[TF:XLA] Add XLA implementation of ResourceStridedSliceAssign.
PiperOrigin-RevId: 156341053
This commit is contained in:
parent
1390dd68fe
commit
170f0b3507
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for XLA JIT compiler."""
|
"""Tests for reading and writing variables."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -21,7 +21,9 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
@ -114,5 +116,68 @@ class VariableOpsTest(XLATestCase):
|
|||||||
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
|
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class StridedSliceAssignChecker(object):
|
||||||
|
"""Compares the results of a slice assignment using Tensorflow and numpy."""
|
||||||
|
|
||||||
|
def __init__(self, test, x, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
self.test = test
|
||||||
|
self.x_np = np.array(x).astype(dtype)
|
||||||
|
|
||||||
|
def __setitem__(self, index, value):
|
||||||
|
value = np.array(value).astype(self.dtype)
|
||||||
|
|
||||||
|
with self.test.test_session() as sess, self.test.test_scope():
|
||||||
|
x = constant_op.constant(self.x_np, dtype=self.dtype)
|
||||||
|
var = resource_variable_ops.ResourceVariable(x)
|
||||||
|
sess.run(variables.variables_initializer([var]))
|
||||||
|
val = sess.run(var[index].assign(value))
|
||||||
|
# val_copy is used to check that tf.assign works equivalently to the
|
||||||
|
# assign method above.
|
||||||
|
val_copy = sess.run(state_ops.assign(var[index], value))
|
||||||
|
valnp = np.copy(self.x_np)
|
||||||
|
valnp[index] = np.array(value)
|
||||||
|
self.test.assertAllEqual(val, valnp)
|
||||||
|
self.test.assertAllEqual(val_copy, valnp)
|
||||||
|
|
||||||
|
|
||||||
|
class SliceAssignTest(XLATestCase):
|
||||||
|
|
||||||
|
def testSliceAssign(self):
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
|
||||||
|
dtype=dtype)
|
||||||
|
# No-op assignment
|
||||||
|
checker[:] = [[10, 20, 30], [40, 50, 60]]
|
||||||
|
# Checks trivial (1,1) shape tensor
|
||||||
|
checker[1:2, 1:2] = [[66]]
|
||||||
|
# shrink shape changes
|
||||||
|
checker[1:2, 1] = [66]
|
||||||
|
checker[1, 1:2] = [66]
|
||||||
|
checker[1, 1] = 66
|
||||||
|
# newaxis shape changes
|
||||||
|
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
|
||||||
|
# shrink and newaxis
|
||||||
|
checker[None, None, 0, 0:1] = [[[99]]]
|
||||||
|
# Non unit strides
|
||||||
|
checker[::1, 1::-1] = [[3, 33], [4, 44]]
|
||||||
|
# degenerate interval
|
||||||
|
checker[8:10, 0] = []
|
||||||
|
checker[8:10, 8:10] = [[]]
|
||||||
|
|
||||||
|
# Assign vector to scalar (rank-0) using newaxis
|
||||||
|
checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype)
|
||||||
|
checker2[()] = 6 # no indices
|
||||||
|
checker2[...] = 6 # ellipsis
|
||||||
|
checker2[None] = [6] # new axis
|
||||||
|
|
||||||
|
def testUninitialized(self):
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
"uninitialized variable"):
|
||||||
|
with self.test_session() as sess, self.test_scope():
|
||||||
|
v = resource_variable_ops.ResourceVariable([1, 2])
|
||||||
|
sess.run(v[:].assign([1, 2]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -68,6 +68,9 @@ Status BackwardsConstAnalysis(const Graph& g,
|
|||||||
{"Range", "limit"},
|
{"Range", "limit"},
|
||||||
{"Range", "delta"},
|
{"Range", "delta"},
|
||||||
{"Reshape", "shape"},
|
{"Reshape", "shape"},
|
||||||
|
{"ResourceStridedSliceAssign", "begin"},
|
||||||
|
{"ResourceStridedSliceAssign", "end"},
|
||||||
|
{"ResourceStridedSliceAssign", "strides"},
|
||||||
{"Reverse", "dims"},
|
{"Reverse", "dims"},
|
||||||
{"ReverseV2", "axis"},
|
{"ReverseV2", "axis"},
|
||||||
{"Slice", "begin"},
|
{"Slice", "begin"},
|
||||||
|
@ -219,5 +219,118 @@ class StridedSliceGradOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
|
REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
|
||||||
|
|
||||||
|
class StridedSliceAssignOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape final_shape;
|
||||||
|
gtl::InlinedVector<int64, 4> begin;
|
||||||
|
gtl::InlinedVector<int64, 4> end;
|
||||||
|
gtl::InlinedVector<int64, 4> strides;
|
||||||
|
|
||||||
|
xla::Literal begin_literal, end_literal, strides_literal;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
||||||
|
|
||||||
|
Tensor begin_tensor, end_tensor, strides_tensor;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
||||||
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||||
|
&strides_tensor));
|
||||||
|
|
||||||
|
DataType lhs_type;
|
||||||
|
TensorShape lhs_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
|
||||||
|
|
||||||
|
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||||
|
|
||||||
|
TensorShape dummy_processing_shape;
|
||||||
|
ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
|
||||||
|
ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape(
|
||||||
|
&dummy_processing_shape);
|
||||||
|
bool dummy = false;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ValidateStridedSliceOp(
|
||||||
|
&begin_tensor, &end_tensor, strides_tensor,
|
||||||
|
ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_,
|
||||||
|
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||||
|
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
|
||||||
|
&dummy, &dummy, &begin, &end, &strides));
|
||||||
|
|
||||||
|
if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
|
||||||
|
// DynamicUpdateSlice does not allow 0-element updates. We should probably
|
||||||
|
// check that rhs_shape can be broadcast to final_shape, but that is
|
||||||
|
// probably better handled when implementing broadcasting more generally.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(aselle): This check is too strong, we only should need
|
||||||
|
// input_shape to be broadcastable to final_shape
|
||||||
|
OP_REQUIRES(ctx, final_shape == rhs_shape,
|
||||||
|
errors::Unimplemented(
|
||||||
|
"sliced l-value shape ", final_shape.DebugString(),
|
||||||
|
" does not match r-value shape ", rhs_shape.DebugString(),
|
||||||
|
". Automatic broadcasting not yet implemented."));
|
||||||
|
|
||||||
|
xla::ComputationDataHandle lhs;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
|
||||||
|
|
||||||
|
xla::ComputationDataHandle rhs = ctx->Input(4);
|
||||||
|
|
||||||
|
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||||
|
gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
|
||||||
|
for (int i = 0; i < begin.size(); ++i) {
|
||||||
|
// TODO(phawkins): implement strides != 1
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, strides[i] == 1 || strides[i] == -1,
|
||||||
|
errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
|
||||||
|
if (strides[i] > 0) {
|
||||||
|
slice_begin.push_back(begin[i]);
|
||||||
|
slice_dims.push_back(end[i] - begin[i]);
|
||||||
|
} else {
|
||||||
|
// Negative stride: swap begin and end, add 1 because the interval
|
||||||
|
// is semi-open, and mark the dimension to be reversed.
|
||||||
|
slice_begin.push_back(end[i] + 1);
|
||||||
|
slice_dims.push_back(begin[i] - end[i]);
|
||||||
|
dimensions_to_reverse.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dimensions_to_reverse.empty()) {
|
||||||
|
rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
|
||||||
|
}
|
||||||
|
rhs = ctx->builder()->Reshape(rhs, slice_dims);
|
||||||
|
|
||||||
|
if (lhs_shape.dims() == 0) {
|
||||||
|
// TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
|
||||||
|
// and remove this workaround.
|
||||||
|
lhs = rhs;
|
||||||
|
} else {
|
||||||
|
lhs = ctx->builder()->DynamicUpdateSlice(
|
||||||
|
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 begin_mask_, end_mask_;
|
||||||
|
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
||||||
|
DataType index_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -467,6 +467,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
|
|
||||||
result->xla_output_shape.Swap(
|
result->xla_output_shape.Swap(
|
||||||
computation_shape.ValueOrDie()->mutable_result());
|
computation_shape.ValueOrDie()->mutable_result());
|
||||||
|
VLOG(2) << "XLA output shape: "
|
||||||
|
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||||
|
|
||||||
auto num_computation_outputs =
|
auto num_computation_outputs =
|
||||||
(xla::ShapeUtil::IsTuple(result->xla_output_shape))
|
(xla::ShapeUtil::IsTuple(result->xla_output_shape))
|
||||||
|
@ -357,6 +357,7 @@ void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) {
|
|||||||
|
|
||||||
Status XlaOpKernelContext::AssignVariable(
|
Status XlaOpKernelContext::AssignVariable(
|
||||||
int index, DataType type, const xla::ComputationDataHandle& handle) {
|
int index, DataType type, const xla::ComputationDataHandle& handle) {
|
||||||
|
TF_RET_CHECK(handle.handle() != 0);
|
||||||
SetOpHasSideEffects();
|
SetOpHasSideEffects();
|
||||||
|
|
||||||
const XlaExpression* expression =
|
const XlaExpression* expression =
|
||||||
|
Loading…
Reference in New Issue
Block a user