[TF:XLA] Add XLA implementation of ResourceStridedSliceAssign.
PiperOrigin-RevId: 156341053
This commit is contained in:
parent
1390dd68fe
commit
170f0b3507
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for XLA JIT compiler."""
|
||||
"""Tests for reading and writing variables."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,7 +21,9 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -114,5 +116,68 @@ class VariableOpsTest(XLATestCase):
|
||||
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
|
||||
|
||||
|
||||
class StridedSliceAssignChecker(object):
|
||||
"""Compares the results of a slice assignment using Tensorflow and numpy."""
|
||||
|
||||
def __init__(self, test, x, dtype):
|
||||
self.dtype = dtype
|
||||
self.test = test
|
||||
self.x_np = np.array(x).astype(dtype)
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
value = np.array(value).astype(self.dtype)
|
||||
|
||||
with self.test.test_session() as sess, self.test.test_scope():
|
||||
x = constant_op.constant(self.x_np, dtype=self.dtype)
|
||||
var = resource_variable_ops.ResourceVariable(x)
|
||||
sess.run(variables.variables_initializer([var]))
|
||||
val = sess.run(var[index].assign(value))
|
||||
# val_copy is used to check that tf.assign works equivalently to the
|
||||
# assign method above.
|
||||
val_copy = sess.run(state_ops.assign(var[index], value))
|
||||
valnp = np.copy(self.x_np)
|
||||
valnp[index] = np.array(value)
|
||||
self.test.assertAllEqual(val, valnp)
|
||||
self.test.assertAllEqual(val_copy, valnp)
|
||||
|
||||
|
||||
class SliceAssignTest(XLATestCase):
|
||||
|
||||
def testSliceAssign(self):
|
||||
for dtype in self.numeric_types:
|
||||
checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
|
||||
dtype=dtype)
|
||||
# No-op assignment
|
||||
checker[:] = [[10, 20, 30], [40, 50, 60]]
|
||||
# Checks trivial (1,1) shape tensor
|
||||
checker[1:2, 1:2] = [[66]]
|
||||
# shrink shape changes
|
||||
checker[1:2, 1] = [66]
|
||||
checker[1, 1:2] = [66]
|
||||
checker[1, 1] = 66
|
||||
# newaxis shape changes
|
||||
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
|
||||
# shrink and newaxis
|
||||
checker[None, None, 0, 0:1] = [[[99]]]
|
||||
# Non unit strides
|
||||
checker[::1, 1::-1] = [[3, 33], [4, 44]]
|
||||
# degenerate interval
|
||||
checker[8:10, 0] = []
|
||||
checker[8:10, 8:10] = [[]]
|
||||
|
||||
# Assign vector to scalar (rank-0) using newaxis
|
||||
checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype)
|
||||
checker2[()] = 6 # no indices
|
||||
checker2[...] = 6 # ellipsis
|
||||
checker2[None] = [6] # new axis
|
||||
|
||||
def testUninitialized(self):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"uninitialized variable"):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable([1, 2])
|
||||
sess.run(v[:].assign([1, 2]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -68,6 +68,9 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
{"Range", "limit"},
|
||||
{"Range", "delta"},
|
||||
{"Reshape", "shape"},
|
||||
{"ResourceStridedSliceAssign", "begin"},
|
||||
{"ResourceStridedSliceAssign", "end"},
|
||||
{"ResourceStridedSliceAssign", "strides"},
|
||||
{"Reverse", "dims"},
|
||||
{"ReverseV2", "axis"},
|
||||
{"Slice", "begin"},
|
||||
|
@ -219,5 +219,118 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
|
||||
|
||||
class StridedSliceAssignOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape final_shape;
|
||||
gtl::InlinedVector<int64, 4> begin;
|
||||
gtl::InlinedVector<int64, 4> end;
|
||||
gtl::InlinedVector<int64, 4> strides;
|
||||
|
||||
xla::Literal begin_literal, end_literal, strides_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
||||
|
||||
Tensor begin_tensor, end_tensor, strides_tensor;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||
&strides_tensor));
|
||||
|
||||
DataType lhs_type;
|
||||
TensorShape lhs_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
|
||||
|
||||
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||
|
||||
TensorShape dummy_processing_shape;
|
||||
ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
|
||||
ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape(
|
||||
&dummy_processing_shape);
|
||||
bool dummy = false;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor,
|
||||
ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_,
|
||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
|
||||
&dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
|
||||
// DynamicUpdateSlice does not allow 0-element updates. We should probably
|
||||
// check that rhs_shape can be broadcast to final_shape, but that is
|
||||
// probably better handled when implementing broadcasting more generally.
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(aselle): This check is too strong, we only should need
|
||||
// input_shape to be broadcastable to final_shape
|
||||
OP_REQUIRES(ctx, final_shape == rhs_shape,
|
||||
errors::Unimplemented(
|
||||
"sliced l-value shape ", final_shape.DebugString(),
|
||||
" does not match r-value shape ", rhs_shape.DebugString(),
|
||||
". Automatic broadcasting not yet implemented."));
|
||||
|
||||
xla::ComputationDataHandle lhs;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
|
||||
|
||||
xla::ComputationDataHandle rhs = ctx->Input(4);
|
||||
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
// TODO(phawkins): implement strides != 1
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1 || strides[i] == -1,
|
||||
errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
|
||||
if (strides[i] > 0) {
|
||||
slice_begin.push_back(begin[i]);
|
||||
slice_dims.push_back(end[i] - begin[i]);
|
||||
} else {
|
||||
// Negative stride: swap begin and end, add 1 because the interval
|
||||
// is semi-open, and mark the dimension to be reversed.
|
||||
slice_begin.push_back(end[i] + 1);
|
||||
slice_dims.push_back(begin[i] - end[i]);
|
||||
dimensions_to_reverse.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!dimensions_to_reverse.empty()) {
|
||||
rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
|
||||
}
|
||||
rhs = ctx->builder()->Reshape(rhs, slice_dims);
|
||||
|
||||
if (lhs_shape.dims() == 0) {
|
||||
// TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
|
||||
// and remove this workaround.
|
||||
lhs = rhs;
|
||||
} else {
|
||||
lhs = ctx->builder()->DynamicUpdateSlice(
|
||||
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
|
||||
}
|
||||
|
||||
private:
|
||||
int32 begin_mask_, end_mask_;
|
||||
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
||||
DataType index_type_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -467,6 +467,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
|
||||
result->xla_output_shape.Swap(
|
||||
computation_shape.ValueOrDie()->mutable_result());
|
||||
VLOG(2) << "XLA output shape: "
|
||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||
|
||||
auto num_computation_outputs =
|
||||
(xla::ShapeUtil::IsTuple(result->xla_output_shape))
|
||||
|
@ -357,6 +357,7 @@ void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) {
|
||||
|
||||
Status XlaOpKernelContext::AssignVariable(
|
||||
int index, DataType type, const xla::ComputationDataHandle& handle) {
|
||||
TF_RET_CHECK(handle.handle() != 0);
|
||||
SetOpHasSideEffects();
|
||||
|
||||
const XlaExpression* expression =
|
||||
|
Loading…
Reference in New Issue
Block a user