diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 10da458a4ea..8f0f000b26a 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -660,6 +660,7 @@ cc_library( deps = [ ":attribute_importer", ":hlo", + ":hlo_module_importer", ":hlo_utils", ":type_to_shape", "//tensorflow/compiler/xla:comparison_util", diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 774caab77fb..d98e6375f7e 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -19,10 +19,12 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/attribute_importer.h" +#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -140,6 +142,24 @@ StatusOr MlirHloBuilder::GatherInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::ScatterInternal( + const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, + bool unique_indices) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates), + ConvertScatterDimensionNumbers(dimension_numbers, &builder_), + builder_.getBoolAttr(indices_are_sorted), + builder_.getBoolAttr(unique_indices)); + + TF_RETURN_IF_ERROR( + ImportComputation(update_computation.proto(), &op.update_computation())); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::RngOpInternal( RandomDistribution distribution, absl::Span parameters, const Shape& shape) { @@ -348,6 +368,18 @@ StatusOr MlirHloBuilder::CreateOp( return MakeXlaOp(op->getResult(0)); } +Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation, + mlir::Region* region) { + TF_ASSIGN_OR_RETURN(auto module_config, + xla::HloModule::CreateModuleConfigFromProto( + computation, xla::DebugOptions())); + TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto( + computation, module_config)); + + return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(), + region, &builder_); +} + StatusOr MlirHloBuilder::GetShapePtr(XlaOp op) const { TF_RETURN_IF_ERROR(first_error()); TF_RETURN_IF_ERROR(CheckOpBuilder(op)); diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index fc5baaee44d..0b6bacbfff6 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -129,6 +129,12 @@ class MlirHloBuilder : public XlaBuilder { const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, bool indices_are_sorted) override; + StatusOr ScatterInternal( + const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, + bool unique_indices) override; + StatusOr RngOpInternal(RandomDistribution distribution, absl::Span parameters, const Shape& shape) override; @@ -196,6 +202,9 @@ class MlirHloBuilder : public XlaBuilder { llvm::ArrayRef operands, llvm::ArrayRef attributes = {}); + Status ImportComputation(const HloModuleProto& computation, + mlir::Region* region); + mlir::OpBuilder builder_; mlir::Location loc_; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index a808b877867..3f99d71494e 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -187,6 +187,33 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2 return %0: tensor<3x4xi32> } +// CHECK-LABEL: @sparse_to_dense +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor) +func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor) -> tensor<3x3xf32> { + +// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32> +// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> + +// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // no predecessors +// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: indices_are_sorted = false +// CHECK-SAME: scatter_dimension_numbers +// CHECK-SAME: index_vector_dim = 1 : i64 +// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> +// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> +// CHECK-SAME: update_window_dims = dense<[]> : tensor<0xi64> +// CHECK-SAME: unique_indices = false +// CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32> + +// return %[[RESULT]] : tensor<3x3xf32> + + %cst = xla_hlo.constant dense<3> : tensor<2xi32> + %0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index b15974979c9..659cbbe8ebc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -156,6 +156,7 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 91b3ecdfedb..595bef42a5a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1218,6 +1218,7 @@ tf_xla_py_test( name = "sparse_to_dense_op_test", size = "small", srcs = ["sparse_to_dense_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py index dbfdc3b7247..d80f9ddd702 100644 --- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test @@ -101,6 +102,7 @@ class SparseToDenseTest(xla_test.XLATestCase): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): _SparseToDense([1, 3], [[5], [3]], 1, -1) + @test_util.disable_mlir_bridge("Error handling") def testBadValue(self): with self.session(), self.test_scope(): with self.assertRaisesOpError( @@ -108,12 +110,14 @@ class SparseToDenseTest(xla_test.XLATestCase): r"should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [[5], [3]], -1) + @test_util.disable_mlir_bridge("Error handling") def testBadNumValues(self): with self.session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [1, 2, 3], -1) + @test_util.disable_mlir_bridge("Error handling") def testBadDefault(self): with self.session(), self.test_scope(): with self.assertRaisesOpError("default_value should be a scalar"): diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 58365c0f498..440fbebfa5e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1899,11 +1899,6 @@ XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, bool unique_indices) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - instr.set_indices_are_sorted(indices_are_sorted); - - instr.set_unique_indices(unique_indices); - TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input)); TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape, GetShapePtr(scatter_indices)); @@ -1914,8 +1909,22 @@ XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, Shape shape, ShapeInference::InferScatterShape( *input_shape, *scatter_indices_shape, *updates_shape, to_apply_shape, dimension_numbers)); - *instr.mutable_shape() = shape.ToProto(); + return ScatterInternal(shape, input, scatter_indices, updates, + update_computation, dimension_numbers, + indices_are_sorted, unique_indices); + }); +} +StatusOr XlaBuilder::ScatterInternal( + const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, + bool unique_indices) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + instr.set_indices_are_sorted(indices_are_sorted); + instr.set_unique_indices(unique_indices); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_scatter_dimension_numbers() = dimension_numbers; AddCalledComputation(update_computation, &instr); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 77b6912e51b..82f8cdbabce 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -653,6 +653,12 @@ class XlaBuilder { const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted = false, bool unique_indices = false); + virtual StatusOr ScatterInternal( + const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, + bool unique_indices); + void Send(XlaOp operand, const ChannelHandle& handle); XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);