diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index cc334d8654f..461c357e509 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -282,6 +282,28 @@ StatusOr MlirHloBuilder::SliceInternal( GetI64ElementsAttr(strides, &builder_))); } +StatusOr MlirHloBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValues(start_indices), + GetI64ElementsAttr(slice_sizes, &builder_))); +} + +StatusOr MlirHloBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValue(update), + GetValues(start_indices))); +} + StatusOr MlirHloBuilder::PadInternal( const Shape& shape, XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 5a84d60cdc2..fc5baaee44d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -175,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder { absl::Span limit_indices, absl::Span strides) override; + StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) override; + + StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) override; + StatusOr PadInternal(const Shape& shape, XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) override; diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 01398eb7314..e8d5cfe997d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -163,6 +163,30 @@ func @truncated_normal() -> tensor<2x2xf32> { return %1 : tensor<2x2xf32> } +// CHECK-LABEL: dynamic_update_slice +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32> +func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> { + + // CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor + + // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor + + // CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) + + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> + return %0: tensor<3x4xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 86a2defd3a8..76657bd5e20 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -168,6 +168,8 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get() diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 1f83701ea7c..f3e915daa67 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -304,7 +304,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) - @test_util.disable_mlir_bridge('Not supported yet') def testDynamicSlice(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -317,7 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [[673, 674], [683, 684], [693, 694]]]), dtype=dtype)) - @test_util.disable_mlir_bridge('Not supported yet') + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectStartIndicesShape(self): with self.session() as session: with self.test_scope(): @@ -331,7 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) - @test_util.disable_mlir_bridge('Not supported yet') + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectSizeIndicesShape(self): with self.session() as session: with self.test_scope(): diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 6539817d524..a4e5b936153 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -864,8 +864,6 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, @@ -876,23 +874,28 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( *operand_shape, start_indices_shapes, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - std::vector operands = {operand}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + return DynamicSliceInternal(shape, operand, start_indices, slice_sizes); }); } +StatusOr XlaBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); +} + XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); std::vector start_indices_shape_ptrs; @@ -904,15 +907,22 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferDynamicUpdateSliceShape( *operand_shape, *update_shape, start_indices_shapes)); - *instr.mutable_shape() = shape.ToProto(); - - std::vector operands = {operand, update}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - operands); + return DynamicUpdateSliceInternal(shape, operand, update, start_indices); }); } +StatusOr XlaBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + operands); +} + XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 24b0cba3a1b..b631514248c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -423,9 +423,15 @@ class XlaBuilder { XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); + virtual StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); + virtual StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); virtual StatusOr ConcatInDimInternal(const Shape& shape,