Remove deprecated variants of DynamicSlice and DynamicUpdateSlice builders

Upgraded existing users by converting 1d start_slices to a list of scalars. I am expecting this to be performance neutral as these tensors are expected to be small. I decided against having the XlaBuilder do this internally as I guess we want to discourage usage of vector indices.

PiperOrigin-RevId: 311261628
Change-Id: I4b779a58cfca1699bdf5104c236bc6453fd419bc
This commit is contained in:
Smit Hinsu 2020-05-12 21:30:26 -07:00 committed by TensorFlower Gardener
parent 1c74b32aa2
commit 296993a42c
5 changed files with 36 additions and 83 deletions

View File

@ -28,6 +28,15 @@ limitations under the License.
namespace tensorflow {
namespace {
absl::InlinedVector<xla::XlaOp, 4> SliceVector(xla::XlaOp input, int64 rank) {
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
scalar_indices.reserve(rank);
for (int i = 0; i < rank; i++)
scalar_indices.push_back(
xla::Reshape(xla::Slice(input, {i}, {i + 1}, {1}), {}));
return scalar_indices;
}
class DynamicUpdateSliceOp : public XlaOpKernel {
public:
explicit DynamicUpdateSliceOp(OpKernelConstruction* context)
@ -41,21 +50,23 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
const TensorShape update_shape = ctx->InputShape("update");
const TensorShape index_shape = ctx->InputShape("indices");
int64 rank = input_shape.dims();
OP_REQUIRES(
ctx,
TensorShapeUtils::IsVector(index_shape) &&
index_shape.num_elements() == input_shape.dims(),
index_shape.num_elements() == rank,
errors::InvalidArgument("index must be a vector with length equal to "
"the number of input dimensions"));
OP_REQUIRES(
ctx, input_shape.dims() == update_shape.dims(),
ctx, rank == update_shape.dims(),
errors::InvalidArgument("input and update must have the same rank,"
" input shape is ",
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
xla::XlaOp indices = ctx->Input("indices");
xla::XlaOp result = xla::DynamicUpdateSlice(
ctx->Input("input"), ctx->Input("update"), ctx->Input("indices"));
ctx->Input("input"), ctx->Input("update"), SliceVector(indices, rank));
ctx->SetOutput(0, result);
}
};
@ -76,17 +87,18 @@ class DynamicSliceOp : public XlaOpKernel {
const TensorShape start_indices_shape = ctx->InputShape("start_indices");
const TensorShape size_indices_shape = ctx->InputShape("size_indices");
int64 rank = input_shape.dims();
OP_REQUIRES(ctx,
TensorShapeUtils::IsVector(start_indices_shape) &&
start_indices_shape.num_elements() == input_shape.dims(),
start_indices_shape.num_elements() == rank,
errors::InvalidArgument(
"start_indices must be a vector with length equal to "
"input rank, but input rank is ",
input_shape.dims(), " and start_indices has shape ",
rank, " and start_indices has shape ",
start_indices_shape.DebugString()));
OP_REQUIRES(ctx,
TensorShapeUtils::IsVector(size_indices_shape) &&
size_indices_shape.num_elements() == input_shape.dims(),
size_indices_shape.num_elements() == rank,
errors::InvalidArgument(
"size_indices must be a vector with length equal to "
"input rank, but input rank is ",
@ -96,8 +108,10 @@ class DynamicSliceOp : public XlaOpKernel {
std::vector<int64> size_indices;
OP_REQUIRES_OK(
ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices));
xla::XlaOp start_indices = ctx->Input("start_indices");
xla::XlaOp result = xla::DynamicSlice(
ctx->Input("input"), ctx->Input("start_indices"), size_indices);
ctx->Input("input"), SliceVector(start_indices, rank), size_indices);
ctx->SetOutput(0, result);
}
};

View File

@ -42,19 +42,17 @@ class SliceOp : public XlaOpKernel {
const TensorShape begin_tensor_shape = ctx->InputShape(1);
const TensorShape size_tensor_shape = ctx->InputShape(2);
const int input_dims = input_shape.dims();
OP_REQUIRES(
ctx,
TensorShapeUtils::IsVector(begin_tensor_shape) &&
TensorShapeUtils::IsVector(size_tensor_shape) &&
begin_tensor_shape.num_elements() == input_shape.dims() &&
size_tensor_shape.num_elements() == input_shape.dims(),
begin_tensor_shape.num_elements() == input_dims &&
size_tensor_shape.num_elements() == input_dims,
errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ",
input_shape.dims(), ", but got shapes ",
begin_tensor_shape.DebugString(), " and ",
size_tensor_shape.DebugString(), " instead."));
const int input_dims = input_shape.dims();
input_dims, ", but got shapes ", begin_tensor_shape.DebugString(),
" and ", size_tensor_shape.DebugString(), " instead."));
std::vector<int64> begin;
std::vector<int64> size;
@ -129,7 +127,15 @@ class SliceOp : public XlaOpKernel {
input_shape.dim_size(i), "], but ",
"got ", size[i]));
}
ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size));
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
scalar_indices.reserve(input_dims);
xla::XlaOp begin = ctx->Input("begin");
for (int i = 0; i < input_dims; i++)
scalar_indices.push_back(
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
}
}
};

View File

@ -860,28 +860,6 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
});
}
XlaOp XlaBuilder::DynamicSlice(XlaOp operand, XlaOp start_indices,
absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
GetShapePtr(start_indices));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferDynamicSliceShape(
*operand_shape, {*start_indices_shape}, slice_sizes));
*instr.mutable_shape() = shape.ToProto();
for (int64 size : slice_sizes) {
instr.add_dynamic_slice_sizes(size);
}
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
{operand, start_indices});
});
}
XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
@ -910,26 +888,6 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
});
}
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
XlaOp start_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
GetShapePtr(start_indices));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferDynamicUpdateSliceShape(
*operand_shape, *update_shape, {*start_indices_shape}));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
{operand, update, start_indices});
});
}
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -3152,20 +3110,11 @@ XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index,
stride, dimno);
}
XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices,
absl::Span<const int64> slice_sizes) {
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
}
XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes) {
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
}
XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
const XlaOp start_indices) {
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
}
XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
absl::Span<const XlaOp> start_indices) {
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);

View File

@ -421,14 +421,9 @@ class XlaBuilder {
virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
absl::Span<const int64> slice_sizes);
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes);
ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices);
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices);
@ -858,14 +853,10 @@ class XlaBuilder {
friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
absl::Span<const int64> slice_sizes);
friend XlaOp DynamicSlice(XlaOp operand,
absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes);
friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
XlaOp start_indices);
friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices);
@ -1438,10 +1429,6 @@ XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64> slice_sizes);
ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
absl::Span<const int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
// The shape of 'update' determines the shape of the slice of 'operand'
@ -1462,9 +1449,6 @@ XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices);
ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices);
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,

View File

@ -863,7 +863,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// Starts = iteration * 2;
auto starts = Mul(iteration, ConstantR0<int32>(&builder, 2));
// UpdateSlice.
auto out1 = DynamicUpdateSlice(input, update, starts);
auto out1 = DynamicUpdateSlice(input, update, {starts});
Tuple(&builder, {out0, out1});
body = builder.Build().ConsumeValueOrDie();