From fc04a6480d9285f8e5c0d2d47913a58611302a89 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Jan 2025 17:23:29 +0200 Subject: [PATCH] torch: allow python scalars in result_type --- array_api_compat/torch/_aliases.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4cb6451..5b20aab 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -125,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True): x1 = x1.to(dtype) return x1, x2 -def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: + +_py_scalars = (bool, int, float, complex) + + +def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: if len(arrays_and_dtypes) == 0: raise TypeError("At least one array or dtype must be provided") if len(arrays_and_dtypes) == 1: @@ -137,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) x, y = arrays_and_dtypes + if isinstance(x, _py_scalars) or isinstance(y, _py_scalars): + return torch.result_type(x, y) + xdt = x.dtype if not isinstance(x, torch.dtype) else x ydt = y.dtype if not isinstance(y, torch.dtype) else y