diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 398926e8168a..4041e522fca1 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, SmallVector indiceOneDimShape, int32_t dim, ArrayRef indexShape); +// Default function to create TOSA op with shift value mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs, int32_t shift); @@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, template TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs) { - lhs = promoteType(rewriter, lhs, outType); - rhs = promoteType(rewriter, rhs, outType); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); return CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index cf2b8a2df075..1b944b9a1ea4 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -63,10 +63,10 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, ArrayRef shape, std::optional dtype = {}); -LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, - Value src, Type destType, Value &result); - -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); +// Default function to create tosa.cast op. This should be called instead of +// directly calling rewriter.create. +std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, + TensorType destType); // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index cbe6e8b87fc0..d0c827e86b2a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -63,7 +63,7 @@ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { // Non floating point inputs are not supported in TOSA so we cast the input // to result type if (!isa(selfTy.getElementType())) - self = tosa::promoteType(rewriter, self, resultTy); + self = tosa::tosaCastTensorToType(rewriter, self, resultTy).value(); rewriter.replaceOpWithNewOp(op, resultTy, self); @@ -87,7 +87,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern { OpConversionPattern::getTypeConverter()->convertType( op.getType())); - self = tosa::promoteType(rewriter, self, outType); + self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); rewriter.replaceOpWithNewOp(op, outType, self); @@ -130,8 +130,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern { /*round=*/false); } else if constexpr (std::is_same() || std::is_same()) { - lhs = tosa::promoteType(rewriter, lhs, outTy); - rhs = tosa::promoteType(rewriter, rhs, outTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outTy).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outTy).value(); // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum and // tosa.minimum binaryOp = rewriter.create( @@ -335,10 +335,11 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (rhsType.getElementType() != rhsAlphaMulElemType) { // right is tensor, rhsType == tensor // right must be cast to same type as the alpha, so MulOp success - rhs = rewriter.create( - op->getLoc(), - RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), - rhs); + rhs = + tosa::tosaCastTensorToType( + rewriter, rhs, + RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType)) + .value(); // reinitialize right value type to tensor rhsType = dyn_cast(rhs.getType()); } @@ -381,7 +382,10 @@ class ConvertAtenAddSubOp : public OpConversionPattern { mulAlphaOp); // cast tensor back to tensor - rewriter.replaceOpWithNewOp(op, outType, addOrSubi64Op); + auto result = + tosa::tosaCastTensorToType(rewriter, addOrSubi64Op, outType).value(); + rewriter.replaceOp(op, result); + return success(); } @@ -456,8 +460,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { OpConversionPattern::getTypeConverter()->convertType( op.getType())); if (isBitwiseOp) { - lhs = tosa::promoteType(rewriter, lhs, resultTy); - rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value(); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value(); } // Support different types comparisons @@ -466,24 +471,27 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (lhsElemTy != rhsElemTy && !isBitwiseOp) { if (isLhsElemFloat && !isRhsElemFloat) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, lhsTy).value(); } else if (!isLhsElemFloat && isRhsElemFloat) { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, rhsTensorTy).value(); } else if (isLhsElemFloat && isRhsElemFloat) { auto lhsElemFloatTy = dyn_cast(lhsElemTy); auto rhsElemFloatTy = dyn_cast(rhsElemTy); if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, lhsTy).value(); } else { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, rhsTensorTy).value(); } } else { auto lhsElemIntTy = dyn_cast(lhsElemTy); auto rhsElemIntTy = dyn_cast(rhsElemTy); if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, lhsTy).value(); } else { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, rhsTensorTy).value(); } } } @@ -629,7 +637,7 @@ std::optional truncFloatDivWithDivResult(PatternRewriter &rewriter, // towards zero) for float type inputs Value truncFloatDiv(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs) { - rhs = tosa::promoteType(rewriter, rhs, outType); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); auto rhsRcp = rewriter.create(op->getLoc(), rhs.getType(), rhs); @@ -655,8 +663,8 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); - lhs = tosa::promoteType(rewriter, lhs, i32Type); - rhs = tosa::promoteType(rewriter, rhs, i32Type); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, i32Type).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, i32Type).value(); auto intDivOp = rewriter.create(op->getLoc(), i32Type, lhs, rhs); @@ -696,7 +704,8 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, auto selectOp = rewriter.create(op->getLoc(), i32Type, cond, truncMinusOne, intDivOp); - Value result = tosa::promoteType(rewriter, selectOp, outType); + Value result = + tosa::tosaCastTensorToType(rewriter, selectOp, outType).value(); return result; } @@ -755,7 +764,8 @@ class ConvertAtenDivOp : public OpConversionPattern { // The input to the reciprocal is an integer sometimes, and we may need // to promote it to a floating point. Per TOSA specification, the input // types can only be floating point for tosa::ReciprocalOp. - rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, outType).value(); auto rhsRcp = rewriter.create( op->getLoc(), rhsTensor.getType(), rhsTensor); @@ -792,13 +802,15 @@ class ConvertAtenDivOp : public OpConversionPattern { // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); - lhs = tosa::promoteType(rewriter, lhs, i32Type); - rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, i32Type).value(); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, i32Type).value(); auto intDivOp = rewriter.create(op->getLoc(), i32Type, lhs, rhsTensor); - result = tosa::promoteType(rewriter, intDivOp, outType); + result = + tosa::tosaCastTensorToType(rewriter, intDivOp, outType).value(); } } @@ -843,7 +855,7 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern { // Non floating point inputs are not supported for activation functions // (erf, sigmoid, tanh) in TOSA so we cast the input to result type if (!isa(selfTy.getElementType())) - self = tosa::promoteType(rewriter, self, resultTy); + self = tosa::tosaCastTensorToType(rewriter, self, resultTy).value(); rewriter.replaceOpWithNewOp(op, resultTy, self); @@ -980,9 +992,11 @@ class ConvertAtenReductionOp : public OpConversionPattern { std::is_same() || std::is_same() || std::is_same()) { - self = tosa::promoteType( - rewriter, self, - RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1))); + self = tosa::tosaCastTensorToType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), + rewriter.getIntegerType(1))) + .value(); } // Handle dtype output and bool elem type for ReduceSum and ReduceProd ops @@ -1007,13 +1021,14 @@ class ConvertAtenReductionOp : public OpConversionPattern { dtypeType = rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth()); - self = tosa::promoteType( - rewriter, self, - RankedTensorType::get(selfTy.getShape(), dtypeType)); + self = tosa::tosaCastTensorToType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), dtypeType)) + .value(); } } else { if (selfElemTy.isInteger(1)) - self = tosa::promoteType(rewriter, self, outputTy); + self = tosa::tosaCastTensorToType(rewriter, self, outputTy).value(); } } @@ -1387,7 +1402,8 @@ class ConvertAtenPowOp : public OpConversionPattern { // Non floating point inputs are not supported for tosa.pow so we cast the // input to result type if (!isa(selfTy.getElementType())) - selfTensor = tosa::promoteType(rewriter, selfTensor, outType); + selfTensor = + tosa::tosaCastTensorToType(rewriter, selfTensor, outType).value(); } Value expTensor; @@ -1409,7 +1425,8 @@ class ConvertAtenPowOp : public OpConversionPattern { // Non floating point exponents are not supported for tosa.pow so we cast // the exponent to result type if (!isa(expTy.getElementType())) - expTensor = tosa::promoteType(rewriter, expTensor, outType); + expTensor = + tosa::tosaCastTensorToType(rewriter, expTensor, outType).value(); } if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), selfTensor, expTensor) @@ -2217,7 +2234,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(getTypeConverter()->convertType(op.getType())); auto resultElemTy = resultTy.getElementType(); - self = tosa::promoteType(rewriter, self, resultTy); + self = tosa::tosaCastTensorToType(rewriter, self, resultTy).value(); Value otherTensor, alphaTensor; @@ -3077,7 +3094,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to output type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, outType); + self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); @@ -3222,7 +3239,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getInput().getType()); + auto self = adaptor.getInput(); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3236,8 +3254,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (train) return rewriter.notifyMatchFailure(op, "train must be false"); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getInput()); + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto result = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + + rewriter.replaceOp(op, result); return success(); } @@ -3439,9 +3460,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // GELU(x) = x * CDF(x) Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy).value(); - cdf = rewriter.createOrFold( - op->getLoc(), - cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + cdf = tosa::tosaCastTensorToType(rewriter, cdf, selfType).value(); auto result = tosa::createMulOpAndCast(rewriter, op, resultType, self, cdf, /*shift=*/0); @@ -3766,11 +3785,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesType.getElementType()), indices, rewriter.getDenseI64ArrayAttr(newIndicesShape)); - auto castIndices = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), - rewriter.getIntegerType(32)), - reshapedIndices); + auto castIndices = + tosa::tosaCastTensorToType( + rewriter, reshapedIndices, + RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), + rewriter.getIntegerType(32))) + .value(); SmallVector intermediateOutShape = {1, numIndices, weightShape[1]}; auto gatherOp = rewriter.create( @@ -4155,11 +4175,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatitable if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexType.getShape(), - rewriter.getIntegerType(32)), - index); + index = tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexType.getShape(), + rewriter.getIntegerType(32))) + .value(); } // Get positive dim @@ -4241,9 +4261,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + index = tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } // Get positive dim @@ -4380,10 +4401,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); + index = + tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); // Expand last dim of index to tf indices [3] -> [3,1] // convert [0,0,0] to [[0],[0],[0]] @@ -4523,10 +4545,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Make type of index tosa compatible, i64 to i32. if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); + index = + tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, @@ -4693,10 +4716,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indexShape = indexType.getShape(); // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); + index = + tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } index = @@ -4782,9 +4806,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatitable if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + index = tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } // Get positive dim @@ -5223,9 +5248,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - self = tosa::promoteType(rewriter, self, resultType); - min = tosa::promoteType(rewriter, min, resultType); - max = tosa::promoteType(rewriter, max, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + min = tosa::tosaCastTensorToType(rewriter, min, resultType).value(); + max = tosa::tosaCastTensorToType(rewriter, max, resultType).value(); // max(xi, min_valuei) // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum @@ -5397,8 +5422,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to generate constant tensor for arange"); } auto result = maybeResult.value(); + result = tosa::tosaCastTensorToType(rewriter, result, resultType).value(); + + rewriter.replaceOp(op, result); - rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } @@ -5450,6 +5477,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + // The non_blocking should be a constant `False`. bool nonBlocking; if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) { @@ -5466,12 +5496,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) { // If we reach here, then it means the given case is handled by implicit // broadcasting done by tosa. - Value result; - if (failed(tosa::tosaCastTensorToType( - rewriter, op, adaptor.getSrc(), - getTypeConverter()->convertType(op.getType()), result))) - return rewriter.notifyMatchFailure( - op, "unimplemented: cast to result type not supported"); + Value result = + tosa::tosaCastTensorToType(rewriter, adaptor.getSrc(), resultTy) + .value(); rewriter.replaceOp(op, result); return success(); } @@ -5528,10 +5555,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultTy = cast( getTypeConverter()->convertType(op.getResult().getType())); - Value result; - if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(), - resultTy, result))) - return rewriter.notifyMatchFailure(op, "conversion to result type failed"); + Value result = + tosa::tosaCastTensorToType(rewriter, adaptor.getSelf(), resultTy).value(); rewriter.replaceOp(op, result); return success(); @@ -5589,7 +5614,7 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { std::is_same(); if (selfTy.getElementType() != outElemTy) - self = rewriter.create(op.getLoc(), outType, self); + self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); Value divTensor; if (isRemainderOp) { @@ -5613,13 +5638,15 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); - self = tosa::promoteType(rewriter, self, i32Type); - otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type); + self = tosa::tosaCastTensorToType(rewriter, self, i32Type).value(); + otherTensor = + tosa::tosaCastTensorToType(rewriter, otherTensor, i32Type).value(); auto intDivTensor = rewriter.create( op->getLoc(), i32Type, self, otherTensor); - divTensor = tosa::promoteType(rewriter, intDivTensor, outType); + divTensor = + tosa::tosaCastTensorToType(rewriter, intDivTensor, outType).value(); } } @@ -6238,7 +6265,10 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { auto constOp = tosa::getConstTensor(rewriter, op, values, shape).value(); - rewriter.replaceOpWithNewOp(op, outType, constOp); + auto result = + tosa::tosaCastTensorToType(rewriter, constOp, outType).value(); + + rewriter.replaceOp(op, result); return success(); } @@ -6303,8 +6333,11 @@ class ConvertAtenFillOp : public OpConversionPattern { op, "Fill value must be a scalar constant"); } - rewriter.replaceOpWithNewOp(op, outType, - fillValueTargetTensor); + auto result = + tosa::tosaCastTensorToType(rewriter, fillValueTargetTensor, outType) + .value(); + + rewriter.replaceOp(op, result); return success(); } @@ -6361,10 +6394,8 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { auto rhsTensor = rhsType ? rhs : rhsAsTensor; auto rhsTensorType = dyn_cast(rhsTensor.getType()); if (rhsTensorType.getElementType() != outElemTy) - rhsTensor = rewriter.create( - op.getLoc(), - RankedTensorType::get(rhsTensorType.getShape(), outElemTy), - rhsTensor); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, outType).value(); if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, rhsTensor) .failed()) @@ -6399,7 +6430,11 @@ class ConvertAtenCloneOp : public OpConversionPattern { auto outType = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); + + auto result = + tosa::tosaCastTensorToType(rewriter, adaptor.getSelf(), outType) + .value(); + rewriter.replaceOp(op, result); return success(); } @@ -6504,7 +6539,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); for (auto &tensor : builtinTensors) - tensor = tosa::promoteType(rewriter, tensor, outType); + tensor = tosa::tosaCastTensorToType(rewriter, tensor, outType).value(); auto result = tosa::CreateOpAndInfer( rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); @@ -6528,11 +6563,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( cast(typeConverter->convertType(op.getType())); auto elementType = resultType.getElementType(); - if (isa(selfTy.getElementType())) { - self = rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType), - self); - } + if (isa(selfTy.getElementType())) + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, elementType).value(); @@ -7401,7 +7433,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( selfType.getElementType()) .value(); - result = tosa::promoteType(rewriter, result, resultType); + result = tosa::tosaCastTensorToType(rewriter, result, resultType).value(); rewriter.replaceOp(op, {result}); @@ -7467,8 +7499,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()), threshold, self); - self = tosa::promoteType(rewriter, self, resultType); - grad = tosa::promoteType(rewriter, grad, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + grad = tosa::tosaCastTensorToType(rewriter, grad, resultType).value(); if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, zero).failed() || mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, grad).failed()) @@ -8118,8 +8150,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(typeConverter->convertType(op.getType())); auto resultShape = resultType.getShape(); - self = tosa::promoteType(rewriter, self, resultType); - vec2 = tosa::promoteType(rewriter, vec2, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + vec2 = tosa::tosaCastTensorToType(rewriter, vec2, resultType).value(); SmallVector resultShapeIndex1Replaced({resultShape[0], 1}); SmallVector resultShapeIndex0Replaced({1, resultShape[1]}); @@ -8371,7 +8403,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); bool isEpsNone = isa(op.getEps().getType()); @@ -8445,7 +8477,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); @@ -8490,7 +8522,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto ten = tosa::getConstTensor(rewriter, op, 10.0f, {}, resultElemTy) .value(); @@ -8544,7 +8576,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); @@ -8585,7 +8617,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Non floating point inputs are not supported in TOSA so we cast the input // to result type if (!isa(selfType.getElementType())) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto sinOp = rewriter.create(op->getLoc(), resultType, self); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 6448798b7661..68041403cdd1 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -118,8 +118,8 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs, int32_t shift) { - lhs = promoteType(rewriter, lhs, outType); - rhs = promoteType(rewriter, rhs, outType); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); auto constShift = tosa::getTosaMulShiftConstTensor(rewriter, op, shift); @@ -138,8 +138,8 @@ createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, op, "tosa.int_div only supports integer type"); } - lhs = promoteType(rewriter, lhs, outType); - rhs = promoteType(rewriter, rhs, outType); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } @@ -1069,7 +1069,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, } auto input_value_casted = - tosa::promoteType(rewriter, input_value, output_type); + tosa::tosaCastTensorToType(rewriter, input_value, output_type).value(); auto absVal = CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(input_type.getShape(), elemType), diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index bee304e16c7e..897f06ebfb3c 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -213,8 +213,9 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + return tosa::tosaCastTensorToType(rewriter, const_op, + RankedTensorType::get(shape, *dtype)) + .value(); } return const_op.getResult(); } @@ -243,8 +244,9 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + return tosa::tosaCastTensorToType(rewriter, const_op, + RankedTensorType::get(shape, *dtype)) + .value(); } return const_op.getResult(); } @@ -272,8 +274,9 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + return tosa::tosaCastTensorToType(rewriter, const_op, + RankedTensorType::get(shape, *dtype)) + .value(); } return const_op.getResult(); } @@ -346,10 +349,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, return failure(); } -// Template specialization for float -LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, - Value src, Type destType, Value &result) { - +// Default function to create tosa.cast op. This should be called instead of +// directly calling rewriter.create. +std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, + TensorType destType) { + Operation *op = src.getDefiningOp(); TensorType srcType = dyn_cast(src.getType()); Type srcElemTy = srcType.getElementType(); Type destElemTy = dyn_cast(destType).getElementType(); @@ -363,76 +367,36 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, // casting only when needed (the default value of `--strict` mode will be // off). // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) - // return rewriter.notifyMatchFailure( - // op, "casting to result dtype is invalid or unsupported"); - - if (destElemTy.isInteger(1)) { - auto srcType = dyn_cast(src.getType()); - SmallVector srcShape(srcType.getShape()); - uint64_t num_total_elements = 1; - for (int64_t a : srcShape) - num_total_elements *= a; - - std::optional constOp; - if (srcElemTy.isInteger(64)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(32)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isF32()) { - SmallVector values(num_total_elements, 0.0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(8)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } - Value equalToZero = rewriter.create(op->getLoc(), destType, - src, constOp.value()); - result = rewriter.create(op->getLoc(), destType, - equalToZero); - } else { - if (llvm::isa(srcElemTy) && destElemTy.isInteger()) { - // for float->int conversion, tosa.cast performs round-to-nearest - // torch performs round-to-zero instead - // generate round-to-zero conversion prior to tosa.cast to match with - // expected torch behavior - auto floor = rewriter.create(op->getLoc(), srcType, src); - auto ceil = rewriter.create(op->getLoc(), srcType, src); - - auto zeroValue = - tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); - - if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) - .failed()) - return rewriter.notifyMatchFailure( - op, "Failed to equalize ranks among operands and result"); - - auto boolType = srcType.clone(rewriter.getIntegerType(1)); - auto isNegative = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), boolType, zeroValue, src); - src = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), srcType, isNegative, ceil, floor); - } - result = rewriter.create(op->getLoc(), destType, src); + // return std::nullopt; + + if (srcElemTy == destElemTy) + return src; + + if (llvm::isa(srcElemTy) && destElemTy.isInteger() && + !destElemTy.isInteger(1)) { + // For float->int conversion, tosa.cast performs round-to-nearest. + // PyTorch performs round-to-zero instead. + // Generate round-to-zero conversion prior to tosa.cast to match with + // expected torch behavior. + auto floor = rewriter.create(op->getLoc(), srcType, src); + auto ceil = rewriter.create(op->getLoc(), srcType, src); + + auto zeroValue = + tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) + return std::nullopt; + + auto boolType = srcType.clone(rewriter.getIntegerType(1)); + auto isNegative = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, src); + src = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), srcType, isNegative, ceil, floor); } - return success(); -} -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { - Operation *op = input.getDefiningOp(); - TensorType inType = cast(input.getType()); - - if (inType.getElementType() != outType.getElementType()) { - TensorType promotedType = - inType.cloneWith(inType.getShape(), outType.getElementType()); - return rewriter.create(op->getLoc(), promotedType, input); - } - return input; + TensorType castedSrcType = srcType.clone(destElemTy); + return rewriter.create(op->getLoc(), castedSrcType, src); } // Template instantiation diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 267b05b93aa5..3e32a3cf3494 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -898,13 +898,10 @@ func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { // ----- // CHECK-LABEL: func.func @torch.aten.dropout$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[VAL_3:.*]] = torch.constant.bool false -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: return %[[VAL_0]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { %float0.000000e00 = torch.constant.float 0.000000e+00 @@ -984,14 +981,14 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> { // ----- // CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[CST5:.*]] = torch.constant.int 5 -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : () -> tensor<5xi64> -// CHECK: %[[VAL_1:.*]] = tosa.cast %[[VAL_0]] : (tensor<5xi64>) -> tensor<5xi64> -// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %1 : tensor<5xi64> -> !torch.vtensor<[5],si64> -// CHECK: return %[[VAL_2]] : !torch.vtensor<[5],si64> +// CHECK: %[[VAL_0:.*]] = torch.constant.none +// CHECK: %[[VAL_1:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 5 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : () -> tensor<5xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<5xi64> -> !torch.vtensor<[5],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[5],si64> +// CHECK: } func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { %none = torch.constant.none %int0 = torch.constant.int 0 @@ -1015,23 +1012,24 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { // ----- // CHECK-LABEL: func.func @torch.aten.copy( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { -// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8> -// CHECK: %[[CST5:.*]] = torch.constant.int 5 -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_2:.*]] = tosa.equal %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.logical_not %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x5x5xi8>}> : () -> tensor<1x1x5x5xi8> -// CHECK: %[[VAL_5:.*]] = tosa.equal %[[INP]], %[[VAL_4]] : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_6:.*]] = tosa.logical_not %[[VAL_5]] : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 5 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 11 +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.bool false +// CHECK: %[[VAL_7:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi1> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<[1, 1, 5, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_13:.*]] = tosa.tile %[[VAL_11]], %[[VAL_12]] : (tensor<1x1x1x1xi1>, !tosa.shape<4>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[1,1,5,5],i1> +// CHECK: } func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { %int5 = torch.constant.int 5 %int1 = torch.constant.int 1 @@ -1048,18 +1046,17 @@ func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtens } // ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { -// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64> -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x5xi64>}> : () -> tensor<3x5xi64> -// CHECK: %[[VAL_1:.*]] = tosa.equal %[[INP]], %[[VAL_0]] : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> -// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<3x5xi1>) -> tensor<3x5xi1> -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1> -func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { +// CHECK-LABEL: func.func @torch.aten.to.dtype$toBool( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 11 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,5],i1> +// CHECK: } +func.func @torch.aten.to.dtype$toBool(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { %int11 = torch.constant.int 11 %none = torch.constant.none %false = torch.constant.bool false @@ -1068,7 +1065,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten } // ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( +// CHECK-LABEL: func.func @torch.aten.to.dtype$fromBool( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,128],i1> -> tensor<1x128xi1> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 @@ -1078,7 +1075,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64> // CHECK: return %[[VAL_6]] : !torch.vtensor<[1,128],si64> // CHECK: } -func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { +func.func @torch.aten.to.dtype$fromBool(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { %int4 = torch.constant.int 4 %none = torch.constant.none %false = torch.constant.bool false @@ -2060,9 +2057,8 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { // CHECK: %[[VAL_1:.*]] = torch.constant.int 0 // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { %int0 = torch.constant.int 0 @@ -2189,9 +2185,8 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<3x4xi32>) -> tensor<3x4xi64> // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi64>}> : () -> tensor<3x4xi64> -// CHECK: %[[VAL_10:.*]] = tosa.cast %[[VAL_9]] : (tensor<3x4xi64>) -> tensor<3x4xi64> -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],si64> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],si64> // CHECK: } func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { %int0 = torch.constant.int 0 @@ -3092,7 +3087,7 @@ func.func @torch.aten.erf$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> // CHECK: %[[VAL_2:.*]] = torch.constant.float 1.100000e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.100000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1xf64> // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> // CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_5]], %[[VAL_6]] : (tensor<1xf64>, tensor<4xf64>) -> tensor<4xi1> @@ -3318,12 +3313,14 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_12]] : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_4]], %[[VAL_12]] : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> -// CHECK: %[[VAL_15:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>) -> tensor<5x14x24x10xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_15]], %[[VAL_16]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> -// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_19:.*]] = tosa.transpose %[[VAL_17]], %[[VAL_18]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[VAL_20:.*]] = tensor.cast %[[VAL_19]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[5,10,14,24],f32> // CHECK: } func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { %false = torch.constant.bool false @@ -3360,12 +3357,14 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_15]] : (tensor<4x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x4x1xf32> // CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<3x3x4x1xf32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_17]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>) -> tensor<5x5x10x4xf32> -// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_18]], %[[VAL_19]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> -// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_17]], %[[VAL_12]], %[[VAL_18]], %[[VAL_19]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_22:.*]] = tosa.transpose %[[VAL_20]], %[[VAL_21]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[VAL_23:.*]] = tensor.cast %[[VAL_22]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> +// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[VAL_24]] : !torch.vtensor<[5,4,5,10],f32> // CHECK: } func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { %false = torch.constant.bool false