Add op definition and bridge implementation for TensorStridedSliceUpdate.
PiperOrigin-RevId: 329118405 Change-Id: Ib8ca243f37d2728efb327d4ef52c7b2a53790c3e
This commit is contained in:
parent
976f9b4605
commit
04eeb6d145
tensorflow/compiler
jit
mlir/tensorflow/ir
tf2xla/kernels
@ -2048,6 +2048,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"TensorScatterAdd",
|
||||
"TensorScatterSub",
|
||||
"TensorScatterUpdate",
|
||||
"TensorStridedSliceUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UpperBound",
|
||||
|
@ -11651,6 +11651,40 @@ On GPU, if an out of bound index is found, the index is ignored.
|
||||
];
|
||||
}
|
||||
|
||||
def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> {
|
||||
let summary = "Assign `value` to the sliced l-value reference of `input`.";
|
||||
|
||||
let description = [{
|
||||
The values of `value` are assigned to the positions in the tensor `input` that
|
||||
are selected by the slice parameters. The slice parameters `begin` `end`
|
||||
`strides` etc. work exactly as in `StridedSlice`.
|
||||
|
||||
NOTE this op currently does not support broadcasting and so `value`'s shape
|
||||
must be exactly the shape produced by the slice of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_I32OrI64Tensor:$begin,
|
||||
TF_I32OrI64Tensor:$end,
|
||||
TF_I32OrI64Tensor:$strides,
|
||||
TF_Tensor:$value,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "0">:$begin_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$end_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$new_axis_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
|
||||
let summary = "Constructs a tensor by tiling a given tensor.";
|
||||
|
||||
|
@ -446,7 +446,12 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
|
||||
TensorShape lhs_shape;
|
||||
xla::XlaOp lhs;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
|
||||
if (ctx->input_type(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
|
||||
} else {
|
||||
lhs_shape = ctx->InputShape(0);
|
||||
lhs = ctx->Input(0);
|
||||
}
|
||||
|
||||
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||
|
||||
@ -504,7 +509,11 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
|
||||
lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
|
||||
if (ctx->input_type(0) == DT_RESOURCE) {
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
|
||||
} else {
|
||||
ctx->SetOutput(0, lhs);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -520,5 +529,11 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
|
||||
.CompileTimeConstantInput("strides"),
|
||||
StridedSliceAssignOp);
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorStridedSliceUpdate")
|
||||
.CompileTimeConstantInput("begin")
|
||||
.CompileTimeConstantInput("end")
|
||||
.CompileTimeConstantInput("strides"),
|
||||
StridedSliceAssignOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user