diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 5e22c57bbb6..141835349c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -16254,7 +16254,7 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam let hasCanonicalizer = 1; } -def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [NoSideEffect]> { +def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Helper operator for performing XLA-style broadcasts"; let description = [{ @@ -16276,6 +16276,13 @@ for binary operators. TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let extraClassDeclaration = [{ + // InferTypeOpInterface: + static bool isCompatibleReturnTypes(ArrayRef l, ArrayRef r) { + return ArraysAreCastCompatible(l, r); + } + }]; } def TF_XlaConvOp : TF_Op<"XlaConv", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 8b717ec5320..0f8a423124f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -2969,6 +2969,92 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// XlaBroadcastHelperOp +//===----------------------------------------------------------------------===// + +LogicalResult XlaBroadcastHelperOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto loc = location ? *location : mlir::UnknownLoc::get(context); + XlaBroadcastHelperOpAdaptor op(operands, attributes); + if (failed(op.verify(loc))) { + return failure(); + } + + Value lhs = op.lhs(); + Value rhs = op.rhs(); + auto set_unranked_results = [&]() { + auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs)); + inferredReturnTypes.push_back(unranked_lhs); + auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs)); + inferredReturnTypes.push_back(unranked_rhs); + return success(); + }; + + RankedTensorType lhs_ty = lhs.getType().dyn_cast(); + RankedTensorType rhs_ty = rhs.getType().dyn_cast(); + if (!lhs_ty || !rhs_ty) return set_unranked_results(); + + int64_t lhs_rank = lhs_ty.getRank(); + int64_t rhs_rank = rhs_ty.getRank(); + + DenseIntElementsAttr dims; + if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) { + return set_unranked_results(); + } + + if (dims.size() == 0) { + if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) { + return emitOptionalError( + location, + "if broadcast_dims is empty, both arguments must have equal rank or " + "at least one argument must be a scalar"); + } + inferredReturnTypes.push_back(lhs_ty); + inferredReturnTypes.push_back(rhs_ty); + return success(); + } + + const bool broadcast_lhs = lhs_rank < rhs_rank; + RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty; + RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty; + + if (dims.size() != min_rank_ty.getRank()) { + return emitOptionalError( + location, + "broadcast_dims must have size equal to the smaller argument rank"); + } + + int64_t output_rank = max_rank_ty.getRank(); + llvm::SmallVector broadcast_shape(output_rank, 1LL); + llvm::SmallVector is_broadcasted(output_rank, false); + for (auto item : llvm::enumerate(dims)) { + int64_t index = item.index(); + int64_t dim = item.value().getSExtValue(); + if (dim < 0 || dim > output_rank) { + return emitOptionalError(location, "out of range broadcast dim"); + } + if (is_broadcasted[dim]) { + return emitOptionalError(location, "broadcast_dims has duplicates"); + } + broadcast_shape[dim] = min_rank_ty.getDimSize(index); + is_broadcasted[dim] = true; + } + + if (broadcast_lhs) { + inferredReturnTypes.push_back( + RankedTensorType::get(broadcast_shape, lhs_ty.getElementType())); + inferredReturnTypes.push_back(rhs_ty); + } else { + inferredReturnTypes.push_back(lhs_ty); + inferredReturnTypes.push_back( + RankedTensorType::get(broadcast_shape, rhs_ty.getElementType())); + } + return success(); +} + //===----------------------------------------------------------------------===// // XlaSetDynamicDimensionSizeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 1b1ae779a80..0f3b5417c3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -4159,3 +4159,54 @@ func @testVarHandleOp() -> tensor<*x!tf.resource> { } : () -> tensor<*x!tf.resource> return %0 : tensor<*x!tf.resource> } + +// ----- + +func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> + // expected-error @+1 {{broadcast_dims must have size equal to the smaller argument rank}} + %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<1xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>) + return +} + +// ----- + +func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () { + %0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> + // expected-error @+1 {{if broadcast_dims is empty, both arguments must have equal rank or at least one argument must be a scalar}} + %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<0xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>) + return +} + +// ----- + +func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () { + %0 = "tf.Const"() {value = dense<0> : tensor<2xi64>} : () -> tensor<2xi64> + // expected-error @+1 {{broadcast_dims has duplicates}} + %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>) + return +} + +// ----- + +func @testXlaBroadcastHelper(%arg0: tensor<2xi32>, %arg1: tensor) -> () { + %0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> + %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2xi32>, tensor, tensor<0xi64>) -> (tensor<2xi32>, tensor) + return +} + +// ----- + +func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () { + %0 = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>) + return +} + +// ----- + +func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () { + %0 = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<2xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>) + return +} diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 8c1a67f0e87..7a99c27a075 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -101,7 +101,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(operand, start_indices), expected=np.array([[5, 6, 7]])) - @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightLogical(self): self._assertOpOutputMatchesExpected( xla.shift_right_logical, @@ -113,7 +112,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) - @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightArithmetic(self): self._assertOpOutputMatchesExpected( xla.shift_right_arithmetic,