Implement TensorStridedSliceAssign XLA op.
PiperOrigin-RevId: 329315230 Change-Id: I5aca22493f5fa38fcd03a3f78f6d9e9afdaadb8b
This commit is contained in:
parent
027c55b667
commit
1619f2f19f
@ -1834,7 +1834,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
|
|||||||
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
|
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
|
||||||
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
|
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
|
||||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||||
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
|
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex",
|
||||||
|
"TensorStridedSliceUpdate",
|
||||||
|
}}};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11651,6 +11651,40 @@ On GPU, if an out of bound index is found, the index is ignored.
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> {
|
||||||
|
let summary = "Assign `value` to the sliced l-value reference of `input`.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The values of `value` are assigned to the positions in the tensor `input` that
|
||||||
|
are selected by the slice parameters. The slice parameters `begin` `end`
|
||||||
|
`strides` etc. work exactly as in `StridedSlice`.
|
||||||
|
|
||||||
|
NOTE this op currently does not support broadcasting and so `value`'s shape
|
||||||
|
must be exactly the shape produced by the slice of `input`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$input,
|
||||||
|
TF_I32OrI64Tensor:$begin,
|
||||||
|
TF_I32OrI64Tensor:$end,
|
||||||
|
TF_I32OrI64Tensor:$strides,
|
||||||
|
TF_Tensor:$value,
|
||||||
|
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$begin_mask,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$end_mask,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$new_axis_mask,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
|
def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
|
||||||
let summary = "Constructs a tensor by tiling a given tensor.";
|
let summary = "Constructs a tensor by tiling a given tensor.";
|
||||||
|
|
||||||
|
|||||||
@ -446,7 +446,12 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
|||||||
|
|
||||||
TensorShape lhs_shape;
|
TensorShape lhs_shape;
|
||||||
xla::XlaOp lhs;
|
xla::XlaOp lhs;
|
||||||
|
if (ctx->input_type(0) == DT_RESOURCE) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
|
||||||
|
} else {
|
||||||
|
lhs_shape = ctx->InputShape(0);
|
||||||
|
lhs = ctx->Input(0);
|
||||||
|
}
|
||||||
|
|
||||||
const TensorShape rhs_shape = ctx->InputShape(4);
|
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||||
|
|
||||||
@ -504,7 +509,11 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
|||||||
|
|
||||||
lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
|
lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
|
||||||
|
|
||||||
|
if (ctx->input_type(0) == DT_RESOURCE) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
|
||||||
|
} else {
|
||||||
|
ctx->SetOutput(0, lhs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -520,5 +529,11 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
|
|||||||
.CompileTimeConstantInput("strides"),
|
.CompileTimeConstantInput("strides"),
|
||||||
StridedSliceAssignOp);
|
StridedSliceAssignOp);
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("TensorStridedSliceUpdate")
|
||||||
|
.CompileTimeConstantInput("begin")
|
||||||
|
.CompileTimeConstantInput("end")
|
||||||
|
.CompileTimeConstantInput("strides"),
|
||||||
|
StridedSliceAssignOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -1234,6 +1234,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
sess.run(v[:].assign(too_small_val))
|
sess.run(v[:].assign(too_small_val))
|
||||||
|
|
||||||
|
@test_util.disable_xla("b/123559667")
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testTensorStridedSliceUpdateWithInputForward(self):
|
def testTensorStridedSliceUpdateWithInputForward(self):
|
||||||
"""Tests tensor_strided_slice_update with input-forwarding taking effect."""
|
"""Tests tensor_strided_slice_update with input-forwarding taking effect."""
|
||||||
@ -1243,6 +1244,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0])
|
return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0])
|
||||||
self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2]))))
|
self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2]))))
|
||||||
|
|
||||||
|
@test_util.disable_xla("b/123559667")
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testTensorStridedSliceUpdateNoInputForward(self):
|
def testTensorStridedSliceUpdateNoInputForward(self):
|
||||||
"""Tests tensor_strided_slice_update with no input-forwarding."""
|
"""Tests tensor_strided_slice_update with no input-forwarding."""
|
||||||
@ -1254,6 +1256,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
ans = y + z
|
ans = y + z
|
||||||
self.assertAllClose([1.6, 2.6], self.evaluate(ans))
|
self.assertAllClose([1.6, 2.6], self.evaluate(ans))
|
||||||
|
|
||||||
|
@test_util.disable_xla("b/123559667")
|
||||||
def testTensorStridedSliceUpdateGradSimple(self):
|
def testTensorStridedSliceUpdateGradSimple(self):
|
||||||
original = constant_op.constant([0.2, 0.3])
|
original = constant_op.constant([0.2, 0.3])
|
||||||
updates = constant_op.constant([0.4])
|
updates = constant_op.constant([0.4])
|
||||||
@ -1272,6 +1275,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
([4], [5], [3], [1], [3], 1, 0, 0, 0, 0),
|
([4], [5], [3], [1], [3], 1, 0, 0, 0, 0),
|
||||||
([2, 2, 3, 2], [0, 0, 1], [1, 0, 2], [1, 0, 1], [2, 3], 0, 0, 2, 0, 5)
|
([2, 2, 3, 2], [0, 0, 1], [1, 0, 2], [1, 0, 1], [2, 3], 0, 0, 2, 0, 5)
|
||||||
]))
|
]))
|
||||||
|
@test_util.disable_xla("b/123559667")
|
||||||
def testTensorStridedSliceUpdateGrad(
|
def testTensorStridedSliceUpdateGrad(
|
||||||
self, shape, begin, end, strides, updates_shape, *args):
|
self, shape, begin, end, strides, updates_shape, *args):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user