From 1619f2f19f3280f36bd351142314f8a8248ac903 Mon Sep 17 00:00:00 2001 From: Russell Power <power@google.com> Date: Mon, 31 Aug 2020 09:16:55 -0700 Subject: [PATCH] Implement TensorStridedSliceAssign XLA op. PiperOrigin-RevId: 329315230 Change-Id: I5aca22493f5fa38fcd03a3f78f6d9e9afdaadb8b --- .../compiler/jit/mark_for_compilation_pass.cc | 4 ++- .../mlir/tensorflow/ir/tf_generated_ops.td | 34 +++++++++++++++++++ .../tf2xla/kernels/strided_slice_op.cc | 19 +++++++++-- .../python/kernel_tests/array_ops_test.py | 4 +++ 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 03ac7b0a59a..af0a192639c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1834,7 +1834,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() { "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}}; + "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex", + "TensorStridedSliceUpdate", + }}}; // clang-format on return result; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index faf7d428aea..9b201ba878e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -11651,6 +11651,40 @@ On GPU, if an out of bound index is found, the index is ignored. ]; } +def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> { + let summary = "Assign `value` to the sliced l-value reference of `input`."; + + let description = [{ +The values of `value` are assigned to the positions in the tensor `input` that +are selected by the slice parameters. The slice parameters `begin` `end` +`strides` etc. work exactly as in `StridedSlice`. + +NOTE this op currently does not support broadcasting and so `value`'s shape +must be exactly the shape produced by the slice of `input`. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$begin, + TF_I32OrI64Tensor:$end, + TF_I32OrI64Tensor:$strides, + TF_Tensor:$value, + + DefaultValuedAttr<I64Attr, "0">:$begin_mask, + DefaultValuedAttr<I64Attr, "0">:$end_mask, + DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask, + DefaultValuedAttr<I64Attr, "0">:$new_axis_mask, + DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; +} + def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> { let summary = "Constructs a tensor by tiling a given tensor."; diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 268317d84fc..6e5afd98e9d 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -446,7 +446,12 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape lhs_shape; xla::XlaOp lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); + if (ctx->input_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); + } else { + lhs_shape = ctx->InputShape(0); + lhs = ctx->Input(0); + } const TensorShape rhs_shape = ctx->InputShape(4); @@ -504,7 +509,11 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); + if (ctx->input_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); + } else { + ctx->SetOutput(0, lhs); + } } private: @@ -520,5 +529,11 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") .CompileTimeConstantInput("strides"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("TensorStridedSliceUpdate") + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("end") + .CompileTimeConstantInput("strides"), + StridedSliceAssignOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 391930e20d5..7714b010147 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1234,6 +1234,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaises(ValueError): sess.run(v[:].assign(too_small_val)) + @test_util.disable_xla("b/123559667") @test_util.run_in_graph_and_eager_modes def testTensorStridedSliceUpdateWithInputForward(self): """Tests tensor_strided_slice_update with input-forwarding taking effect.""" @@ -1243,6 +1244,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase): return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0]) self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2])))) + @test_util.disable_xla("b/123559667") @test_util.run_in_graph_and_eager_modes def testTensorStridedSliceUpdateNoInputForward(self): """Tests tensor_strided_slice_update with no input-forwarding.""" @@ -1254,6 +1256,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase): ans = y + z self.assertAllClose([1.6, 2.6], self.evaluate(ans)) + @test_util.disable_xla("b/123559667") def testTensorStridedSliceUpdateGradSimple(self): original = constant_op.constant([0.2, 0.3]) updates = constant_op.constant([0.4]) @@ -1272,6 +1275,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase): ([4], [5], [3], [1], [3], 1, 0, 0, 0, 0), ([2, 2, 3, 2], [0, 0, 1], [1, 0, 2], [1, 0, 1], [2, 3], 0, 0, 2, 0, 5) ])) + @test_util.disable_xla("b/123559667") def testTensorStridedSliceUpdateGrad( self, shape, begin, end, strides, updates_shape, *args): with self.cached_session():