Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure compressed number of coils is not greater than existing number #567

Merged
merged 18 commits into from
Jan 7, 2025
26 changes: 16 additions & 10 deletions src/mrpro/data/KData.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ def compress_coils(
from mrpro.operators import PCACompressionOp

coil_dim = -4 % self.data.ndim

if n_compressed_coils > (n_current_coils := self.data.shape[coil_dim]):
raise ValueError(
f'Number of compressed coils ({n_compressed_coils}) cannot be greater '
f'than the number of current coils ({n_current_coils}).'
)

if batch_dims is not None and joint_dims is not Ellipsis:
raise ValueError('Either batch_dims or joint_dims can be defined not both.')

Expand All @@ -349,22 +356,21 @@ def compress_coils(

# reshape to (*batch dimension, -1, coils)
permute_order = (
batch_dims_normalized
+ [i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized]
+ [coil_dim]
*batch_dims_normalized,
*[i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized],
coil_dim,
)
kdata_coil_compressed = self.data.permute(permute_order)
permuted_kdata_shape = kdata_coil_compressed.shape
kdata_coil_compressed = kdata_coil_compressed.flatten(
kdata_permuted = self.data.permute(permute_order)
kdata_flattened = kdata_permuted.flatten(
start_dim=len(batch_dims_normalized), end_dim=-2
) # keep separate dimensions and coil

pca_compression_op = PCACompressionOp(data=kdata_coil_compressed, n_components=n_compressed_coils)
(kdata_coil_compressed,) = pca_compression_op(kdata_coil_compressed)

pca_compression_op = PCACompressionOp(data=kdata_flattened, n_components=n_compressed_coils)
(kdata_coil_compressed_flattened,) = pca_compression_op(kdata_flattened)
del kdata_flattened
# reshape to original dimensions and undo permutation
kdata_coil_compressed = torch.reshape(
kdata_coil_compressed, [*permuted_kdata_shape[:-1], n_compressed_coils]
kdata_coil_compressed_flattened, [*kdata_permuted.shape[:-1], n_compressed_coils]
).permute(*np.argsort(permute_order))

return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone())
Expand Down
7 changes: 7 additions & 0 deletions tests/data/test_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,10 @@ def test_KData_compress_coils_error_coil_dim(consistently_shaped_kdata):

with pytest.raises(ValueError, match='Coil dimension must not'):
consistently_shaped_kdata.compress_coils(n_compressed_coils=3, joint_dims=(-4,))


def test_KData_compress_coils_error_n_coils(consistently_shaped_kdata):
"""Test if error is raised if new coils would be larger than existing coils"""
existing_coils = consistently_shaped_kdata.data.shape[-4]
with pytest.raises(ValueError, match='greater'):
consistently_shaped_kdata.compress_coils(existing_coils + 1)
Loading