From 00ecea025b7ac361b36e5471e24a732edadf6ff6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 26 Nov 2024 15:46:01 +0100 Subject: [PATCH 01/13] Update [ghstack-poisoned] --- src/mrpro/algorithms/prewhiten_kspace.py | 2 +- src/mrpro/algorithms/reconstruction/DirectReconstruction.py | 2 +- .../algorithms/reconstruction/IterativeSENSEReconstruction.py | 2 +- src/mrpro/algorithms/reconstruction/Reconstruction.py | 2 +- .../reconstruction/RegularizedIterativeSENSEReconstruction.py | 2 +- src/mrpro/data/{_kdata => }/KData.py | 0 src/mrpro/data/__init__.py | 2 +- src/mrpro/operators/FourierOp.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename src/mrpro/data/{_kdata => }/KData.py (100%) diff --git a/src/mrpro/algorithms/prewhiten_kspace.py b/src/mrpro/algorithms/prewhiten_kspace.py index ab50f325e..3e1dde12d 100644 --- a/src/mrpro/algorithms/prewhiten_kspace.py +++ b/src/mrpro/algorithms/prewhiten_kspace.py @@ -5,7 +5,7 @@ import torch from einops import einsum, parse_shape, rearrange -from mrpro.data._kdata.KData import KData +from mrpro.data.KData import KData from mrpro.data.KNoise import KNoise diff --git a/src/mrpro/algorithms/reconstruction/DirectReconstruction.py b/src/mrpro/algorithms/reconstruction/DirectReconstruction.py index 3feab5cdb..a201b86c2 100644 --- a/src/mrpro/algorithms/reconstruction/DirectReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/DirectReconstruction.py @@ -3,10 +3,10 @@ from collections.abc import Callable from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction -from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData from mrpro.data.IData import IData +from mrpro.data.KData import KData from mrpro.data.KNoise import KNoise from mrpro.operators.FourierOp import FourierOp from mrpro.operators.LinearOperator import LinearOperator diff --git a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py index 32785a91a..444d85712 100644 --- a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py @@ -7,9 +7,9 @@ from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import ( RegularizedIterativeSENSEReconstruction, ) -from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData +from mrpro.data.KData import KData from mrpro.data.KNoise import KNoise from mrpro.operators.LinearOperator import LinearOperator diff --git a/src/mrpro/algorithms/reconstruction/Reconstruction.py b/src/mrpro/algorithms/reconstruction/Reconstruction.py index c4208157e..54a1f6af2 100644 --- a/src/mrpro/algorithms/reconstruction/Reconstruction.py +++ b/src/mrpro/algorithms/reconstruction/Reconstruction.py @@ -8,10 +8,10 @@ from typing_extensions import Self from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData from mrpro.data.IData import IData +from mrpro.data.KData import KData from mrpro.data.KNoise import KNoise from mrpro.operators.FourierOp import FourierOp from mrpro.operators.LinearOperator import LinearOperator diff --git a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py index c9a307ebe..e3c1c49ce 100644 --- a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py @@ -9,10 +9,10 @@ from mrpro.algorithms.optimizers.cg import cg from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction -from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData from mrpro.data.IData import IData +from mrpro.data.KData import KData from mrpro.data.KNoise import KNoise from mrpro.operators.IdentityOp import IdentityOp from mrpro.operators.LinearOperator import LinearOperator diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/KData.py similarity index 100% rename from src/mrpro/data/_kdata/KData.py rename to src/mrpro/data/KData.py diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index d5667a5bc..89cc6695b 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -6,7 +6,7 @@ from mrpro.data.EncodingLimits import EncodingLimits, Limits from mrpro.data.IData import IData from mrpro.data.IHeader import IHeader -from mrpro.data._kdata.KData import KData +from mrpro.data.KData import KData from mrpro.data.KHeader import KHeader from mrpro.data.KNoise import KNoise from mrpro.data.KTrajectory import KTrajectory diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index a3e81aba7..c25400eea 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -8,8 +8,8 @@ from torchkbnufft import KbNufft, KbNufftAdjoint from typing_extensions import Self -from mrpro.data._kdata.KData import KData from mrpro.data.enums import TrajType +from mrpro.data.KData import KData from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp From 15964abca731b9761c441a1a4c36aef10b708523 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 26 Nov 2024 15:46:05 +0100 Subject: [PATCH 02/13] Update [ghstack-poisoned] --- src/mrpro/data/AcqInfo.py | 111 +++++++++++--------- src/mrpro/data/_kdata/KDataRemoveOsMixin.py | 6 -- tests/conftest.py | 7 +- tests/data/test_kdata.py | 8 +- 4 files changed, 68 insertions(+), 64 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index f5d677f97..5926e36a8 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -82,6 +82,37 @@ class AcqIdx(MoveDataMixin): """User index 7.""" +@dataclass(slots=True) +class UserValues(MoveDataMixin): + """User Values used in AcqInfo.""" + + float1: torch.Tensor + float2: torch.Tensor + float3: torch.Tensor + float4: torch.Tensor + float5: torch.Tensor + float6: torch.Tensor + float7: torch.Tensor + float8: torch.Tensor + int1: torch.Tensor + int2: torch.Tensor + int3: torch.Tensor + int4: torch.Tensor + int5: torch.Tensor + int6: torch.Tensor + int7: torch.Tensor + int8: torch.Tensor + + +@dataclass(slots=True) +class PhysiologyTimestamps: + """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units.""" + + timestamp1: torch.Tensor + timestamp2: torch.Tensor + timestamp3: torch.Tensor + + @dataclass(slots=True) class AcqInfo(MoveDataMixin): """Acquisition information for each readout.""" @@ -92,43 +123,19 @@ class AcqInfo(MoveDataMixin): acquisition_time_stamp: torch.Tensor """Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)""" - active_channels: torch.Tensor - """Number of active receiver coil elements.""" - - available_channels: torch.Tensor - """Number of available receiver coil elements.""" - - center_sample: torch.Tensor - """Index of the readout sample corresponding to k-space center (zero indexed).""" - - channel_mask: torch.Tensor - """Bit mask indicating active coils (64*16 = 1024 bits).""" - - discard_post: torch.Tensor - """Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events).""" - - discard_pre: torch.Tensor - """Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)""" - - encoding_space_ref: torch.Tensor - """Indexed reference to the encoding spaces enumerated in the MRD (xml) header.""" - flags: torch.Tensor """A bit mask of common attributes applicable to individual acquisition readouts.""" measurement_uid: torch.Tensor """Unique ID corresponding to the readout.""" - number_of_samples: torch.Tensor - """Number of sample points per readout (readouts may have different number of sample points).""" - orientation: Rotation """Rotation describing the orientation of the readout, phase and slice encoding direction.""" patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - physiology_time_stamp: torch.Tensor + physiology_time_stamps: PhysiologyTimestamps """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] @@ -140,17 +147,8 @@ class AcqInfo(MoveDataMixin): scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. - """Dimensionality of the k-space trajectory vector.""" - - user_float: torch.Tensor - """User-defined float parameters.""" - - user_int: torch.Tensor - """User-defined int parameters.""" - - version: torch.Tensor - """Major version number.""" + user: UserValues + """User defined float or int values""" @classmethod def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self: @@ -228,33 +226,44 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: user6=tensor(idx['user'][:, 6]), user7=tensor(idx['user'][:, 7]), ) - + user = UserValues( + tensor_2d(headers['user_float'][:, 0]), + tensor_2d(headers['user_float'][:, 1]), + tensor_2d(headers['user_float'][:, 2]), + tensor_2d(headers['user_float'][:, 3]), + tensor_2d(headers['user_float'][:, 4]), + tensor_2d(headers['user_float'][:, 5]), + tensor_2d(headers['user_float'][:, 6]), + tensor_2d(headers['user_float'][:, 7]), + tensor_2d(headers['user_int'][:, 0]), + tensor_2d(headers['user_int'][:, 1]), + tensor_2d(headers['user_int'][:, 2]), + tensor_2d(headers['user_int'][:, 3]), + tensor_2d(headers['user_int'][:, 4]), + tensor_2d(headers['user_int'][:, 5]), + tensor_2d(headers['user_int'][:, 6]), + tensor_2d(headers['user_int'][:, 7]), + ) + physiology_time_stamps = PhysiologyTimestamps( + tensor_2d(headers['physiology_time_stamp'][:, 0]).double(), + tensor_2d(headers['physiology_time_stamp'][:, 1]).double(), + tensor_2d(headers['physiology_time_stamp'][:, 2]).double(), + ) acq_info = cls( idx=acq_idx, - acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']), - active_channels=tensor_2d(headers['active_channels']), - available_channels=tensor_2d(headers['available_channels']), - center_sample=tensor_2d(headers['center_sample']), - channel_mask=tensor_2d(headers['channel_mask']), - discard_post=tensor_2d(headers['discard_post']), - discard_pre=tensor_2d(headers['discard_pre']), - encoding_space_ref=tensor_2d(headers['encoding_space_ref']), + acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']).double(), flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), - number_of_samples=tensor_2d(headers['number_of_samples']), orientation=Rotation.from_directions( spatialdimension_2d(headers['slice_dir']), spatialdimension_2d(headers['phase_dir']), spatialdimension_2d(headers['read_dir']), ), patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), - physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), position=spatialdimension_2d(headers['position']).apply_(mm_to_m), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above - user_float=tensor_2d(headers['user_float']), - user_int=tensor_2d(headers['user_int']), - version=tensor_2d(headers['version']), + user=user, + physiology_time_stamps=physiology_time_stamps, ) return acq_info diff --git a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py index 555f56a39..592d123bd 100644 --- a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py +++ b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py @@ -65,11 +65,5 @@ def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: # Adapt header parameters header = deepcopy(self.header) - header.acq_info.center_sample -= start_cropped_readout - header.acq_info.number_of_samples[:] = cropped_data.shape[-1] header.encoding_matrix.x = cropped_data.shape[-1] - - header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32) - header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32) - return type(self)(header, cropped_data, cropped_traj) diff --git a/tests/conftest.py b/tests/conftest.py index 30ae9c229..0918e8628 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,7 +192,7 @@ def random_acq_info(random_acquisition): return acq_info -@pytest.fixture(params=({'seed': 0, 'n_other': 10, 'n_k2': 40, 'n_k1': 20},)) +@pytest.fixture(params=({'seed': 0, 'n_other': 10, 'n_k2': 40, 'n_k1': 20, 'n_k0': 64, 'n_coils': 2},)) def random_kheader_shape(request, random_acquisition, random_full_ismrmrd_header): """Random (not necessarily valid) KHeader with defined shape.""" # Get dimensions @@ -206,14 +206,15 @@ def random_kheader_shape(request, random_acquisition, random_full_ismrmrd_header # Generate acquisitions random_acq_info = AcqInfo.from_ismrmrd_acquisitions([random_acquisition for _ in range(n_k1 * n_k2 * n_other)]) - n_k0 = int(random_acq_info.number_of_samples[0]) - n_coils = int(random_acq_info.active_channels[0]) # Generate trajectory + n_k0 = int(request.param['n_k0']) ktraj = [generate_random_trajectory(generator, shape=(n_k0, 2)) for _ in range(n_k1 * n_k2 * n_other)] # Put it all together to a KHeader object kheader = KHeader.from_ismrmrd(random_full_ismrmrd_header, acq_info=random_acq_info, defaults={'trajectory': ktraj}) + n_coils = int(request.param['n_coils']) + return kheader, n_other, n_coils, n_k2, n_k1, n_k0 diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index fa3e4ebd9..386ef60be 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -260,8 +260,8 @@ def test_KData_to_complex128_header(ismrmrd_cart): """Change KData dtype complex128: test header""" kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) kdata_complex128 = kdata.to(dtype=torch.complex128) - assert kdata_complex128.header.acq_info.user_float.dtype == torch.float64 - assert kdata_complex128.header.acq_info.user_int.dtype == torch.int32 + assert kdata_complex128.header.acq_info.user.float1.dtype == torch.float64 + assert kdata_complex128.header.acq_info.user.int1.dtype == torch.int32 @pytest.mark.cuda @@ -285,7 +285,7 @@ def test_KData_cuda(ismrmrd_cart): assert kdata_cuda.traj.kz.is_cuda assert kdata_cuda.traj.ky.is_cuda assert kdata_cuda.traj.kx.is_cuda - assert kdata_cuda.header.acq_info.user_int.is_cuda + assert kdata_cuda.header.acq_info.user.int1.is_cuda assert kdata_cuda.device == torch.device(torch.cuda.current_device()) assert kdata_cuda.header.acq_info.device == torch.device(torch.cuda.current_device()) assert kdata_cuda.is_cuda @@ -301,7 +301,7 @@ def test_KData_cpu(ismrmrd_cart): assert kdata_cpu.traj.kz.is_cpu assert kdata_cpu.traj.ky.is_cpu assert kdata_cpu.traj.kx.is_cpu - assert kdata_cpu.header.acq_info.user_int.is_cpu + assert kdata_cpu.header.acq_info.user.int1.is_cpu assert kdata_cpu.device == torch.device('cpu') assert kdata_cpu.header.acq_info.device == torch.device('cpu') From 9bf825328f0858d63cdbecb8f864ce08f8259920 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 27 Nov 2024 13:17:32 +0100 Subject: [PATCH 03/13] Update [ghstack-poisoned] --- src/mrpro/data/KData.py | 15 +- .../traj_calculators/KTrajectoryCalculator.py | 80 +++-- .../traj_calculators/KTrajectoryCartesian.py | 36 +- .../traj_calculators/KTrajectoryIsmrmrd.py | 6 +- .../traj_calculators/KTrajectoryPulseq.py | 39 ++- .../traj_calculators/KTrajectoryRadial2D.py | 35 +- .../data/traj_calculators/KTrajectoryRpe.py | 107 +++--- .../KTrajectorySunflowerGoldenRpe.py | 112 ++++--- tests/data/test_traj_calculators.py | 308 ++++++++---------- 9 files changed, 393 insertions(+), 345 deletions(-) diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 4b5df6250..56ba66a06 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -21,6 +21,7 @@ from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits +from mrpro.data.enums import AcqFlags from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape @@ -238,7 +239,19 @@ def from_file( case KTrajectoryIsmrmrd(): ktrajectory_final = ktrajectory(acquisitions).sort_and_reshape(sort_idx, n_k2, n_k1) case KTrajectoryCalculator(): - ktrajectory_or_rawshape = ktrajectory(kheader) + reversed_readout_mask = (kheader.acq_info.flags[..., 0] & AcqFlags.ACQ_IS_REVERSE.value).bool() + + ktrajectory_or_rawshape = ktrajectory( + n_k0=0, + k0_center=0, + k1_idx=kheader.acq_info.idx.k1, + k1_center=0, + k2_idx=kheader.acq_info.idx.k2, + k2_center=0, + reversed_readout_mask=reversed_readout_mask, + encoding_matrix=kheader.encoding_matrix, + ) + if isinstance(ktrajectory_or_rawshape, KTrajectoryRawShape): ktrajectory_final = ktrajectory_or_rawshape.sort_and_reshape(sort_idx, n_k2, n_k1) else: diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py index 1893d761c..3a3087ba8 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py @@ -4,52 +4,83 @@ import torch -from mrpro.data.enums import AcqFlags -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape +from mrpro.data.SpatialDimension import SpatialDimension class KTrajectoryCalculator(ABC): """Base class for k-space trajectories.""" @abstractmethod - def __call__(self, header: KHeader) -> KTrajectory | KTrajectoryRawShape: + def __call__( + self, + *, + n_k0: int, + k0_center: int, + k1_idx: torch.Tensor, + k1_center: int, + k2_idx: torch.Tensor, + k2_center: int, + encoding_matrix: SpatialDimension, + reversed_readout_mask: torch.Tensor | None = None, + ) -> KTrajectory | KTrajectoryRawShape: """Calculate the trajectory for given KHeader. The shapes of kz, ky and kx of the calculated trajectory must be broadcastable to (prod(all_other_dimensions), k2, k1, k0). + + Not all of the parameters will be used by all implementations. + + Parameters + ---------- + n_k0 + number of samples in k0 + k1_idx + indices of k1 + k2_idx + indices of k2 + k0_center + position of k-space center in k0 + k1_center + position of k-space center in k1 + k2_center + position of k-space center in k2 + reversed_readout_mask + boolean tensor indicating reversed redout + encoding_matrix + encoding matrix, describing the extend of the k-space coordinates + + + + Returns + ------- + Trajectory + """ - ... - def _kfreq(self, kheader: KHeader) -> torch.Tensor: + def _readout(self, n_k0: int, k0_center: int, reversed_readout_mask: torch.Tensor | None) -> torch.Tensor: """Calculate the trajectory along one readout (k0 dimension). Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in readout + k0_center + position of k-space center in readout + reversed_readout_mask + boolean tensor indicating reversed readout, e.g bipolar readout Returns ------- - trajectory along ONE readout + trajectory along one readout - Raises - ------ - ValueError - Number of samples have to be the same for each readout """ - n_samples = torch.unique(kheader.acq_info.number_of_samples) - center_sample = kheader.acq_info.center_sample - if len(n_samples) > 1: - raise ValueError('Trajectory can only be calculated if each acquisition has the same number of samples') - n_k0 = int(n_samples.item()) - - # Data can be obtained with standard or reversed readout (e.g. bipolar readout). - k0 = torch.linspace(0, n_k0 - 1, n_k0, dtype=torch.float32) - center_sample + k0 = torch.linspace(0, n_k0 - 1, n_k0, dtype=torch.float32) - k0_center # Data can be obtained with standard or reversed readout (e.g. bipolar readout). - reversed_readout_mask = (kheader.acq_info.flags[..., 0] & AcqFlags.ACQ_IS_REVERSE.value).bool() - k0[reversed_readout_mask, :] = torch.flip(k0[reversed_readout_mask, :], (-1,)) + if reversed_readout_mask is not None: + k0, reversed_readout_mask = torch.broadcast_tensors(k0, reversed_readout_mask) + k0[reversed_readout_mask] = torch.flip(k0[reversed_readout_mask], (-1,)) return k0 @@ -59,8 +90,9 @@ class DummyTrajectory(KTrajectoryCalculator): Shape will fit to all data. Only used as dummy for testing. """ - def __call__(self, header: KHeader) -> KTrajectory: # noqa: ARG002 + def __call__(self, **_) -> KTrajectory: """Calculate dummy trajectory.""" kx = torch.zeros(1, 1, 1, 1) - ky = kz = torch.zeros(1, 1, 1, 1) + ky = torch.zeros(1, 1, 1, 1) + kz = torch.zeros(1, 1, 1, 1) return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py index 1b0742ee1..63923c99c 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py @@ -3,7 +3,6 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -11,24 +10,47 @@ class KTrajectoryCartesian(KTrajectoryCalculator): """Cartesian trajectory.""" - def __call__(self, kheader: KHeader) -> KTrajectory: + def __call__( + self, + *, + n_k0: int, + k0_center: int, + k1_idx: torch.Tensor, + k1_center: int, + k2_idx: torch.Tensor, + k2_center: int, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: """Calculate Cartesian trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + k1_center + position of k-space center in k1 + k2_idx + indices of k2 + k2_center + position of k-space center in k2 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- Cartesian trajectory for given KHeader """ # K-space locations along readout lines - kx = self._kfreq(kheader) + kx = self._readout(n_k0, k0_center, reversed_readout_mask=reversed_readout_mask) # Trajectory along phase and slice encoding - ky = (kheader.acq_info.idx.k1 - kheader.encoding_limits.k1.center).to(torch.float32) - kz = (kheader.acq_info.idx.k2 - kheader.encoding_limits.k2.center).to(torch.float32) + ky = (k1_idx - k1_center).to(torch.float32) + kz = (k2_idx - k2_center).to(torch.float32) # Bring to correct dimensions ky = repeat(ky, '... k2 k1-> ... k2 k1 k0', k0=1) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py b/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py index 2b6aad369..be0a09c40 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryIsmrmrd.py @@ -47,6 +47,10 @@ def __call__(self, acquisitions: Sequence[ismrmrd.Acquisition]) -> KTrajectoryRa kx=ktraj_mrd[..., 0], ) else: - ktraj = KTrajectoryRawShape(kz=ktraj_mrd[..., 2], ky=ktraj_mrd[..., 1], kx=ktraj_mrd[..., 0]) + ktraj = KTrajectoryRawShape( + kz=ktraj_mrd[..., 2], + ky=ktraj_mrd[..., 1], + kx=ktraj_mrd[..., 0], + ) return ktraj diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index 7c843572d..bfcf67060 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -6,7 +6,7 @@ import torch from einops import rearrange -from mrpro.data.KHeader import KHeader +from mrpro.data import SpatialDimension from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -28,13 +28,21 @@ def __init__(self, seq_path: str | Path, repeat_detection_tolerance: None | floa self.seq_path = seq_path self.repeat_detection_tolerance = repeat_detection_tolerance - def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: + def __call__( + self, + *, + n_k0: int, + encoding_matrix: SpatialDimension, + **_, + ) -> KTrajectoryRawShape: """Calculate trajectory from given .seq file and header information. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + encoding_matrix + encoding matrix, describing the extend of the k-space coordinates Returns ------- @@ -48,19 +56,20 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: k_traj_adc_numpy, _, _, _, _ = seq.calculate_kspace() k_traj_adc = torch.tensor(k_traj_adc_numpy, dtype=torch.float32) - n_samples = kheader.acq_info.number_of_samples - n_samples = torch.unique(n_samples) - if len(n_samples) > 1: - raise ValueError('We currently only support constant number of samples') - n_k0 = int(n_samples.item()) - - def reshape_pulseq_traj(k_traj: torch.Tensor, encoding_size: int): - k_traj *= encoding_size / (2 * torch.max(torch.abs(k_traj))) + def reshape(k_traj: torch.Tensor, encoding_size: int) -> torch.Tensor: + max_value_range = 2 * torch.max(torch.abs(k_traj)) + if max_value_range > 1e-9 and encoding_size > 1: + k_traj = k_traj * encoding_size / max_value_range + else: + # If encoding matrix is 1, we force k_traj to be 0. We assume here that the values are + # numerical noise returned by pulseq, not real trajectory values + # even if pulseq returned some numerical noise. Also we avoid division by zero. + k_traj = torch.zeros_like(k_traj) return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0) # rearrange k-space trajectory to match MRpro convention - kx = reshape_pulseq_traj(k_traj_adc[0], kheader.encoding_matrix.x) - ky = reshape_pulseq_traj(k_traj_adc[1], kheader.encoding_matrix.y) - kz = reshape_pulseq_traj(k_traj_adc[2], kheader.encoding_matrix.z) + kx = reshape(k_traj_adc[0], encoding_matrix.x) + ky = reshape(k_traj_adc[1], encoding_matrix.y) + kz = reshape(k_traj_adc[2], encoding_matrix.z) return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py index 1616af909..930b60c84 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py @@ -3,7 +3,6 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -22,27 +21,35 @@ def __init__(self, angle: float = torch.pi * 0.618034) -> None: super().__init__() self.angle: float = angle - def __call__(self, kheader: KHeader) -> KTrajectory: + def __call__( + self, + *, + n_k0: int, + k0_center: int, + k1_idx: torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: """Calculate radial 2D trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- radial 2D trajectory for given KHeader """ - # K-space locations along readout lines - krad = self._kfreq(kheader) - - # Angles of readout lines - kang = repeat(kheader.acq_info.idx.k1 * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) - - # K-space radial coordinates - kx = krad * torch.cos(kang) - ky = krad * torch.sin(kang) + radial = self._readout(n_k0=n_k0, k0_center=k0_center, reversed_readout_mask=reversed_readout_mask) + angle = repeat(k1_idx * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) + kx = radial * torch.cos(angle) + ky = radial * torch.sin(angle) kz = torch.zeros(1, 1, 1, 1) - return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py index 377a7ae97..bd6dfdf64 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py @@ -3,7 +3,6 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator @@ -39,7 +38,7 @@ def __init__(self, angle: float, shift_between_rpe_lines: tuple | torch.Tensor = self.angle: float = angle self.shift_between_rpe_lines: torch.Tensor = torch.as_tensor(shift_between_rpe_lines) - def _apply_shifts_between_rpe_lines(self, krad: torch.Tensor, kang_idx: torch.Tensor) -> torch.Tensor: + def _apply_shifts_between_rpe_lines(self, k_radial: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """Shift radial phase encoding lines relative to each other. Example: shift_between_rpe_lines = [0, 0.5, 0.25, 0.75] leads to a shift of the 0th line by 0, @@ -56,80 +55,70 @@ def _apply_shifts_between_rpe_lines(self, krad: torch.Tensor, kang_idx: torch.Te Parameters ---------- - krad - k-space positions along each phase encoding line - kang_idx - indices of angles to be used for shift calculation - - References - ---------- - .. [PRI2010] Prieto C, Schaeffter T (2010) 3D undersampled golden-radial phase encoding - for DCE-MRA using inherently regularized iterative SENSE. MRM 64(2). https://doi.org/10.1002/mrm.22446 - """ - for ind, shift in enumerate(self.shift_between_rpe_lines): - curr_angle_idx = torch.nonzero( - torch.fmod(kang_idx, len(self.shift_between_rpe_lines)) == ind, - as_tuple=True, - ) - curr_krad = krad[curr_angle_idx] - - # Do not shift the k-space center - curr_krad += shift * (curr_krad != 0) - - krad[curr_angle_idx] = curr_krad - return krad - - def _kang(self, kheader: KHeader) -> torch.Tensor: - """Calculate the angles of the phase encoding lines. - - Parameters - ---------- - kheader - MR raw data header (KHeader) containing required meta data + k_radial + k-space positions along each phase encoding line, zo be shifted + idx + indices used for shift calculation Returns ------- - angles of phase encoding lines - """ - return repeat(kheader.acq_info.idx.k2 * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) - - def _krad(self, kheader: KHeader) -> torch.Tensor: - """Calculate the k-space locations along the phase encoding lines. + shifted radial k-space positions - Parameters + References ---------- - kheader - MR raw data header (KHeader) containing required meta data - - Returns - ------- - k-space locations along the phase encoding lines + .. [PRI2010] Prieto C, Schaeffter T (2010) 3D undersampled golden-radial phase encoding + for DCE-MRA using inherently regularized iterative SENSE. MRM 64(2). https://doi.org/10.1002/mrm.22446 """ - krad = (kheader.acq_info.idx.k1 - kheader.encoding_limits.k1.center).to(torch.float32) - krad = self._apply_shifts_between_rpe_lines(krad, kheader.acq_info.idx.k2) - return repeat(krad, '... k2 k1 -> ... k2 k1 k0', k0=1) + # do not shift k-space center + not_center = ~torch.isclose(k_radial, torch.tensor(0)) - def __call__(self, kheader: KHeader) -> KTrajectory: + for ind, shift in enumerate(self.shift_between_rpe_lines): + current_mask = (idx % len(self.shift_between_rpe_lines)) == ind + current_mask &= not_center + k_radial[current_mask] += shift + + return k_radial + + def __call__( + self, + *, + n_k0: int, + k0_center: int, + k1_idx: torch.Tensor, + k1_center: int, + k2_idx: torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: """Calculate radial phase encoding trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + k1_center + position of k-space center in k1 + k2_idx + indices of k2 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- radial phase encoding trajectory for given KHeader """ - # Trajectory along readout - kx = self._kfreq(kheader) + angles = repeat(k2_idx * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1) - # Angles of phase encoding lines - kang = self._kang(kheader) + radial = (k1_idx - k1_center).to(torch.float32) + radial = self._apply_shifts_between_rpe_lines(radial, k2_idx) + radial = repeat(radial, '... k2 k1 -> ... k2 k1 k0', k0=1) - # K-space locations along phase encoding lines - krad = self._krad(kheader) + kz = radial * torch.sin(angles) + ky = radial * torch.cos(angles) + kx = self._readout(n_k0, k0_center, reversed_readout_mask=reversed_readout_mask) - kz = krad * torch.sin(kang) - ky = krad * torch.cos(kang) return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py index 8ab4abd92..0d3de5de7 100644 --- a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py @@ -4,90 +4,94 @@ import torch from einops import repeat -from mrpro.data.KHeader import KHeader -from mrpro.data.traj_calculators.KTrajectoryRpe import KTrajectoryRpe +from mrpro.data.KTrajectory import KTrajectory +from mrpro.data.traj_calculators import KTrajectoryCalculator +GOLDEN_RATIO = 0.5 * (5**0.5 + 1) -class KTrajectorySunflowerGoldenRpe(KTrajectoryRpe): + +class KTrajectorySunflowerGoldenRpe(KTrajectoryCalculator): """Radial phase encoding trajectory with a sunflower pattern.""" - def __init__(self, rad_us_factor: float = 1.0) -> None: + def __init__(self, radial_undersampling_factor: float = 1.0) -> None: """Initialize KTrajectorySunflowerGoldenRpe. Parameters ---------- - rad_us_factor + radial_undersampling_factor undersampling factor along radial phase encoding direction. """ - super().__init__(angle=torch.pi * 0.618034) - self.rad_us_factor: float = rad_us_factor + self.angle = torch.pi * 0.618034 + + if radial_undersampling_factor != 1: + raise NotImplementedError('Radial undersampling is not yet implemented') def _apply_sunflower_shift_between_rpe_lines( self, - krad: torch.Tensor, - kang: torch.Tensor, - kheader: KHeader, + radial: torch.Tensor, + angles: torch.Tensor, + k2_idx: torch.Tensor, ) -> torch.Tensor: """Shift radial phase encoding lines relative to each other. The shifts are applied to create a sunflower pattern of k-space points in the ky-kz phase encoding plane. - The applied shifts can lead to a scaling of the FOV. This scaling depends on the undersampling factor along the - radial phase encoding direction and is compensated for at the end. Parameters ---------- - krad - k-space positions along each phase encoding line - kang - angles of the radial phase encoding lines - kheader - MR raw data header (KHeader) containing required meta data + radial + position along radial direction + angles + angle of spokes + k2_idx + indices in k2 """ - kang = kang.flatten() - _, indices = np.unique(kang, return_index=True) + angles = angles.flatten() + _, indices = np.unique(angles, return_index=True) shift_idx = np.argsort(indices) - - # Apply sunflower shift - golden_ratio = 0.5 * (np.sqrt(5) + 1) for ind, shift in enumerate(shift_idx): - krad[kheader.acq_info.idx.k2 == ind] += ((shift * golden_ratio) % 1) - 0.5 + radial[k2_idx == ind] += ((shift * GOLDEN_RATIO) % 1) - 0.5 + return radial - # Set asym k-space point to 0 because this point was used to obtain a self-navigator signal. - krad[kheader.acq_info.idx.k1 == 0] = 0 - - return krad - - def _kang(self, kheader: KHeader) -> torch.Tensor: - """Calculate the angles of the phase encoding lines. + def __call__( + self, + *, + n_k0: int, + k0_center: int, + k1_idx: torch.Tensor, + k1_center: int, + k2_idx: torch.Tensor, + reversed_readout_mask: torch.Tensor | None = None, + **_, + ) -> KTrajectory: + """Calculate radial phase encoding trajectory for given KHeader. Parameters ---------- - kheader - MR raw data header (KHeader) containing required meta data + n_k0 + number of samples in k0 + k0_center + position of k-space center in k0 + k1_idx + indices of k1 + k1_center + position of k-space center in k1 + k2_idx + indices of k2 + reversed_readout_mask + boolean tensor indicating reversed readout Returns ------- - angles of phase encoding lines + radial phase encoding trajectory for given KHeader """ - return repeat((kheader.acq_info.idx.k2 * self.angle) % torch.pi, '... k2 k1 -> ... k2 k1 k0', k0=1) + angles = repeat((k2_idx * self.angle) % torch.pi, '... k2 k1 -> ... k2 k1 k0', k0=1) + radial = repeat((k1_idx - k1_center).to(torch.float32), '... k2 k1 -> ... k2 k1 k0', k0=1) + radial = self._apply_sunflower_shift_between_rpe_lines(radial, angles, k2_idx) - def _krad(self, kheader: KHeader) -> torch.Tensor: - """Calculate the k-space locations along the phase encoding lines. + # Asymmetric k-space point is used to obtain a self-navigator signal, thus should be in k-space center + radial[k1_idx == 0] = 0 - Parameters - ---------- - kheader - MR raw data header (KHeader) containing required meta data - - Returns - ------- - k-space locations along the phase encoding lines - """ - kang = self._kang(kheader) - krad = repeat( - (kheader.acq_info.idx.k1 - kheader.encoding_limits.k1.center).to(torch.float32), - '... k2 k1 -> ... k2 k1 k0', - k0=1, - ) - krad = self._apply_sunflower_shift_between_rpe_lines(krad, kang, kheader) - return krad + kz = radial * torch.sin(angles) + ky = radial * torch.cos(angles) + kx = self._readout(n_k0, k0_center, reversed_readout_mask=reversed_readout_mask) + return KTrajectory(kz, ky, kx) diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 7ddcb30aa..42a2b6bfa 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -1,6 +1,5 @@ """Tests for KTrajectory Calculator classes.""" -import numpy as np import pytest import torch from einops import repeat @@ -18,202 +17,125 @@ from tests.data import IsmrmrdRawTestData, PulseqRadialTestSeq -@pytest.fixture -def valid_rad2d_kheader(monkeypatch, random_kheader): - """KHeader with all necessary parameters for radial 2D trajectories.""" - # K-space dimensions +def test_KTrajectoryRadial2D(): + """Test shapes returned by KTrajectoryRadial2D.""" + n_k0 = 256 n_k1 = 10 - n_k2 = 1 - - # List of k1 indices in the shape - idx_k1 = repeat(torch.arange(n_k1, dtype=torch.int32), 'k1 -> other k2 k1', other=1, k2=1) - - # Set parameters for radial 2D trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) - monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) - monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) - monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) - # This is only needed for Pulseq trajectory calculation - monkeypatch.setattr(random_kheader.encoding_matrix, 'x', n_k0) - monkeypatch.setattr(random_kheader.encoding_matrix, 'y', n_k0) # square encoding in kx-ky plane - monkeypatch.setattr(random_kheader.encoding_matrix, 'z', n_k2) - - return random_kheader - - -def radial2D_traj_shape(valid_rad2d_kheader): - """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[2] - n_k2 = 1 - n_other = 1 - return ( - torch.Size([n_other, 1, 1, 1]), - torch.Size([n_other, n_k2, n_k1, n_k0]), - torch.Size([n_other, n_k2, n_k1, n_k0]), + trajectory_calculator = KTrajectoryRadial2D() + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k1)[None, None, :, None], ) + assert trajectory.kz.shape == (1, 1, 1, 1) + assert trajectory.ky.shape == (1, 1, n_k1, n_k0) + assert trajectory.kx.shape == (1, 1, n_k1, n_k0) -def test_KTrajectoryRadial2D_golden(valid_rad2d_kheader): - """Calculate Radial 2D trajectory with golden angle.""" - trajectory_calculator = KTrajectoryRadial2D(angle=torch.pi * 0.618034) - trajectory = trajectory_calculator(valid_rad2d_kheader) - valid_shape = radial2D_traj_shape(valid_rad2d_kheader) - assert trajectory.kx.shape == valid_shape[2] - assert trajectory.ky.shape == valid_shape[1] - assert trajectory.kz.shape == valid_shape[0] - - -@pytest.fixture -def valid_rpe_kheader(monkeypatch, random_kheader): - """KHeader with all necessary parameters for RPE trajectories.""" - # K-space dimensions - n_k0 = 200 +def test_KTrajectoryRpe(): + """Test shapes returned by KTrajectoryRpe""" + n_k0 = 100 n_k1 = 20 n_k2 = 10 - # List of k1 and k2 indices in the shape (other, k2, k1) - k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) - k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) - idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = torch.reshape(idx_k1, (1, n_k2, n_k1)) - idx_k2 = torch.reshape(idx_k2, (1, n_k2, n_k1)) - - # Set parameters for RPE trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) - monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) - monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) - monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k2', idx_k2) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'center', int(n_k1 // 2)) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'max', int(n_k1 - 1)) - return random_kheader - - -def rpe_traj_shape(valid_rpe_kheader): - """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[2] - n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[1] - n_other = 1 - return ( - torch.Size([n_other, n_k2, n_k1, 1]), - torch.Size([n_other, n_k2, n_k1, 1]), - torch.Size([n_other, 1, 1, n_k0]), + trajectory_calculator = KTrajectoryRpe(angle=torch.pi * 0.618034) + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k2)[None, :, None, None], + k1_center=n_k1 // 2, + k2_idx=torch.arange(n_k1)[None, None, :, None], ) + assert trajectory.kz.shape == (1, n_k2, n_k1, 1) + assert trajectory.ky.shape == (1, n_k2, n_k1, 1) + assert trajectory.kx.shape == (1, 1, 1, n_k0) -def test_KTrajectoryRpe_golden(valid_rpe_kheader): - """Calculate RPE trajectory with golden angle.""" - trajectory_calculator = KTrajectoryRpe(angle=torch.pi * 0.618034) - trajectory = trajectory_calculator(valid_rpe_kheader) - valid_shape = rpe_traj_shape(valid_rpe_kheader) - assert trajectory.kz.shape == valid_shape[0] - assert trajectory.ky.shape == valid_shape[1] - assert trajectory.kx.shape == valid_shape[2] - - -def test_KTrajectoryRpe_uniform(valid_rpe_kheader): - """Calculate RPE trajectory with uniform angle.""" - n_rpe_lines = valid_rpe_kheader.acq_info.idx.k1.shape[1] - trajectory1_calculator = KTrajectoryRpe(angle=torch.pi / n_rpe_lines, shift_between_rpe_lines=torch.tensor([0])) - trajectory1 = trajectory1_calculator(valid_rpe_kheader) +def test_KTrajectoryRpe_angle(): + """Test that every second line matches the first half of lines of a trajectory with double the angular gap.""" + n_k0 = 100 + n_k1 = 20 + n_k2 = 10 + k1_idx = torch.arange(n_k1)[None, None, :, None] + k2_idx = torch.arange(n_k2)[None, :, None, None] + trajectory1_calculator = KTrajectoryRpe(angle=torch.pi / n_k1, shift_between_rpe_lines=(0,)) + trajectory1 = trajectory1_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=k1_idx, + k1_center=n_k1 // 2, + k2_idx=k2_idx, + ) # Calculate trajectory with half the angular gap such that every second line should be the same as above trajectory2_calculator = KTrajectoryRpe( - angle=torch.pi / (2 * n_rpe_lines), - shift_between_rpe_lines=torch.tensor([0]), + angle=torch.pi / (2 * n_k1), + shift_between_rpe_lines=torch.tensor([0, 0, 0, 0]), ) - trajectory2 = trajectory2_calculator(valid_rpe_kheader) - - torch.testing.assert_close(trajectory1.kx[:, : n_rpe_lines // 2, :, :], trajectory2.kx[:, ::2, :, :]) - torch.testing.assert_close(trajectory1.ky[:, : n_rpe_lines // 2, :, :], trajectory2.ky[:, ::2, :, :]) - torch.testing.assert_close(trajectory1.kz[:, : n_rpe_lines // 2, :, :], trajectory2.kz[:, ::2, :, :]) - - -def test_KTrajectoryRpe_shift(valid_rpe_kheader): - """Evaluate radial shifts for RPE trajectory.""" - trajectory1_calculator = KTrajectoryRpe(angle=torch.pi * 0.618034, shift_between_rpe_lines=torch.tensor([0.25])) - trajectory1 = trajectory1_calculator(valid_rpe_kheader) - trajectory2_calculator = KTrajectoryRpe( - angle=torch.pi * 0.618034, - shift_between_rpe_lines=torch.tensor([0.25, 0.25, 0.25, 0.25]), + trajectory2 = trajectory2_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=k1_idx, + k1_center=n_k1 // 2, + k2_idx=k2_idx, ) - trajectory2 = trajectory2_calculator(valid_rpe_kheader) - torch.testing.assert_close(trajectory1.as_tensor(), trajectory2.as_tensor()) - -def test_KTrajectorySunflowerGoldenRpe(valid_rpe_kheader): - """Calculate RPE Sunflower trajectory.""" - trajectory_calculator = KTrajectorySunflowerGoldenRpe(rad_us_factor=2) - trajectory = trajectory_calculator(valid_rpe_kheader) - assert trajectory.broadcasted_shape == np.broadcast_shapes(*rpe_traj_shape(valid_rpe_kheader)) + torch.testing.assert_close(trajectory1.kx[:, : n_k1 // 2, :, :], trajectory2.kx[:, ::2, :, :]) + torch.testing.assert_close(trajectory1.ky[:, : n_k1 // 2, :, :], trajectory2.ky[:, ::2, :, :]) + torch.testing.assert_close(trajectory1.kz[:, : n_k1 // 2, :, :], trajectory2.kz[:, ::2, :, :]) -@pytest.fixture -def valid_cartesian_kheader(monkeypatch, random_kheader): - """KHeader with all necessary parameters for Cartesian trajectories.""" - # K-space dimensions - n_k0 = 200 +def test_KTrajectorySunflowerGoldenRpe(): + """Test shape returned by KTrajectorySunflowerGoldenRpe""" + n_k0 = 100 n_k1 = 20 n_k2 = 10 - n_other = 2 - - # List of k1 and k2 indices in the shape (other, k2, k1) - k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) - k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) - idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) - idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) - - # Set parameters for Cartesian trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) - monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) - monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) - monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) - monkeypatch.setattr(random_kheader.acq_info.idx, 'k2', idx_k2) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'center', int(n_k1 // 2)) - monkeypatch.setattr(random_kheader.encoding_limits.k1, 'max', int(n_k1 - 1)) - monkeypatch.setattr(random_kheader.encoding_limits.k2, 'center', int(n_k2 // 2)) - monkeypatch.setattr(random_kheader.encoding_limits.k2, 'max', int(n_k2 - 1)) - return random_kheader - - -def cartesian_traj_shape(valid_cartesian_kheader): - """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[2] - n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[1] - n_other = 1 # trajectory along other is the same - return (torch.Size([n_other, n_k2, 1, 1]), torch.Size([n_other, 1, n_k1, 1]), torch.Size([n_other, 1, 1, n_k0])) + k1_idx = torch.arange(n_k1)[None, None, :, None] + k2_idx = torch.arange(n_k2)[None, :, None, None] + trajectory_calculator = KTrajectorySunflowerGoldenRpe() + trajectory = trajectory_calculator( + n_k0=n_k0, k0_center=n_k0 // 2, k1_idx=k1_idx, k1_center=n_k1 // 2, k2_idx=k2_idx + ) + assert trajectory.broadcasted_shape == (1, n_k2, n_k1, n_k0) def test_KTrajectoryCartesian(valid_cartesian_kheader): """Calculate Cartesian trajectory.""" + n_k0 = 30 + n_k1 = 20 + n_k2 = 10 trajectory_calculator = KTrajectoryCartesian() - trajectory = trajectory_calculator(valid_cartesian_kheader) - valid_shape = cartesian_traj_shape(valid_cartesian_kheader) - assert trajectory.kz.shape == valid_shape[0] - assert trajectory.ky.shape == valid_shape[1] - assert trajectory.kx.shape == valid_shape[2] - - -@pytest.fixture -def valid_cartesian_kheader_bipolar(monkeypatch, valid_cartesian_kheader): - """Set readout of other==1 to reversed.""" - acq_info_flags = valid_cartesian_kheader.acq_info.flags - acq_info_flags[1, ...] = AcqFlags.ACQ_IS_REVERSE.value - monkeypatch.setattr(valid_cartesian_kheader.acq_info, 'flags', acq_info_flags) - return valid_cartesian_kheader + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k1)[None, None, :, None], + k1_center=n_k1 // 2, + k2_idx=torch.arange(n_k2)[None, :, None, None], + ) + assert trajectory.kz.shape == (1, n_k2, 1, 1) + assert trajectory.ky.shape == (1, 1, n_k1, 1) + assert trajectory.kx.shape == (1, 1, 1, n_k0) def test_KTrajectoryCartesian_bipolar(valid_cartesian_kheader_bipolar): - """Calculate Cartesian trajectory with bipolar readout.""" + """Verify that the readout for the second part of a bipolar readout is reversed""" + trajectory_calculator = KTrajectoryCartesian() + n_k0 = 30 + n_k1 = 20 + n_k2 = 10 + reversed_readout_mask = torch.zeros(n_k1, 1, dtype=torch.bool) + reversed_readout_mask[1] = True trajectory_calculator = KTrajectoryCartesian() - trajectory = trajectory_calculator(valid_cartesian_kheader_bipolar) - # Verify that the readout for the second part of the bipolar readout is reversed - torch.testing.assert_close(trajectory.kx[0, ...], torch.flip(trajectory.kx[1, ...], dims=(-1,))) + trajectory = trajectory_calculator( + n_k0=n_k0, + k0_center=n_k0 // 2, + k1_idx=torch.arange(n_k1)[None, None, :, None], + k1_center=n_k1 // 2, + k2_idx=torch.arange(n_k2)[None, :, None, None], + reversed_readout_mask=reversed_readout_mask, + ) + torch.testing.assert_close(trajectory.kx[..., 0, :], torch.flip(trajectory.kx[..., 1, :], dims=(-1,))) @pytest.fixture(scope='session') @@ -252,13 +174,12 @@ def pulseq_example_rad_seq(tmp_path_factory): return seq -def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_rad2d_kheader): +def test_KTrajectoryPulseq(pulseq_example_rad_seq, valid_rad2d_kheader): """Test pulseq File reader with valid seq File.""" - # TODO: Test with valid header # TODO: Test with invalid seq file trajectory_calculator = KTrajectoryPulseq(seq_path=pulseq_example_rad_seq.seq_filename) - trajectory = trajectory_calculator(kheader=valid_rad2d_kheader) + trajectory = trajectory_calculator(n_k0=n_k0, encoding_matrix=encoding_matrix) kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze(0).squeeze(0) kx_test *= valid_rad2d_kheader.encoding_matrix.x / (2 * torch.max(torch.abs(kx_test))) @@ -268,3 +189,50 @@ def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_ torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3) torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3) + + +@pytest.fixture +def valid_cartesian_kheader(monkeypatch, random_kheader): + """KHeader with all necessary parameters for Cartesian trajectories.""" + # K-space dimensions + n_k0 = 200 + n_k1 = 20 + n_k2 = 10 + n_other = 2 + + # List of k1 and k2 indices in the shape (other, k2, k1) + k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) + k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) + idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') + idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) + idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) + + # Set parameters for Cartesian trajectory (AcqInfo is of shape (other k2 k1 dim=1 or 3)) + monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) + monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) + monkeypatch.setattr(random_kheader.acq_info, 'flags', torch.zeros_like(idx_k1)[..., None]) + monkeypatch.setattr(random_kheader.acq_info.idx, 'k1', idx_k1) + monkeypatch.setattr(random_kheader.acq_info.idx, 'k2', idx_k2) + monkeypatch.setattr(random_kheader.encoding_limits.k1, 'center', int(n_k1 // 2)) + monkeypatch.setattr(random_kheader.encoding_limits.k1, 'max', int(n_k1 - 1)) + monkeypatch.setattr(random_kheader.encoding_limits.k2, 'center', int(n_k2 // 2)) + monkeypatch.setattr(random_kheader.encoding_limits.k2, 'max', int(n_k2 - 1)) + return random_kheader + + +def cartesian_traj_shape(valid_cartesian_kheader): + """Expected shape of trajectory based on KHeader.""" + n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0] + n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[2] + n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[1] + n_other = 1 # trajectory along other is the same + return + + +@pytest.fixture +def valid_cartesian_kheader_bipolar(monkeypatch, valid_cartesian_kheader): + """Set readout of other==1 to reversed.""" + acq_info_flags = valid_cartesian_kheader.acq_info.flags + acq_info_flags[1, ...] = AcqFlags.ACQ_IS_REVERSE.value + monkeypatch.setattr(valid_cartesian_kheader.acq_info, 'flags', acq_info_flags) + return valid_cartesian_kheader From 3e660957396f6d7779908f33e0983214cb1df99e Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 27 Nov 2024 14:29:24 +0100 Subject: [PATCH 04/13] Update [ghstack-poisoned] --- src/mrpro/data/AcqInfo.py | 35 +++++++++++++++---- src/mrpro/data/KData.py | 32 ++++++++++------- .../traj_calculators/KTrajectoryCalculator.py | 10 +++--- .../traj_calculators/KTrajectoryCartesian.py | 6 ++-- .../traj_calculators/KTrajectoryPulseq.py | 2 +- .../traj_calculators/KTrajectoryRadial2D.py | 2 +- .../data/traj_calculators/KTrajectoryRpe.py | 4 +-- .../KTrajectorySunflowerGoldenRpe.py | 6 ++-- tests/data/test_traj_calculators.py | 2 ++ 9 files changed, 67 insertions(+), 32 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 5926e36a8..44d2c1e52 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from dataclasses import dataclass +from typing import overload import ismrmrd import numpy as np @@ -150,14 +151,31 @@ class AcqInfo(MoveDataMixin): user: UserValues """User defined float or int values""" + @overload @classmethod - def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self: + def from_ismrmrd_acquisitions( + cls, acquisitions: Sequence[ismrmrd.Acquisition], *, additional_fields: None + ) -> Self: ... + + @overload + @classmethod + def from_ismrmrd_acquisitions( + cls, acquisitions: Sequence[ismrmrd.Acquisition], *, additional_fields: Sequence[str] + ) -> tuple[Self, tuple[torch.Tensor, ...]]: ... + + @classmethod + def from_ismrmrd_acquisitions( + cls, acquisitions: Sequence[ismrmrd.Acquisition], *, additional_fields: Sequence[str] | None = None + ) -> Self | tuple[Self, tuple[torch.Tensor, ...]]: """Read the header of a list of acquisition and store information. Parameters ---------- - acquisitions: + acquisitions list of ismrmrd acquisistions to read from. Needs at least one acquisition. + additional_fields + if supplied, additional fields with these names will be from the ismrmrd acquisitions + and returned as tensors. """ # Idea: create array of structs, then a struct of arrays, # convert it into tensors to store in our dataclass. @@ -167,9 +185,9 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) raise ValueError('Acquisition list must not be empty.') # Creating the dtype first and casting to bytes - # is a workaround for a bug in cpython > 3.12 causing a warning - # is np.array(AcquisitionHeader) is called directly. - # also, this needs to check the dtyoe only once. + # is a workaround for a bug in cpython causing a warning + # if np.array(AcquisitionHeader) is called directly. + # also, this needs to check the dtype only once. acquisition_head_dtype = np.dtype(ismrmrd.AcquisitionHeader) headers = np.frombuffer( np.array([memoryview(a._head).cast('B') for a in acquisitions]), @@ -266,4 +284,9 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: user=user, physiology_time_stamps=physiology_time_stamps, ) - return acq_info + + if additional_fields is None: + return acq_info + else: + additional_values = tuple(tensor_2d(headers[field]) for field in additional_fields) + return acq_info, additional_values diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 56ba66a06..614d48fbf 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -137,9 +137,12 @@ def from_file( kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) - acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) + acq_info, (k0_center, n_k0_tensor, discard_pre, discard_post) = AcqInfo.from_ismrmrd_acquisitions( + acquisitions, + additional_fields=('center_sample', 'number_of_samples', 'discard_pre', 'discard_post'), + ) - if len(torch.unique(acqinfo.idx.user5)) > 1: + if len(torch.unique(acq_info.idx.user5)) > 1: warnings.warn( 'The Siemens to ismrmrd converter currently (ab)uses ' 'the user 5 indices for storing the kspace center line number.\n' @@ -147,7 +150,7 @@ def from_file( stacklevel=1, ) - if len(torch.unique(acqinfo.idx.user6)) > 1: + if len(torch.unique(acq_info.idx.user6)) > 1: warnings.warn( 'The Siemens to ismrmrd converter currently (ab)uses ' 'the user 6 indices for storing the kspace center partition number.\n' @@ -158,7 +161,7 @@ def from_file( # Raises ValueError if required fields are missing in the header kheader = KHeader.from_ismrmrd( ismrmrd_header, - acqinfo, + acq_info, defaults={ 'datetime': modification_time, # use the modification time of the dataset as fallback 'trajectory': ktrajectory, @@ -172,9 +175,9 @@ def from_file( # (number_of_samples, center_sample) of (100, 20) (e.g. partial Fourier in the negative k0 direction) and # (100, 80) (e.g. partial Fourier in the positive k0 direction) then this should lead to encoding limits of # [min=0, max=159, center=80] - max_center_sample = int(torch.max(kheader.acq_info.center_sample)) - max_pos_k0_extend = int(torch.max(kheader.acq_info.number_of_samples - kheader.acq_info.center_sample)) - kheader.encoding_limits.k0 = Limits(0, max_center_sample + max_pos_k0_extend - 1, max_center_sample) + max_center_sample = int(torch.max(k0_center)) + max_positive_k0_extend = int(torch.max(n_k0_tensor - k0_center)) + kheader.encoding_limits.k0 = Limits(0, max_center_sample + max_positive_k0_extend - 1, max_center_sample) # Sort and reshape the kdata and the acquisistion info according to the indices. # within "other", the aquisistions are sorted in the order determined by KDIM_SORT_LABELS. @@ -240,14 +243,19 @@ def from_file( ktrajectory_final = ktrajectory(acquisitions).sort_and_reshape(sort_idx, n_k2, n_k1) case KTrajectoryCalculator(): reversed_readout_mask = (kheader.acq_info.flags[..., 0] & AcqFlags.ACQ_IS_REVERSE.value).bool() - + n_k0_unique = torch.unique(n_k0_tensor) + if len(n_k0_unique) > 1: + raise ValueError( + 'Trajectory can only be calculated for constant number of readout samples.\n' + f'Got unique values {list(n_k0_unique)}' + ) ktrajectory_or_rawshape = ktrajectory( - n_k0=0, - k0_center=0, + n_k0=int(n_k0_unique[0]), + k0_center=k0_center, k1_idx=kheader.acq_info.idx.k1, - k1_center=0, + k1_center=kheader.encoding_limits.k1.center, k2_idx=kheader.acq_info.idx.k2, - k2_center=0, + k2_center=kheader.encoding_limits.k2.center, reversed_readout_mask=reversed_readout_mask, encoding_matrix=kheader.encoding_matrix, ) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py index 3a3087ba8..1061ae6a4 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py @@ -17,11 +17,11 @@ def __call__( self, *, n_k0: int, - k0_center: int, + k0_center: int | torch.Tensor, k1_idx: torch.Tensor, - k1_center: int, + k1_center: int | torch.Tensor, k2_idx: torch.Tensor, - k2_center: int, + k2_center: int | torch.Tensor, encoding_matrix: SpatialDimension, reversed_readout_mask: torch.Tensor | None = None, ) -> KTrajectory | KTrajectoryRawShape: @@ -59,7 +59,9 @@ def __call__( """ - def _readout(self, n_k0: int, k0_center: int, reversed_readout_mask: torch.Tensor | None) -> torch.Tensor: + def _readout( + self, n_k0: int, k0_center: int | torch.Tensor, reversed_readout_mask: torch.Tensor | None + ) -> torch.Tensor: """Calculate the trajectory along one readout (k0 dimension). Parameters diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py index 63923c99c..efa6b0b95 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py @@ -14,11 +14,11 @@ def __call__( self, *, n_k0: int, - k0_center: int, + k0_center: int | torch.Tensor, k1_idx: torch.Tensor, - k1_center: int, + k1_center: int | torch.Tensor, k2_idx: torch.Tensor, - k2_center: int, + k2_center: int | torch.Tensor, reversed_readout_mask: torch.Tensor | None = None, **_, ) -> KTrajectory: diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index bfcf67060..91dda9d17 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -6,8 +6,8 @@ import torch from einops import rearrange -from mrpro.data import SpatialDimension from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape +from mrpro.data.SpatialDimension import SpatialDimension from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py index 930b60c84..c458a69f4 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py @@ -25,7 +25,7 @@ def __call__( self, *, n_k0: int, - k0_center: int, + k0_center: int | torch.Tensor, k1_idx: torch.Tensor, reversed_readout_mask: torch.Tensor | None = None, **_, diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py index bd6dfdf64..86f45813e 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py @@ -83,9 +83,9 @@ def __call__( self, *, n_k0: int, - k0_center: int, + k0_center: int | torch.Tensor, k1_idx: torch.Tensor, - k1_center: int, + k1_center: int | torch.Tensor, k2_idx: torch.Tensor, reversed_readout_mask: torch.Tensor | None = None, **_, diff --git a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py index 0d3de5de7..2c34b7e6d 100644 --- a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py @@ -5,7 +5,7 @@ from einops import repeat from mrpro.data.KTrajectory import KTrajectory -from mrpro.data.traj_calculators import KTrajectoryCalculator +from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator GOLDEN_RATIO = 0.5 * (5**0.5 + 1) @@ -56,9 +56,9 @@ def __call__( self, *, n_k0: int, - k0_center: int, + k0_center: int | torch.Tensor, k1_idx: torch.Tensor, - k1_center: int, + k1_center: int | torch.Tensor, k2_idx: torch.Tensor, reversed_readout_mask: torch.Tensor | None = None, **_, diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 42a2b6bfa..69d4e3640 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -112,6 +112,7 @@ def test_KTrajectoryCartesian(valid_cartesian_kheader): k1_idx=torch.arange(n_k1)[None, None, :, None], k1_center=n_k1 // 2, k2_idx=torch.arange(n_k2)[None, :, None, None], + k2_center=n_k2 // 2, ) assert trajectory.kz.shape == (1, n_k2, 1, 1) assert trajectory.ky.shape == (1, 1, n_k1, 1) @@ -133,6 +134,7 @@ def test_KTrajectoryCartesian_bipolar(valid_cartesian_kheader_bipolar): k1_idx=torch.arange(n_k1)[None, None, :, None], k1_center=n_k1 // 2, k2_idx=torch.arange(n_k2)[None, :, None, None], + k2_center=n_k2 // 2, reversed_readout_mask=reversed_readout_mask, ) torch.testing.assert_close(trajectory.kx[..., 0, :], torch.flip(trajectory.kx[..., 1, :], dims=(-1,))) From 7dd0b2efff2bb5fd110b70b960d3e207ee761597 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 28 Nov 2024 16:50:19 +0100 Subject: [PATCH 05/13] Update [ghstack-poisoned] --- src/mrpro/data/KData.py | 27 ++++++++++++++++----------- tests/data/test_kdata.py | 7 +++++++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 4b5df6250..4a0f3256e 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -330,6 +330,13 @@ def compress_coils( from mrpro.operators import PCACompressionOp coil_dim = -4 % self.data.ndim + + if n_compressed_coils > (n_current_coils := self.data.shape[coil_dim]): + raise ValueError( + f'Number of compressed coils ({n_compressed_coils}) cannot be greater ' + f'than the number of current coils ({n_current_coils}).' + ) + if batch_dims is not None and joint_dims is not Ellipsis: raise ValueError('Either batch_dims or joint_dims can be defined not both.') @@ -347,22 +354,20 @@ def compress_coils( # reshape to (*batch dimension, -1, coils) permute_order = ( - batch_dims_normalized - + [i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized] - + [coil_dim] + *batch_dims_normalized, + *[i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized], + coil_dim, ) - kdata_coil_compressed = self.data.permute(permute_order) - permuted_kdata_shape = kdata_coil_compressed.shape - kdata_coil_compressed = kdata_coil_compressed.flatten( + kdata_permuted = self.data.permute(permute_order) + kdata_flattened = kdata_permuted.flatten( start_dim=len(batch_dims_normalized), end_dim=-2 ) # keep separate dimensions and coil - pca_compression_op = PCACompressionOp(data=kdata_coil_compressed, n_components=n_compressed_coils) - (kdata_coil_compressed,) = pca_compression_op(kdata_coil_compressed) - + pca_compression_op = PCACompressionOp(data=kdata_flattened, n_components=n_compressed_coils) + (kdata_coil_compressed_flattened,) = pca_compression_op(kdata_flattened) + del kdata_flattened # reshape to original dimensions and undo permutation kdata_coil_compressed = torch.reshape( - kdata_coil_compressed, [*permuted_kdata_shape[:-1], n_compressed_coils] + kdata_coil_compressed_flattened, [*kdata_permuted.shape[:-1], n_compressed_coils] ).permute(*np.argsort(permute_order)) - return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone()) diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index fa3e4ebd9..4a3183e8e 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -574,3 +574,10 @@ def test_KData_compress_coils_error_coil_dim(consistently_shaped_kdata): with pytest.raises(ValueError, match='Coil dimension must not'): consistently_shaped_kdata.compress_coils(n_compressed_coils=3, joint_dims=(-4,)) + + +def test_KData_compress_coils_error_n_coils(consistently_shaped_kdata): + """Test if error is raised if new coils would be larger than existing coils""" + existing_coils = consistently_shaped_kdata.data.shape[-4] + with pytest.raises(ValueError, match='greater'): + consistently_shaped_kdata.compress_coils(existing_coils + 1) From 8611f773d688702a77ad0740cccb8a43d01efb3b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 16 Dec 2024 23:11:15 +0100 Subject: [PATCH 06/13] Apply suggestions from code review Co-authored-by: Christoph Kolbitsch --- src/mrpro/data/AcqInfo.py | 2 +- .../data/traj_calculators/KTrajectoryCalculator.py | 12 ++++++------ src/mrpro/data/traj_calculators/KTrajectoryPulseq.py | 2 +- src/mrpro/data/traj_calculators/KTrajectoryRpe.py | 2 +- .../KTrajectorySunflowerGoldenRpe.py | 5 ++--- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 143a62efe..9a466a982 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -174,7 +174,7 @@ def from_ismrmrd_acquisitions( acquisitions list of ismrmrd acquisistions to read from. Needs at least one acquisition. additional_fields - if supplied, additional fields with these names will be from the ismrmrd acquisitions + if supplied, additional information from the fields with these names will be extracted from the ismrmrd acquisitions and returned as tensors. """ # Idea: create array of structs, then a struct of arrays, diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py index 2d541b1e4..212da9ca7 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py @@ -37,20 +37,20 @@ def __call__( ---------- n_k0 number of samples in k0 - k1_idx - indices of k1 - k2_idx - indices of k2 k0_center position of k-space center in k0 + k1_idx + indices of k1 k1_center position of k-space center in k1 + k2_idx + indices of k2 k2_center position of k-space center in k2 reversed_readout_mask - boolean tensor indicating reversed redout + boolean tensor indicating reversed readout encoding_matrix - encoding matrix, describing the extend of the k-space coordinates + encoding matrix diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index 91dda9d17..87a5b3ff1 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -42,7 +42,7 @@ def __call__( n_k0 number of samples in k0 encoding_matrix - encoding matrix, describing the extend of the k-space coordinates + encoding matrix Returns ------- diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py index ffc3a5bd5..0e1841e14 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py @@ -56,7 +56,7 @@ def _apply_shifts_between_rpe_lines(self, k_radial: torch.Tensor, idx: torch.Ten Parameters ---------- k_radial - k-space positions along each phase encoding line, zo be shifted + k-space positions along each phase encoding line, to be shifted idx indices used for shift calculation diff --git a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py index f54a1699b..66b30ab5e 100644 --- a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py @@ -16,7 +16,6 @@ class KTrajectorySunflowerGoldenRpe(KTrajectoryCalculator): def __init__(self) -> None: """Initialize KTrajectorySunflowerGoldenRpe. - Parameters ---------- radial_undersampling_factor undersampling factor along radial phase encoding direction. @@ -71,11 +70,11 @@ def __call__( k0_center position of k-space center in k0 k1_idx - indices of k1 + indices of k1 (radial) k1_center position of k-space center in k1 k2_idx - indices of k2 + indices of k2 (angle) reversed_readout_mask boolean tensor indicating reversed readout From 7ff70bcd4ee038e09d38dcd023a4c9227e02fb07 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 16 Dec 2024 23:28:50 +0100 Subject: [PATCH 07/13] Update [ghstack-poisoned] --- src/mrpro/data/AcqInfo.py | 2 +- .../data/traj_calculators/KTrajectoryCalculator.py | 12 ++++++------ src/mrpro/data/traj_calculators/KTrajectoryPulseq.py | 2 +- src/mrpro/data/traj_calculators/KTrajectoryRpe.py | 2 +- .../KTrajectorySunflowerGoldenRpe.py | 5 +++-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 9a466a982..143a62efe 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -174,7 +174,7 @@ def from_ismrmrd_acquisitions( acquisitions list of ismrmrd acquisistions to read from. Needs at least one acquisition. additional_fields - if supplied, additional information from the fields with these names will be extracted from the ismrmrd acquisitions + if supplied, additional fields with these names will be from the ismrmrd acquisitions and returned as tensors. """ # Idea: create array of structs, then a struct of arrays, diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py index 212da9ca7..2d541b1e4 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCalculator.py @@ -37,20 +37,20 @@ def __call__( ---------- n_k0 number of samples in k0 - k0_center - position of k-space center in k0 k1_idx indices of k1 - k1_center - position of k-space center in k1 k2_idx indices of k2 + k0_center + position of k-space center in k0 + k1_center + position of k-space center in k1 k2_center position of k-space center in k2 reversed_readout_mask - boolean tensor indicating reversed readout + boolean tensor indicating reversed redout encoding_matrix - encoding matrix + encoding matrix, describing the extend of the k-space coordinates diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index 87a5b3ff1..91dda9d17 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -42,7 +42,7 @@ def __call__( n_k0 number of samples in k0 encoding_matrix - encoding matrix + encoding matrix, describing the extend of the k-space coordinates Returns ------- diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py index 0e1841e14..ffc3a5bd5 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRpe.py @@ -56,7 +56,7 @@ def _apply_shifts_between_rpe_lines(self, k_radial: torch.Tensor, idx: torch.Ten Parameters ---------- k_radial - k-space positions along each phase encoding line, to be shifted + k-space positions along each phase encoding line, zo be shifted idx indices used for shift calculation diff --git a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py index 66b30ab5e..f54a1699b 100644 --- a/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py +++ b/src/mrpro/data/traj_calculators/KTrajectorySunflowerGoldenRpe.py @@ -16,6 +16,7 @@ class KTrajectorySunflowerGoldenRpe(KTrajectoryCalculator): def __init__(self) -> None: """Initialize KTrajectorySunflowerGoldenRpe. + Parameters ---------- radial_undersampling_factor undersampling factor along radial phase encoding direction. @@ -70,11 +71,11 @@ def __call__( k0_center position of k-space center in k0 k1_idx - indices of k1 (radial) + indices of k1 k1_center position of k-space center in k1 k2_idx - indices of k2 (angle) + indices of k2 reversed_readout_mask boolean tensor indicating reversed readout From c575567f98707275506b644ca75a7e49a4429118 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 24 Dec 2024 00:02:30 +0100 Subject: [PATCH 08/13] Update [ghstack-poisoned] --- src/mrpro/data/KHeader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mrpro/data/KHeader.py b/src/mrpro/data/KHeader.py index 29f0e7220..445e1ab4c 100644 --- a/src/mrpro/data/KHeader.py +++ b/src/mrpro/data/KHeader.py @@ -19,11 +19,9 @@ from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues from mrpro.utils.unit_conversion import mm_to_m, ms_to_s -from .traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator - if TYPE_CHECKING: # avoid circular imports by importing only when type checking - pass + from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator UNKNOWN = 'unknown' From 4c3d06d7afff9ec16475ef453d3dff68bf70feb6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 28 Dec 2024 02:45:54 +0100 Subject: [PATCH 09/13] Update [ghstack-poisoned] --- .../traj_calculators/KTrajectorySpiral2D.py | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py diff --git a/src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py b/src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py new file mode 100644 index 000000000..dbed3d5f0 --- /dev/null +++ b/src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py @@ -0,0 +1,139 @@ +import torch + +from mrpro.data import SpatialDimension +from mrpro.data.KTrajectory import KTrajectory +from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator + + +class KTrajectorySpiral(KTrajectoryCalculator): + """A Spiral variable density trajectory. + + Implements the spiral trajectory calculation as described in + Simple Analytic Variable Density Spiral Design by Kim et al., MRM 2003""" + + def __init__( + self, + max_gradient: float, + max_slewrate: float, + fov: SpatialDimension | float, + angle: float, + acceleration_per_interleave: float = 1.0, + density_factor: float = 1.0, + gamma: float = 42577478, + ): + """Create a spiral trajectory calculator. + + Parameters + ---------- + max_gradient + Maximum gradient amplitude [T/m]. + max_slewrate + Maximum slew rate [T/m/s]. + density_factor + Density factor alpha. + fov + Field of view [m]. + acceleration_per_interleave + Acceleration per interleave. + Overall acceleration is (acceleration_per_interleave/n_interleaves), + where n_interleaves is determined by k1_idx + angle + Angle between interleaves [rad]. + Usully set to 2pi/n_interleaves + gamma + Gyromagnetic ratio [Hz/T]. + """ + self.density_factor = density_factor + self.max_gradient_gamma = max_gradient * gamma + self.max_slewrate_gamma = max_slewrate * gamma + self.acceleration_per_interleave = acceleration_per_interleave + self.angle = angle + + if isinstance(fov, float): + self.fov = fov + elif fov.x != fov.y: + raise ValueError('Only square FOV is supported.') + elif fov.z != 0: + raise ValueError('Only 2D trajectories are supported.') + else: + self.fov = fov.x + + if self.fov <= 0: + raise ValueError('FOV must be positive.') + if self.acceleration_per_interleave <= 0: + raise ValueError('Acceleration per interleave must be positive.') + if self.max_gradient_gamma <= 0: + raise ValueError('Max gradient must be positive.') + if self.max_slewrate_gamma <= 0: + raise ValueError('Max slew rate must be positive.') + if self.density_factor <= 0: + raise ValueError('Density factor alpha must be positive.') + + def __call__( + self, + *, + n_k0: int, + k1_idx: torch.Tensor, + encoding_matrix: SpatialDimension, + **_, + ) -> KTrajectory: + """ + Calculate the spiral trajectory. + + Parameters + ---------- + n_k0 + Number of samples along a spiral interleave. + k1_idx + Integer index of the interleaves + encoding_matrix + Dimensions of the encoding matrix. + Only square matrices are supported. + + Returns + ------- + Spiral Trajectory + """ + if encoding_matrix.x != encoding_matrix.y: + raise ValueError('Only square encoding matrices are supported.') + if encoding_matrix.z != 1: + raise ValueError('Only 2D trajectories are supported.') + + lam = 0.5 * (encoding_matrix.x / self.fov) + n_turns = 1 / ( + 1 - (1 - (2 * self.acceleration_per_interleave) / encoding_matrix.x) ** (1 / self.density_factor) + ) # eq. 10 + max_angle = 2 * torch.pi * n_turns + end_time_amplitude = (lam * max_angle) / (self.max_gradient_gamma * (self.density_factor + 1)) # eq. 5, Tes + end_time_slew = torch.sqrt(lam * max_angle**2 / (self.max_slewrate_gamma)) / ( + self.density_factor / 2 + 1 + ) # eq. 8, Tea + + transition_time_slew_to_amplitude = ( + end_time_slew ** ((self.density_factor + 1) / (self.density_factor / 2 + 1)) + * (self.density_factor / 2 + 1) + / end_time_amplitude + / (self.density_factor + 1) + ) ** (1 + 2 / self.density_factor) # eq. 9, Ts2a + + has_amplitude_phase = transition_time_slew_to_amplitude < end_time_slew + end_time = end_time_amplitude if has_amplitude_phase else end_time_slew + + def tau(t: torch.Tensor) -> torch.Tensor: + """Normalized time function.""" + # eq. 11 + slew_phase = (t / end_time_slew) ** (1 / (self.density_factor / 2 + 1)) + slew_phase = slew_phase * ((t >= 0) * (t <= transition_time_slew_to_amplitude)) + if not has_amplitude_phase: + return slew_phase + amplitude_phase = (t / end_time_amplitude) ** (1 / (self.density_factor + 1)) + amplitude_phase = amplitude_phase * ((t > transition_time_slew_to_amplitude) * (t <= end_time_amplitude)) + return slew_phase + amplitude_phase + + t = torch.linspace(0, end_time, n_k0) + tau_t = tau(t) + k = lam * tau_t**self.density_factor * torch.exp(1j * max_angle * tau_t) # eq. 2 + phase_rotation = torch.exp(self.angle * k1_idx) + k = k[None, :] * phase_rotation[:, None] + trajectory = KTrajectory(kx=k.real, ky=k.imag, kz=torch.zeros_like(k.real)) + return trajectory From 3a19c0048c58f6bc97c16ab0a879d7cd6fd5a9eb Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 28 Dec 2024 02:49:04 +0100 Subject: [PATCH 10/13] Update [ghstack-poisoned] --- .../traj_calculators/KTrajectorySpiral2D.py | 139 ------------------ 1 file changed, 139 deletions(-) delete mode 100644 src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py diff --git a/src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py b/src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py deleted file mode 100644 index dbed3d5f0..000000000 --- a/src/mrpro/data/traj_calculators/KTrajectorySpiral2D.py +++ /dev/null @@ -1,139 +0,0 @@ -import torch - -from mrpro.data import SpatialDimension -from mrpro.data.KTrajectory import KTrajectory -from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator - - -class KTrajectorySpiral(KTrajectoryCalculator): - """A Spiral variable density trajectory. - - Implements the spiral trajectory calculation as described in - Simple Analytic Variable Density Spiral Design by Kim et al., MRM 2003""" - - def __init__( - self, - max_gradient: float, - max_slewrate: float, - fov: SpatialDimension | float, - angle: float, - acceleration_per_interleave: float = 1.0, - density_factor: float = 1.0, - gamma: float = 42577478, - ): - """Create a spiral trajectory calculator. - - Parameters - ---------- - max_gradient - Maximum gradient amplitude [T/m]. - max_slewrate - Maximum slew rate [T/m/s]. - density_factor - Density factor alpha. - fov - Field of view [m]. - acceleration_per_interleave - Acceleration per interleave. - Overall acceleration is (acceleration_per_interleave/n_interleaves), - where n_interleaves is determined by k1_idx - angle - Angle between interleaves [rad]. - Usully set to 2pi/n_interleaves - gamma - Gyromagnetic ratio [Hz/T]. - """ - self.density_factor = density_factor - self.max_gradient_gamma = max_gradient * gamma - self.max_slewrate_gamma = max_slewrate * gamma - self.acceleration_per_interleave = acceleration_per_interleave - self.angle = angle - - if isinstance(fov, float): - self.fov = fov - elif fov.x != fov.y: - raise ValueError('Only square FOV is supported.') - elif fov.z != 0: - raise ValueError('Only 2D trajectories are supported.') - else: - self.fov = fov.x - - if self.fov <= 0: - raise ValueError('FOV must be positive.') - if self.acceleration_per_interleave <= 0: - raise ValueError('Acceleration per interleave must be positive.') - if self.max_gradient_gamma <= 0: - raise ValueError('Max gradient must be positive.') - if self.max_slewrate_gamma <= 0: - raise ValueError('Max slew rate must be positive.') - if self.density_factor <= 0: - raise ValueError('Density factor alpha must be positive.') - - def __call__( - self, - *, - n_k0: int, - k1_idx: torch.Tensor, - encoding_matrix: SpatialDimension, - **_, - ) -> KTrajectory: - """ - Calculate the spiral trajectory. - - Parameters - ---------- - n_k0 - Number of samples along a spiral interleave. - k1_idx - Integer index of the interleaves - encoding_matrix - Dimensions of the encoding matrix. - Only square matrices are supported. - - Returns - ------- - Spiral Trajectory - """ - if encoding_matrix.x != encoding_matrix.y: - raise ValueError('Only square encoding matrices are supported.') - if encoding_matrix.z != 1: - raise ValueError('Only 2D trajectories are supported.') - - lam = 0.5 * (encoding_matrix.x / self.fov) - n_turns = 1 / ( - 1 - (1 - (2 * self.acceleration_per_interleave) / encoding_matrix.x) ** (1 / self.density_factor) - ) # eq. 10 - max_angle = 2 * torch.pi * n_turns - end_time_amplitude = (lam * max_angle) / (self.max_gradient_gamma * (self.density_factor + 1)) # eq. 5, Tes - end_time_slew = torch.sqrt(lam * max_angle**2 / (self.max_slewrate_gamma)) / ( - self.density_factor / 2 + 1 - ) # eq. 8, Tea - - transition_time_slew_to_amplitude = ( - end_time_slew ** ((self.density_factor + 1) / (self.density_factor / 2 + 1)) - * (self.density_factor / 2 + 1) - / end_time_amplitude - / (self.density_factor + 1) - ) ** (1 + 2 / self.density_factor) # eq. 9, Ts2a - - has_amplitude_phase = transition_time_slew_to_amplitude < end_time_slew - end_time = end_time_amplitude if has_amplitude_phase else end_time_slew - - def tau(t: torch.Tensor) -> torch.Tensor: - """Normalized time function.""" - # eq. 11 - slew_phase = (t / end_time_slew) ** (1 / (self.density_factor / 2 + 1)) - slew_phase = slew_phase * ((t >= 0) * (t <= transition_time_slew_to_amplitude)) - if not has_amplitude_phase: - return slew_phase - amplitude_phase = (t / end_time_amplitude) ** (1 / (self.density_factor + 1)) - amplitude_phase = amplitude_phase * ((t > transition_time_slew_to_amplitude) * (t <= end_time_amplitude)) - return slew_phase + amplitude_phase - - t = torch.linspace(0, end_time, n_k0) - tau_t = tau(t) - k = lam * tau_t**self.density_factor * torch.exp(1j * max_angle * tau_t) # eq. 2 - phase_rotation = torch.exp(self.angle * k1_idx) - k = k[None, :] * phase_rotation[:, None] - trajectory = KTrajectory(kx=k.real, ky=k.imag, kz=torch.zeros_like(k.real)) - return trajectory From a564f3ce7887a32610db5a91687afc98111eebf7 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 28 Dec 2024 03:00:24 +0100 Subject: [PATCH 11/13] Update [ghstack-poisoned] --- tests/data/test_kheader.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/data/test_kheader.py b/tests/data/test_kheader.py index 55cbfdf67..d1da71281 100644 --- a/tests/data/test_kheader.py +++ b/tests/data/test_kheader.py @@ -1,22 +1,16 @@ """Tests for KHeader class.""" -import pytest import torch from mrpro.data import KHeader from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory -def test_kheader_fail_from_mandatory_ismrmrd_header(random_mandatory_ismrmrd_header, random_acq_info): - """KHeader cannot be created only from ismrmrd header because trajectory is missing.""" - with pytest.raises(ValueError, match='Could not create Header'): - _ = KHeader.from_ismrmrd(random_mandatory_ismrmrd_header, random_acq_info) - - def test_kheader_overwrite_missing_parameter(random_mandatory_ismrmrd_header, random_acq_info): """KHeader can be created if trajectory is provided.""" overwrite = {'trajectory': DummyTrajectory()} - kheader = KHeader.from_ismrmrd(random_mandatory_ismrmrd_header, random_acq_info, overwrite=overwrite) + kheader = KHeader.from_ismrmrd(random_mandatory_ismrmrd_header, random_acq_info) assert kheader is not None + assert kheader.trajectory is overwrite['trajectory'] def test_kheader_set_missing_defaults(random_mandatory_ismrmrd_header, random_acq_info): @@ -24,6 +18,7 @@ def test_kheader_set_missing_defaults(random_mandatory_ismrmrd_header, random_ac defaults = {'trajectory': DummyTrajectory()} kheader = KHeader.from_ismrmrd(random_mandatory_ismrmrd_header, random_acq_info, defaults=defaults) assert kheader is not None + assert kheader.trajectory is defaults['trajectory'] def test_kheader_verify_None(random_mandatory_ismrmrd_header, random_acq_info): From 2b3115d90e988a6c920f0425bfb607d924703a92 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 30 Dec 2024 02:57:51 +0100 Subject: [PATCH 12/13] Update [ghstack-poisoned] --- tests/data/test_kheader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_kheader.py b/tests/data/test_kheader.py index d1da71281..309069d70 100644 --- a/tests/data/test_kheader.py +++ b/tests/data/test_kheader.py @@ -8,7 +8,7 @@ def test_kheader_overwrite_missing_parameter(random_mandatory_ismrmrd_header, random_acq_info): """KHeader can be created if trajectory is provided.""" overwrite = {'trajectory': DummyTrajectory()} - kheader = KHeader.from_ismrmrd(random_mandatory_ismrmrd_header, random_acq_info) + kheader = KHeader.from_ismrmrd(random_mandatory_ismrmrd_header, random_acq_info, overwrite=overwrite) assert kheader is not None assert kheader.trajectory is overwrite['trajectory'] From 4dc9f0b5ffada8b983b031013579d8f2a6c6fa74 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 3 Jan 2025 16:17:16 +0100 Subject: [PATCH 13/13] Update [ghstack-poisoned] --- src/mrpro/data/KData.py | 300 ++++++++++++++++++- src/mrpro/data/_kdata/KDataProtocol.py | 41 --- src/mrpro/data/_kdata/KDataRearrangeMixin.py | 41 --- src/mrpro/data/_kdata/KDataRemoveOsMixin.py | 75 ----- src/mrpro/data/_kdata/KDataSelectMixin.py | 65 ---- src/mrpro/data/_kdata/KDataSplitMixin.py | 161 ---------- 6 files changed, 293 insertions(+), 390 deletions(-) delete mode 100644 src/mrpro/data/_kdata/KDataProtocol.py delete mode 100644 src/mrpro/data/_kdata/KDataRearrangeMixin.py delete mode 100644 src/mrpro/data/_kdata/KDataRemoveOsMixin.py delete mode 100644 src/mrpro/data/_kdata/KDataSelectMixin.py delete mode 100644 src/mrpro/data/_kdata/KDataSplitMixin.py diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 4b5df6250..47ddc27d2 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -1,23 +1,21 @@ """MR raw data / k-space data class.""" +import copy import dataclasses import datetime import warnings from collections.abc import Callable, Sequence from pathlib import Path from types import EllipsisType +from typing import Literal, cast import h5py import ismrmrd import numpy as np import torch -from einops import rearrange -from typing_extensions import Self +from einops import rearrange, repeat +from typing_extensions import Self, TypeVar -from mrpro.data._kdata.KDataRearrangeMixin import KDataRearrangeMixin -from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin -from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin -from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits @@ -29,6 +27,8 @@ from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) + KDIM_SORT_LABELS = ( 'k1', 'k2', @@ -63,7 +63,9 @@ @dataclasses.dataclass(slots=True, frozen=True) -class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveOsMixin, MoveDataMixin): +class KData( + MoveDataMixin, +): """MR raw data / k-space data class.""" header: KHeader @@ -366,3 +368,287 @@ def compress_coils( ).permute(*np.argsort(permute_order)) return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone()) + + def rearrange_k2_k1_into_k1(self: Self) -> Self: + """Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...). + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + + Returns + ------- + K-space data (other coils 1 (k2 k1) k0) + """ + # Rearrange data + kdat = rearrange(self.data, '... coils k2 k1 k0->... coils 1 (k2 k1) k0') + + # Rearrange trajectory + ktraj = rearrange(self.traj.as_tensor(), 'dim ... k2 k1 k0-> dim ... 1 (k2 k1) k0') + + # Create new header with correct shape + kheader = copy.deepcopy(self.header) + + # Update shape of acquisition info index + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) + + return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) + + def remove_readout_os(self: Self) -> Self: + """Remove any oversampling along the readout (k0) direction [GAD]_. + + Returns a copy of the data. + + Parameters + ---------- + kdata + K-space data + + Returns + ------- + Copy of K-space data with oversampling removed. + + Raises + ------ + ValueError + If the recon matrix along x is larger than the encoding matrix along x. + + References + ---------- + .. [GAD] Gadgetron https://github.com/gadgetron/gadgetron-python + """ + from mrpro.operators.FastFourierOp import FastFourierOp + + # Ratio of k0/x between encoded and recon space + x_ratio = self.header.recon_matrix.x / self.header.encoding_matrix.x + if x_ratio == 1: + # If the encoded and recon space is the same we don't have to do anything + return self + elif x_ratio > 1: + raise ValueError('Recon matrix along x should be equal or larger than encoding matrix along x.') + + # Starting and end point of image after removing oversampling + start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2 + end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x + + def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: + # returns a cropped copy + return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone() + + # Transform to image space along readout, crop to reconstruction matrix size and transform back + fourier_k0_op = FastFourierOp(dim=(-1,)) + (cropped_data,) = fourier_k0_op(crop_readout(*fourier_k0_op.H(self.data))) + + # Adapt trajectory + ks = [self.traj.kz, self.traj.ky, self.traj.kx] + # only cropped ks that are not broadcasted/singleton along k0 + cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks] + cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2]) + + # Adapt header parameters + header = copy.deepcopy(self.header) + header.acq_info.center_sample -= start_cropped_readout + header.acq_info.number_of_samples[:] = cropped_data.shape[-1] + header.encoding_matrix.x = cropped_data.shape[-1] + + header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32) + header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32) + + return type(self)(header, cropped_data, cropped_traj) + + def select_other_subset( + self: Self, + subset_idx: torch.Tensor, + subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + ) -> Self: + """Select a subset from the other dimension of KData. + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + subset_idx + Index which elements of the other subset to use, e.g. phase 0,1,2 and 5 + subset_label + Name of the other label, e.g. phase + + Returns + ------- + K-space data (other_subset coils k2 k1 k0) + + Raises + ------ + ValueError + If the subset indices are not available in the data + """ + # Make a copy such that the original kdata.header remains the same + kheader = copy.deepcopy(self.header) + ktraj = self.traj.as_tensor() + + # Verify that the subset_idx is available + label_idx = getattr(kheader.acq_info.idx, subset_label) + if not all(el in torch.unique(label_idx) for el in subset_idx): + raise ValueError('Subset indices are outside of the available index range') + + # Find subset index in acq_info index + other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) + + # Adapt header + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) + + # Select data + kdat = self.data[other_idx, ...] + + # Select ktraj + if ktraj.shape[1] > 1: + ktraj = ktraj[:, other_idx, ...] + + return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) + + def _split_k2_or_k1_into_other( + self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + split_dir: Literal['k2', 'k1'], + ) -> Self: + """Based on an index tensor, split the data in e.g. phases. + + Parameters + ---------- + split_idx + 2D index describing the k2 or k1 points in each block to be moved to the other dimension + (other_split, k1_per_split) or (other_split, k2_per_split) + other_label + Label of other dimension, e.g. repetition, phase + split_dir + Dimension to split, either 'k1' or 'k2' + + Returns + ------- + K-space data with new shape + ((other other_split) coils k2 k1_per_split k0) or ((other other_split) coils k2_per_split k1 k0) + + Raises + ------ + ValueError + Already existing "other_label" can only be of length 1 + """ + # Number of other + n_other = split_idx.shape[0] + + # Verify that the specified label of the other dimension is unused + if getattr(self.header.encoding_limits, other_label).length > 1: + raise ValueError(f'{other_label} is already used to encode different parts of the scan.') + + # Set-up splitting + if split_dir == 'k1': + # Split along k1 dimensions + def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: + return dat_traj[:, :, :, split_idx, :] + + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) + + # Rearrange other_split and k1 dimension + rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' + rearrange_pattern_traj = 'dim other k2 other_split k1 k0->dim (other other_split) k2 k1 k0' + rearrange_pattern_acq_info = 'other k2 other_split k1 ... -> (other other_split) k2 k1 ...' + + elif split_dir == 'k2': + # Split along k2 dimensions + def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: + return dat_traj[:, :, split_idx, :, :] + + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) + + # Rearrange other_split and k1 dimension + rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' + rearrange_pattern_traj = 'dim other other_split k2 k1 k0->dim (other other_split) k2 k1 k0' + rearrange_pattern_acq_info = 'other other_split k2 k1 ... -> (other other_split) k2 k1 ...' + + else: + raise ValueError('split_dir has to be "k1" or "k2"') + + # Split data + kdat = rearrange(split_data_traj(self.data), rearrange_pattern_data) + + # First we need to make sure the other dimension is the same as data then we can split the trajectory + ktraj = self.traj.as_tensor() + # Verify that other dimension of trajectory is 1 or matches data + if ktraj.shape[1] > 1 and ktraj.shape[1] != self.data.shape[0]: + raise ValueError(f'other dimension of trajectory has to be 1 or match data ({self.data.shape[0]})') + elif ktraj.shape[1] == 1 and self.data.shape[0] > 1: + ktraj = repeat(ktraj, 'dim other k2 k1 k0->dim (other_data other) k2 k1 k0', other_data=self.data.shape[0]) + ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) + + # Create new header with correct shape + kheader = self.header.clone() + + # Update shape of acquisition info index + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) + + # Update other label limits and acquisition info + setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) + + # acq_info for new other dimensions + acq_info_other_split = repeat( + torch.linspace(0, n_other - 1, n_other), 'other-> other k2 k1', k2=kdat.shape[-3], k1=kdat.shape[-2] + ) + setattr(kheader.acq_info.idx, other_label, acq_info_other_split) + + return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) + + def split_k1_into_other( + self: Self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + ) -> Self: + """Based on an index tensor, split the data in e.g. phases. + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + split_idx + 2D index describing the k1 points in each block to be moved to other dimension (other_split, k1_per_split) + other_label + Label of other dimension, e.g. repetition, phase + + Returns + ------- + K-space data with new shape ((other other_split) coils k2 k1_per_split k0) + """ + return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1') + + def split_k2_into_other( + self: Self, + split_idx: torch.Tensor, + other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], + ) -> Self: + """Based on an index tensor, split the data in e.g. phases. + + Parameters + ---------- + kdata + K-space data (other coils k2 k1 k0) + split_idx + 2D index describing the k2 points in each block to be moved to other dimension (other_split, k2_per_split) + other_label + Label of other dimension, e.g. repetition, phase + + Returns + ------- + K-space data with new shape ((other other_split) coils k2_per_split k1 k0) + """ + return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k2') diff --git a/src/mrpro/data/_kdata/KDataProtocol.py b/src/mrpro/data/_kdata/KDataProtocol.py deleted file mode 100644 index 485a8fc4d..000000000 --- a/src/mrpro/data/_kdata/KDataProtocol.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Protocol for KData.""" - -from typing import Literal - -import torch -from typing_extensions import Protocol, Self - -from mrpro.data.KHeader import KHeader -from mrpro.data.KTrajectory import KTrajectory - - -class _KDataProtocol(Protocol): - """Protocol for KData used for type hinting in KData mixins. - - Note that the actual KData class can have more properties and methods than those defined here. - - If you want to use a property or method of KData in a new KDataMixin class, - you must add it to this Protocol to make sure that the type hinting works [PRO]_. - - References - ---------- - .. [PRO] Protocols https://typing.readthedocs.io/en/latest/spec/protocol.html#protocols - """ - - @property - def header(self) -> KHeader: ... - - @property - def data(self) -> torch.Tensor: ... - - @property - def traj(self) -> KTrajectory: ... - - def __init__(self, header: KHeader, data: torch.Tensor, traj: KTrajectory): ... - - def _split_k2_or_k1_into_other( - self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - split_dir: Literal['k1', 'k2'], - ) -> Self: ... diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py deleted file mode 100644 index 23a58dea6..000000000 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Rearrange KData.""" - -import copy - -from einops import rearrange -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import rearrange_acq_info_fields - - -class KDataRearrangeMixin(_KDataProtocol): - """Rearrange KData.""" - - def rearrange_k2_k1_into_k1(self: Self) -> Self: - """Rearrange kdata from (... k2 k1 ...) to (... 1 (k2 k1) ...). - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - - Returns - ------- - K-space data (other coils 1 (k2 k1) k0) - """ - # Rearrange data - kdat = rearrange(self.data, '... coils k2 k1 k0->... coils 1 (k2 k1) k0') - - # Rearrange trajectory - ktraj = rearrange(self.traj.as_tensor(), 'dim ... k2 k1 k0-> dim ... 1 (k2 k1) k0') - - # Create new header with correct shape - kheader = copy.deepcopy(self.header) - - # Update shape of acquisition info index - kheader.acq_info.apply_( - lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') - ) - - return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py b/src/mrpro/data/_kdata/KDataRemoveOsMixin.py deleted file mode 100644 index 555f56a39..000000000 --- a/src/mrpro/data/_kdata/KDataRemoveOsMixin.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Remove oversampling along readout dimension.""" - -from copy import deepcopy - -import torch -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.KTrajectory import KTrajectory - - -class KDataRemoveOsMixin(_KDataProtocol): - """Remove oversampling along readout dimension.""" - - def remove_readout_os(self: Self) -> Self: - """Remove any oversampling along the readout (k0) direction [GAD]_. - - Returns a copy of the data. - - Parameters - ---------- - kdata - K-space data - - Returns - ------- - Copy of K-space data with oversampling removed. - - Raises - ------ - ValueError - If the recon matrix along x is larger than the encoding matrix along x. - - References - ---------- - .. [GAD] Gadgetron https://github.com/gadgetron/gadgetron-python - """ - from mrpro.operators.FastFourierOp import FastFourierOp - - # Ratio of k0/x between encoded and recon space - x_ratio = self.header.recon_matrix.x / self.header.encoding_matrix.x - if x_ratio == 1: - # If the encoded and recon space is the same we don't have to do anything - return self - elif x_ratio > 1: - raise ValueError('Recon matrix along x should be equal or larger than encoding matrix along x.') - - # Starting and end point of image after removing oversampling - start_cropped_readout = (self.header.encoding_matrix.x - self.header.recon_matrix.x) // 2 - end_cropped_readout = start_cropped_readout + self.header.recon_matrix.x - - def crop_readout(data_to_crop: torch.Tensor) -> torch.Tensor: - # returns a cropped copy - return data_to_crop[..., start_cropped_readout:end_cropped_readout].clone() - - # Transform to image space along readout, crop to reconstruction matrix size and transform back - fourier_k0_op = FastFourierOp(dim=(-1,)) - (cropped_data,) = fourier_k0_op(crop_readout(*fourier_k0_op.H(self.data))) - - # Adapt trajectory - ks = [self.traj.kz, self.traj.ky, self.traj.kx] - # only cropped ks that are not broadcasted/singleton along k0 - cropped_ks = [crop_readout(k) if k.shape[-1] > 1 else k.clone() for k in ks] - cropped_traj = KTrajectory(cropped_ks[0], cropped_ks[1], cropped_ks[2]) - - # Adapt header parameters - header = deepcopy(self.header) - header.acq_info.center_sample -= start_cropped_readout - header.acq_info.number_of_samples[:] = cropped_data.shape[-1] - header.encoding_matrix.x = cropped_data.shape[-1] - - header.acq_info.discard_post = (header.acq_info.discard_post * x_ratio).to(torch.int32) - header.acq_info.discard_pre = (header.acq_info.discard_pre * x_ratio).to(torch.int32) - - return type(self)(header, cropped_data, cropped_traj) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py deleted file mode 100644 index 8f8a452cf..000000000 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Select subset along other dimensions of KData.""" - -import copy -from typing import Literal - -import torch -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.Rotation import Rotation - - -class KDataSelectMixin(_KDataProtocol): - """Select subset of KData.""" - - def select_other_subset( - self: Self, - subset_idx: torch.Tensor, - subset_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> Self: - """Select a subset from the other dimension of KData. - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - subset_idx - Index which elements of the other subset to use, e.g. phase 0,1,2 and 5 - subset_label - Name of the other label, e.g. phase - - Returns - ------- - K-space data (other_subset coils k2 k1 k0) - - Raises - ------ - ValueError - If the subset indices are not available in the data - """ - # Make a copy such that the original kdata.header remains the same - kheader = copy.deepcopy(self.header) - ktraj = self.traj.as_tensor() - - # Verify that the subset_idx is available - label_idx = getattr(kheader.acq_info.idx, subset_label) - if not all(el in torch.unique(label_idx) for el in subset_idx): - raise ValueError('Subset indices are outside of the available index range') - - # Find subset index in acq_info index - other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) - - # Adapt header - kheader.acq_info.apply_( - lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field - ) - - # Select data - kdat = self.data[other_idx, ...] - - # Select ktraj - if ktraj.shape[1] > 1: - ktraj = ktraj[:, other_idx, ...] - - return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py deleted file mode 100644 index c28004af4..000000000 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Mixin class to split KData into other subsets.""" - -from typing import Literal, TypeVar, cast - -import torch -from einops import rearrange, repeat -from typing_extensions import Self - -from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import rearrange_acq_info_fields -from mrpro.data.EncodingLimits import Limits -from mrpro.data.Rotation import Rotation - -RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) - - -class KDataSplitMixin(_KDataProtocol): - """Split KData into other subsets.""" - - def _split_k2_or_k1_into_other( - self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - split_dir: Literal['k2', 'k1'], - ) -> Self: - """Based on an index tensor, split the data in e.g. phases. - - Parameters - ---------- - split_idx - 2D index describing the k2 or k1 points in each block to be moved to the other dimension - (other_split, k1_per_split) or (other_split, k2_per_split) - other_label - Label of other dimension, e.g. repetition, phase - split_dir - Dimension to split, either 'k1' or 'k2' - - Returns - ------- - K-space data with new shape - ((other other_split) coils k2 k1_per_split k0) or ((other other_split) coils k2_per_split k1 k0) - - Raises - ------ - ValueError - Already existing "other_label" can only be of length 1 - """ - # Number of other - n_other = split_idx.shape[0] - - # Verify that the specified label of the other dimension is unused - if getattr(self.header.encoding_limits, other_label).length > 1: - raise ValueError(f'{other_label} is already used to encode different parts of the scan.') - - # Set-up splitting - if split_dir == 'k1': - # Split along k1 dimensions - def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: - return dat_traj[:, :, :, split_idx, :] - - def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: - # cast due to https://github.com/python/mypy/issues/10817 - return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) - - # Rearrange other_split and k1 dimension - rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' - rearrange_pattern_traj = 'dim other k2 other_split k1 k0->dim (other other_split) k2 k1 k0' - rearrange_pattern_acq_info = 'other k2 other_split k1 ... -> (other other_split) k2 k1 ...' - - elif split_dir == 'k2': - # Split along k2 dimensions - def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: - return dat_traj[:, :, split_idx, :, :] - - def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: - return cast(RotationOrTensor, acq_info[:, split_idx, ...]) - - # Rearrange other_split and k1 dimension - rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' - rearrange_pattern_traj = 'dim other other_split k2 k1 k0->dim (other other_split) k2 k1 k0' - rearrange_pattern_acq_info = 'other other_split k2 k1 ... -> (other other_split) k2 k1 ...' - - else: - raise ValueError('split_dir has to be "k1" or "k2"') - - # Split data - kdat = rearrange(split_data_traj(self.data), rearrange_pattern_data) - - # First we need to make sure the other dimension is the same as data then we can split the trajectory - ktraj = self.traj.as_tensor() - # Verify that other dimension of trajectory is 1 or matches data - if ktraj.shape[1] > 1 and ktraj.shape[1] != self.data.shape[0]: - raise ValueError(f'other dimension of trajectory has to be 1 or match data ({self.data.shape[0]})') - elif ktraj.shape[1] == 1 and self.data.shape[0] > 1: - ktraj = repeat(ktraj, 'dim other k2 k1 k0->dim (other_data other) k2 k1 k0', other_data=self.data.shape[0]) - ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) - - # Create new header with correct shape - kheader = self.header.clone() - - # Update shape of acquisition info index - kheader.acq_info.apply_( - lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) - if isinstance(field, Rotation | torch.Tensor) - else field - ) - - # Update other label limits and acquisition info - setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) - - # acq_info for new other dimensions - acq_info_other_split = repeat( - torch.linspace(0, n_other - 1, n_other), 'other-> other k2 k1', k2=kdat.shape[-3], k1=kdat.shape[-2] - ) - setattr(kheader.acq_info.idx, other_label, acq_info_other_split) - - return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) - - def split_k1_into_other( - self: Self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> Self: - """Based on an index tensor, split the data in e.g. phases. - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - split_idx - 2D index describing the k1 points in each block to be moved to other dimension (other_split, k1_per_split) - other_label - Label of other dimension, e.g. repetition, phase - - Returns - ------- - K-space data with new shape ((other other_split) coils k2 k1_per_split k0) - """ - return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k1') - - def split_k2_into_other( - self: Self, - split_idx: torch.Tensor, - other_label: Literal['average', 'slice', 'contrast', 'phase', 'repetition', 'set'], - ) -> Self: - """Based on an index tensor, split the data in e.g. phases. - - Parameters - ---------- - kdata - K-space data (other coils k2 k1 k0) - split_idx - 2D index describing the k2 points in each block to be moved to other dimension (other_split, k2_per_split) - other_label - Label of other dimension, e.g. repetition, phase - - Returns - ------- - K-space data with new shape ((other other_split) coils k2_per_split k1 k0) - """ - return self._split_k2_or_k1_into_other(split_idx, other_label, split_dir='k2')