diff --git a/quaternions.py b/quaternions.py index 2ad6a5d..b1a6067 100644 --- a/quaternions.py +++ b/quaternions.py @@ -61,7 +61,7 @@ def pure_quat(v): def quat_inv(q): #Note, 'empty_like' is necessary to prevent in-place modification (which is not auto-diff'able) if q.dim() < 2: - q = q.unsqueeze() + q = q.unsqueeze(dim=0) q_inv = torch.empty_like(q) q_inv[:, :3] = -1*q[:, :3] q_inv[:, 3] = q[:, 3]