Add op definition and bridge implementation for TensorStridedSliceUpdate.

PiperOrigin-RevId: 329122036
Change-Id: I616b81ec2b328eddc77f662c77ae956972cd265a
This commit is contained in:
A. Unique TensorFlower 2020-08-29 14:28:13 -07:00 committed by TensorFlower Gardener
parent 04eeb6d145
commit 16d2b56de2
3 changed files with 2 additions and 52 deletions

View File

@ -2048,7 +2048,6 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"TensorScatterAdd",
"TensorScatterSub",
"TensorScatterUpdate",
"TensorStridedSliceUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"UpperBound",

View File

@ -11651,40 +11651,6 @@ On GPU, if an out of bound index is found, the index is ignored.
];
}
def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> {
let summary = "Assign `value` to the sliced l-value reference of `input`.";
let description = [{
The values of `value` are assigned to the positions in the tensor `input` that
are selected by the slice parameters. The slice parameters `begin` `end`
`strides` etc. work exactly as in `StridedSlice`.
NOTE this op currently does not support broadcasting and so `value`'s shape
must be exactly the shape produced by the slice of `input`.
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$begin,
TF_I32OrI64Tensor:$end,
TF_I32OrI64Tensor:$strides,
TF_Tensor:$value,
DefaultValuedAttr<I64Attr, "0">:$begin_mask,
DefaultValuedAttr<I64Attr, "0">:$end_mask,
DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask,
DefaultValuedAttr<I64Attr, "0">:$new_axis_mask,
DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
}
def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
let summary = "Constructs a tensor by tiling a given tensor.";

View File

@ -446,12 +446,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
TensorShape lhs_shape;
xla::XlaOp lhs;
if (ctx->input_type(0) == DT_RESOURCE) {
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
} else {
lhs_shape = ctx->InputShape(0);
lhs = ctx->Input(0);
}
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
const TensorShape rhs_shape = ctx->InputShape(4);
@ -509,11 +504,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
if (ctx->input_type(0) == DT_RESOURCE) {
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
} else {
ctx->SetOutput(0, lhs);
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
}
private:
@ -529,11 +520,5 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
.CompileTimeConstantInput("strides"),
StridedSliceAssignOp);
REGISTER_XLA_OP(Name("TensorStridedSliceUpdate")
.CompileTimeConstantInput("begin")
.CompileTimeConstantInput("end")
.CompileTimeConstantInput("strides"),
StridedSliceAssignOp);
} // namespace
} // namespace tensorflow