[TF:XLA] Add XLA implementation of ResourceStridedSliceAssign.

PiperOrigin-RevId: 156341053
This commit is contained in:
Peter Hawkins 2017-05-17 12:47:47 -07:00 committed by TensorFlower Gardener
parent 1390dd68fe
commit 170f0b3507
5 changed files with 185 additions and 1 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA JIT compiler."""
"""Tests for reading and writing variables."""
from __future__ import absolute_import
from __future__ import division
@ -21,7 +21,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@ -114,5 +116,68 @@ class VariableOpsTest(XLATestCase):
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
class StridedSliceAssignChecker(object):
"""Compares the results of a slice assignment using Tensorflow and numpy."""
def __init__(self, test, x, dtype):
self.dtype = dtype
self.test = test
self.x_np = np.array(x).astype(dtype)
def __setitem__(self, index, value):
value = np.array(value).astype(self.dtype)
with self.test.test_session() as sess, self.test.test_scope():
x = constant_op.constant(self.x_np, dtype=self.dtype)
var = resource_variable_ops.ResourceVariable(x)
sess.run(variables.variables_initializer([var]))
val = sess.run(var[index].assign(value))
# val_copy is used to check that tf.assign works equivalently to the
# assign method above.
val_copy = sess.run(state_ops.assign(var[index], value))
valnp = np.copy(self.x_np)
valnp[index] = np.array(value)
self.test.assertAllEqual(val, valnp)
self.test.assertAllEqual(val_copy, valnp)
class SliceAssignTest(XLATestCase):
def testSliceAssign(self):
for dtype in self.numeric_types:
checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
dtype=dtype)
# No-op assignment
checker[:] = [[10, 20, 30], [40, 50, 60]]
# Checks trivial (1,1) shape tensor
checker[1:2, 1:2] = [[66]]
# shrink shape changes
checker[1:2, 1] = [66]
checker[1, 1:2] = [66]
checker[1, 1] = 66
# newaxis shape changes
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
# shrink and newaxis
checker[None, None, 0, 0:1] = [[[99]]]
# Non unit strides
checker[::1, 1::-1] = [[3, 33], [4, 44]]
# degenerate interval
checker[8:10, 0] = []
checker[8:10, 8:10] = [[]]
# Assign vector to scalar (rank-0) using newaxis
checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype)
checker2[()] = 6 # no indices
checker2[...] = 6 # ellipsis
checker2[None] = [6] # new axis
def testUninitialized(self):
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"uninitialized variable"):
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable([1, 2])
sess.run(v[:].assign([1, 2]))
if __name__ == "__main__":
googletest.main()

View File

@ -68,6 +68,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Range", "limit"},
{"Range", "delta"},
{"Reshape", "shape"},
{"ResourceStridedSliceAssign", "begin"},
{"ResourceStridedSliceAssign", "end"},
{"ResourceStridedSliceAssign", "strides"},
{"Reverse", "dims"},
{"ReverseV2", "axis"},
{"Slice", "begin"},

View File

@ -219,5 +219,118 @@ class StridedSliceGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
class StridedSliceAssignOp : public XlaOpKernel {
public:
explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape final_shape;
gtl::InlinedVector<int64, 4> begin;
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;
xla::Literal begin_literal, end_literal, strides_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
Tensor begin_tensor, end_tensor, strides_tensor;
OP_REQUIRES_OK(
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
OP_REQUIRES_OK(ctx,
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
DataType lhs_type;
TensorShape lhs_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
const TensorShape rhs_shape = ctx->InputShape(4);
TensorShape dummy_processing_shape;
ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape(
&dummy_processing_shape);
bool dummy = false;
OP_REQUIRES_OK(
ctx, ValidateStridedSliceOp(
&begin_tensor, &end_tensor, strides_tensor,
ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_,
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
&dummy, &dummy, &begin, &end, &strides));
if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
// DynamicUpdateSlice does not allow 0-element updates. We should probably
// check that rhs_shape can be broadcast to final_shape, but that is
// probably better handled when implementing broadcasting more generally.
return;
}
// TODO(aselle): This check is too strong, we only should need
// input_shape to be broadcastable to final_shape
OP_REQUIRES(ctx, final_shape == rhs_shape,
errors::Unimplemented(
"sliced l-value shape ", final_shape.DebugString(),
" does not match r-value shape ", rhs_shape.DebugString(),
". Automatic broadcasting not yet implemented."));
xla::ComputationDataHandle lhs;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
xla::ComputationDataHandle rhs = ctx->Input(4);
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
for (int i = 0; i < begin.size(); ++i) {
// TODO(phawkins): implement strides != 1
OP_REQUIRES(
ctx, strides[i] == 1 || strides[i] == -1,
errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
if (strides[i] > 0) {
slice_begin.push_back(begin[i]);
slice_dims.push_back(end[i] - begin[i]);
} else {
// Negative stride: swap begin and end, add 1 because the interval
// is semi-open, and mark the dimension to be reversed.
slice_begin.push_back(end[i] + 1);
slice_dims.push_back(begin[i] - end[i]);
dimensions_to_reverse.push_back(i);
}
}
if (!dimensions_to_reverse.empty()) {
rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
}
rhs = ctx->builder()->Reshape(rhs, slice_dims);
if (lhs_shape.dims() == 0) {
// TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
// and remove this workaround.
lhs = rhs;
} else {
lhs = ctx->builder()->DynamicUpdateSlice(
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
}
private:
int32 begin_mask_, end_mask_;
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
DataType index_type_;
};
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp);
} // namespace
} // namespace tensorflow

View File

@ -467,6 +467,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
result->xla_output_shape.Swap(
computation_shape.ValueOrDie()->mutable_result());
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
auto num_computation_outputs =
(xla::ShapeUtil::IsTuple(result->xla_output_shape))

View File

@ -357,6 +357,7 @@ void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) {
Status XlaOpKernelContext::AssignVariable(
int index, DataType type, const xla::ComputationDataHandle& handle) {
TF_RET_CHECK(handle.handle() != 0);
SetOpHasSideEffects();
const XlaExpression* expression =