diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 63a277516ac..21b1ac5f0ea 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -132,6 +132,52 @@ StatusOr MlirHloBuilder::FftInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::ReduceInternal( + const Shape& shape, absl::Span all_operands, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + // Reduce takes two set of variadic operands inputs and init_values. + // all_operands contains both of these so split operands into two parts. + int64_t num_args = all_operands.size() / 2; + auto op = builder_.create( + loc_, GetValues(all_operands.first(num_args)), + GetValues(all_operands.subspan(num_args)), + GetI64ElementsAttr(dimensions_to_reduce, &builder_)); + TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body())); + if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0)); + auto tuple = builder_.create(loc_, op.getResults()); + return MakeXlaOp(tuple); +} + +StatusOr MlirHloBuilder::ReduceWindowInternal( + const Shape& shape, XlaOp operand, XlaOp init_value, + const XlaComputation& computation, Window window) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + llvm::SmallVector sizes, strides, base_dilations, win_dilations; + llvm::SmallVector padding; + for (const auto& dim : window.dimensions()) { + sizes.push_back(dim.size()); + strides.push_back(dim.stride()); + base_dilations.push_back(dim.base_dilation()); + win_dilations.push_back(dim.window_dilation()); + padding.push_back(dim.padding_low()); + padding.push_back(dim.padding_high()); + } + auto padding_ty = + mlir::RankedTensorType::get({static_cast(padding.size()) / 2, 2}, + builder_.getIntegerType(64)); + auto op = builder_.create( + loc_, ty, GetValue(operand), GetValue(init_value), + GetI64ElementsAttr(sizes, &builder_), + GetI64ElementsAttr(strides, &builder_), + GetI64ElementsAttr(base_dilations, &builder_), + GetI64ElementsAttr(win_dilations, &builder_), + mlir::DenseIntElementsAttr::get(padding_ty, padding)); + TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body())); + return MakeXlaOp(op); +} + XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 7d93f0b1eae..4b28c32db99 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -124,6 +124,16 @@ class MlirHloBuilder : public XlaBuilder { FftType fft_type, absl::Span fft_length) override; + StatusOr ReduceInternal( + const Shape& shape, absl::Span all_operands, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) override; + + StatusOr ReduceWindowInternal(const Shape& shape, XlaOp operand, + XlaOp init_value, + const XlaComputation& computation, + Window window) override; + XlaOp Iota(const Shape& shape, int64 iota_dimension) override; StatusOr TransposeInternal( diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index d7c92b95b40..b8a6df54519 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -236,6 +236,21 @@ func @mirror_pad(%arg0: tensor<2x3xcomplex>) -> tensor<4x7xcomplex> { return %1 : tensor<4x7xcomplex> } +// CHECK-LABEL: bucketize +func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { + // CHECK-NOT: tf.Bucketize + %0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32> + return %0 : tensor<2x5xi32> +} + +// CHECK-LABEL: arg_min +func @arg_min(%arg0: tensor<6xf64>) -> tensor { + // CHECK-NOT: ArgMin + %0 = xla_hlo.constant dense<0> : tensor + %1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor) -> tensor + return %1 : tensor +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 477bc654914..e57d6938efb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -132,6 +135,8 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c9876035da9..b574622efce 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -185,6 +185,7 @@ tf_xla_py_test( name = "argminmax_test", size = "small", srcs = ["argminmax_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -253,6 +254,7 @@ tf_xla_py_test( name = "bucketize_op_test", size = "small", srcs = ["bucketize_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -806,6 +808,7 @@ tf_xla_py_test( name = "lrn_ops_test", size = "medium", srcs = ["lrn_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index 75d06706a2d..f6b6d773135 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase): expected_out, sess.run(op, {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + @test_util.disable_mlir_bridge("Error handling") def testInvalidBoundariesOrder(self): with self.session() as sess: p = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 77556a72442..bfba48862f6 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2040,8 +2040,6 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, const XlaComputation& computation, absl::Span dimensions_to_reduce) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); @@ -2060,6 +2058,17 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, Shape shape, ShapeInference::InferReduceShape( operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); + return ReduceInternal(shape, all_operands, computation, + dimensions_to_reduce); + }); +} + +StatusOr XlaBuilder::ReduceInternal( + const Shape& shape, absl::Span all_operands, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions_to_reduce) { @@ -2067,7 +2076,6 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, } AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands); }); } @@ -2110,28 +2118,35 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( absl::Span window_dilations, absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + TF_ASSIGN_OR_RETURN(auto window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding, /*lhs_dilation=*/base_dilations, /*rhs_dilation=*/window_dilations)); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape( - *operand_shape, *init_shape, - instr.window(), to_apply_shape)); - *instr.mutable_shape() = shape.ToProto(); - - AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, - {operand, init_value}); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferReduceWindowShape( + *operand_shape, *init_shape, window, to_apply_shape)); + return ReduceWindowInternal(shape, operand, init_value, computation, + std::move(window)); }); } +StatusOr XlaBuilder::ReduceWindowInternal( + const Shape& shape, XlaOp operand, XlaOp init_value, + const XlaComputation& computation, Window window) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_window() = std::move(window); + + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, + {operand, init_value}); +} + XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, int64 feature_index) { return ReportErrorOrReturn([&]() -> StatusOr { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index d21ae66d365..ffa6a7c3439 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -542,6 +542,11 @@ class XlaBuilder { const XlaComputation& computation, absl::Span dimensions_to_reduce); + virtual StatusOr ReduceInternal( + const Shape& shape, absl::Span all_operands, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + XlaOp ReduceAll(XlaOp operand, XlaOp init_value, const XlaComputation& computation); @@ -558,6 +563,10 @@ class XlaBuilder { absl::Span window_dilations, absl::Span> padding); + virtual StatusOr ReduceWindowInternal( + const Shape& shape, XlaOp operand, XlaOp init_value, + const XlaComputation& computation, Window window); + XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {});