Skip to content

Commit

Permalink
[TOSA] Refactor: Use tosaCastTensorToType() function to create tosa.cast
Browse files Browse the repository at this point in the history
* Previously, in Torch to TOSA, there are 3 ways to create tosa.cast op:
  - `rewriter.create<tosa::CastOp>()`
  - `tosa::promoteType()`
  - `tosa::tosaCastTensorToType()`
* This commit combines the three APIs above into
  `tosa::tosaCastTensorToType()` with the following features:
  - Checking whether source and destination element types are the same
    before casting. If they are same, skip the cast.
  - Custom float to integer cast behavior added from this PR:
    #3946
    TLDR: PyTorch's and TOSA's float to integer casting behaviors are
    different (round to zero vs round to nearest, respectively), which
    requires a custom casting here.
  - Future `TODO`: add a --strict mode which includes
    `checkValidityOfCast()` to ensure that the casting pairs follow TOSA
    specifications.
* Update LIT tests.

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I2aef3c79d8f2d98b93e671d5b815b8eab33e697e
  • Loading branch information
justin-ngo-arm committed Feb 4, 2025
1 parent f7b32d7 commit da60c17
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
ArrayRef<int64_t> indexShape);

// Default function to create TOSA op with shift value
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs,
int32_t shift);
Expand All @@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
template <typename TosaOpT>
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value();
rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value();
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
}

Expand Down
8 changes: 4 additions & 4 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape,
std::optional<Type> dtype = {});

LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
// Default function to create tosa.cast op. This should be called instead of
// directly calling rewriter.create<tosa::CastOp>.
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
Expand Down
Loading

0 comments on commit da60c17

Please sign in to comment.