Skip to content

Commit

Permalink
torch: allow python scalars in result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Feb 1, 2025
1 parent 4d81a0d commit fc04a64
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True):
x1 = x1.to(dtype)
return x1, x2

def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:

_py_scalars = (bool, int, float, complex)


def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
if len(arrays_and_dtypes) == 0:
raise TypeError("At least one array or dtype must be provided")
if len(arrays_and_dtypes) == 1:
Expand All @@ -137,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))

x, y = arrays_and_dtypes
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
return torch.result_type(x, y)

xdt = x.dtype if not isinstance(x, torch.dtype) else x
ydt = y.dtype if not isinstance(y, torch.dtype) else y

Expand Down

0 comments on commit fc04a64

Please sign in to comment.