From 170f0b35076b9d3048f0615fe2cdd9cff8da494c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 17 May 2017 12:47:47 -0700 Subject: [PATCH] [TF:XLA] Add XLA implementation of ResourceStridedSliceAssign. PiperOrigin-RevId: 156341053 --- .../compiler/tests/variable_ops_test.py | 67 ++++++++++- tensorflow/compiler/tf2xla/const_analysis.cc | 3 + .../tf2xla/kernels/strided_slice_op.cc | 113 ++++++++++++++++++ tensorflow/compiler/tf2xla/xla_compiler.cc | 2 + tensorflow/compiler/tf2xla/xla_op_kernel.cc | 1 + 5 files changed, 185 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fef390fd67f..70dacd9de4b 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for XLA JIT compiler.""" +"""Tests for reading and writing variables.""" from __future__ import absolute_import from __future__ import division @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -114,5 +116,68 @@ class VariableOpsTest(XLATestCase): self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) +class StridedSliceAssignChecker(object): + """Compares the results of a slice assignment using Tensorflow and numpy.""" + + def __init__(self, test, x, dtype): + self.dtype = dtype + self.test = test + self.x_np = np.array(x).astype(dtype) + + def __setitem__(self, index, value): + value = np.array(value).astype(self.dtype) + + with self.test.test_session() as sess, self.test.test_scope(): + x = constant_op.constant(self.x_np, dtype=self.dtype) + var = resource_variable_ops.ResourceVariable(x) + sess.run(variables.variables_initializer([var])) + val = sess.run(var[index].assign(value)) + # val_copy is used to check that tf.assign works equivalently to the + # assign method above. + val_copy = sess.run(state_ops.assign(var[index], value)) + valnp = np.copy(self.x_np) + valnp[index] = np.array(value) + self.test.assertAllEqual(val, valnp) + self.test.assertAllEqual(val_copy, valnp) + + +class SliceAssignTest(XLATestCase): + + def testSliceAssign(self): + for dtype in self.numeric_types: + checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], + dtype=dtype) + # No-op assignment + checker[:] = [[10, 20, 30], [40, 50, 60]] + # Checks trivial (1,1) shape tensor + checker[1:2, 1:2] = [[66]] + # shrink shape changes + checker[1:2, 1] = [66] + checker[1, 1:2] = [66] + checker[1, 1] = 66 + # newaxis shape changes + checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] + # shrink and newaxis + checker[None, None, 0, 0:1] = [[[99]]] + # Non unit strides + checker[::1, 1::-1] = [[3, 33], [4, 44]] + # degenerate interval + checker[8:10, 0] = [] + checker[8:10, 8:10] = [[]] + + # Assign vector to scalar (rank-0) using newaxis + checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis + checker2[None] = [6] # new axis + + def testUninitialized(self): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "uninitialized variable"): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable([1, 2]) + sess.run(v[:].assign([1, 2])) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 4adc17b8382..c4cbaebb258 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -68,6 +68,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"Range", "limit"}, {"Range", "delta"}, {"Reshape", "shape"}, + {"ResourceStridedSliceAssign", "begin"}, + {"ResourceStridedSliceAssign", "end"}, + {"ResourceStridedSliceAssign", "strides"}, {"Reverse", "dims"}, {"ReverseV2", "axis"}, {"Slice", "begin"}, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 03e02299e33..211412d463d 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -219,5 +219,118 @@ class StridedSliceGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); +class StridedSliceAssignOp : public XlaOpKernel { + public: + explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape final_shape; + gtl::InlinedVector begin; + gtl::InlinedVector end; + gtl::InlinedVector strides; + + xla::Literal begin_literal, end_literal, strides_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + + Tensor begin_tensor, end_tensor, strides_tensor; + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); + OP_REQUIRES_OK(ctx, + LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + + DataType lhs_type; + TensorShape lhs_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + + const TensorShape rhs_shape = ctx->InputShape(4); + + TensorShape dummy_processing_shape; + ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); + ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( + &dummy_processing_shape); + bool dummy = false; + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, + &dummy, &dummy, &begin, &end, &strides)); + + if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { + // DynamicUpdateSlice does not allow 0-element updates. We should probably + // check that rhs_shape can be broadcast to final_shape, but that is + // probably better handled when implementing broadcasting more generally. + return; + } + + // TODO(aselle): This check is too strong, we only should need + // input_shape to be broadcastable to final_shape + OP_REQUIRES(ctx, final_shape == rhs_shape, + errors::Unimplemented( + "sliced l-value shape ", final_shape.DebugString(), + " does not match r-value shape ", rhs_shape.DebugString(), + ". Automatic broadcasting not yet implemented.")); + + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); + + xla::ComputationDataHandle rhs = ctx->Input(4); + + gtl::InlinedVector dimensions_to_reverse; + gtl::InlinedVector slice_begin, slice_dims; + for (int i = 0; i < begin.size(); ++i) { + // TODO(phawkins): implement strides != 1 + OP_REQUIRES( + ctx, strides[i] == 1 || strides[i] == -1, + errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + if (strides[i] > 0) { + slice_begin.push_back(begin[i]); + slice_dims.push_back(end[i] - begin[i]); + } else { + // Negative stride: swap begin and end, add 1 because the interval + // is semi-open, and mark the dimension to be reversed. + slice_begin.push_back(end[i] + 1); + slice_dims.push_back(begin[i] - end[i]); + dimensions_to_reverse.push_back(i); + } + } + + if (!dimensions_to_reverse.empty()) { + rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + } + rhs = ctx->builder()->Reshape(rhs, slice_dims); + + if (lhs_shape.dims() == 0) { + // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix + // and remove this workaround. + lhs = rhs; + } else { + lhs = ctx->builder()->DynamicUpdateSlice( + lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); + } + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + } + + private: + int32 begin_mask_, end_mask_; + int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + DataType index_type_; +}; + +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d4a917671b9..f8a9c5e9bc6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -467,6 +467,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->xla_output_shape.Swap( computation_shape.ValueOrDie()->mutable_result()); + VLOG(2) << "XLA output shape: " + << xla::ShapeUtil::HumanString(result->xla_output_shape); auto num_computation_outputs = (xla::ShapeUtil::IsTuple(result->xla_output_shape)) diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 48831ce4c27..4de69ee43c3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -357,6 +357,7 @@ void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) { Status XlaOpKernelContext::AssignVariable( int index, DataType type, const xla::ComputationDataHandle& handle) { + TF_RET_CHECK(handle.handle() != 0); SetOpHasSideEffects(); const XlaExpression* expression =