Support more cases in strided slice ops.
- When ranges are dynamic, we can lower strided slice ops into xla dynamic slices. - Optimize xla buidler dynamism inference to not materialize parameter tuple. - Export some helper data structures from strided_slice_op. PiperOrigin-RevId: 356839276 Change-Id: I3b5d9fcb1289496029afbf97e49df49df26ce699
This commit is contained in:
parent
dcaa3c9ed9
commit
41b3e84fa9
tensorflow
compiler
core/util
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/ops_util.h"
|
#include "tensorflow/core/framework/ops_util.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -52,6 +53,145 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EmitDynamicSlice(XlaOpKernelContext* ctx,
|
||||||
|
const absl::InlinedVector<int64, 4>& strides,
|
||||||
|
TensorShape processing_shape, TensorShape final_shape,
|
||||||
|
PartialTensorShape partial_processing_shape,
|
||||||
|
PartialTensorShape partial_final_shape,
|
||||||
|
const StridedSliceShapeSpec& shape_spec,
|
||||||
|
const std::vector<bool>& begins_are_dynamic,
|
||||||
|
const std::vector<bool>& ends_are_dynamic) {
|
||||||
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
|
xla::XlaOp slice = ctx->Input(0);
|
||||||
|
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
|
||||||
|
OP_REQUIRES(ctx, strides[i] == 1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Strides have to be one when inputs are not constant."));
|
||||||
|
}
|
||||||
|
// Infer static output shape, reconsile unknown dimension with input dim
|
||||||
|
// size.
|
||||||
|
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
|
||||||
|
if (partial_final_shape.dim_size(i) == -1) {
|
||||||
|
// Use input shape shape_spec.
|
||||||
|
partial_final_shape.set_dim(
|
||||||
|
i,
|
||||||
|
input_shape.dim_size(shape_spec.output_to_processing_mapping[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||||
|
InvalidArgument("XLA can't deduce compile time constant output "
|
||||||
|
"shape for strided slice: ",
|
||||||
|
partial_final_shape.DebugString(),
|
||||||
|
", output shape must be a compile-time constant"));
|
||||||
|
for (int64 i = 0; i < partial_processing_shape.dims(); ++i) {
|
||||||
|
if (partial_processing_shape.dim_size(i) == -1) {
|
||||||
|
// Use input shape shape_spec.
|
||||||
|
partial_processing_shape.set_dim(i, input_shape.dim_size(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, partial_processing_shape.AsTensorShape(&processing_shape),
|
||||||
|
InvalidArgument("XLA can't deduce compile time constant processing "
|
||||||
|
"shape for strided slice: ",
|
||||||
|
partial_processing_shape.DebugString(),
|
||||||
|
", output shape must be a compile-time constant"));
|
||||||
|
// When inputs are not compile time constants, shape inference can only
|
||||||
|
// inference size 1 slice.
|
||||||
|
std::vector<int64> slice_sizes(input_shape.dims(), 1);
|
||||||
|
// If there is dynamic begin/end (and if the dimension is not shrunk), we
|
||||||
|
// need to use dynamic shape infrastructure -- we slice the output with
|
||||||
|
// full size, then call SetDimensionSize on the output. However, if we
|
||||||
|
// slice with the full size at a non-zero dimension we may get OOB access.
|
||||||
|
// To avoid that, we first pad the input to 2x before calling slice.
|
||||||
|
xla::PaddingConfig padding_config;
|
||||||
|
bool need_padding = false;
|
||||||
|
std::vector<bool> result_dims_are_dynamic;
|
||||||
|
for (int64 i = 0; i < input_shape.dims(); ++i) {
|
||||||
|
int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
|
||||||
|
bool shrink_axis_set = (1 << i) & shape_spec.shrink_axis_dense_mask;
|
||||||
|
auto* dims = padding_config.add_dimensions();
|
||||||
|
dims->set_edge_padding_low(0);
|
||||||
|
|
||||||
|
dims->set_interior_padding(0);
|
||||||
|
if ((begins_are_dynamic[sparse_index] ||
|
||||||
|
ends_are_dynamic[sparse_index]) &&
|
||||||
|
!shrink_axis_set) {
|
||||||
|
// Need to slice this dimension so pad first.
|
||||||
|
dims->set_edge_padding_high(input_shape.dim_size(i));
|
||||||
|
need_padding = true;
|
||||||
|
result_dims_are_dynamic.push_back(true);
|
||||||
|
} else {
|
||||||
|
dims->set_edge_padding_high(0);
|
||||||
|
result_dims_are_dynamic.push_back(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (need_padding) {
|
||||||
|
// Pad input to 2x to avoid OOB access.
|
||||||
|
slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
|
||||||
|
padding_config);
|
||||||
|
}
|
||||||
|
std::vector<xla::XlaOp> start_indices;
|
||||||
|
std::vector<xla::XlaOp> slice_sizes_dynamic;
|
||||||
|
xla::Shape input_xla_shape = ctx->InputXlaShape(0).ValueOrDie();
|
||||||
|
for (int64 i = 0; i < input_shape.dims(); ++i) {
|
||||||
|
bool begin_mask = (1 << i) & shape_spec.begin_dense_mask;
|
||||||
|
bool end_mask = (1 << i) & shape_spec.end_dense_mask;
|
||||||
|
auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
|
||||||
|
xla::XlaOp begin_index, end_index;
|
||||||
|
int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
|
||||||
|
bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i);
|
||||||
|
xla::XlaOp dim_size;
|
||||||
|
if (xla_input_is_dynamic) {
|
||||||
|
dim_size = xla::GetDimensionSize(ctx->Input(0), i);
|
||||||
|
OP_REQUIRES(ctx, ctx->InputXlaType("begin") == xla::S32,
|
||||||
|
errors::InvalidArgument("'begin shape has to be int32 when "
|
||||||
|
"indices to slice op are dynamic"));
|
||||||
|
} else {
|
||||||
|
dim_size =
|
||||||
|
xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"),
|
||||||
|
input_xla_shape.dimensions(i));
|
||||||
|
}
|
||||||
|
if (begin_mask) {
|
||||||
|
begin_index = zero;
|
||||||
|
} else {
|
||||||
|
begin_index = xla::Slice(ctx->Input("begin"), {sparse_index},
|
||||||
|
{sparse_index + 1}, {1});
|
||||||
|
begin_index = xla::Reshape(begin_index, {});
|
||||||
|
auto index_negative = xla::Lt(begin_index, zero);
|
||||||
|
auto wrapped_index = xla::Add(dim_size, begin_index);
|
||||||
|
// Wrap negative indices around.
|
||||||
|
begin_index = xla::Select(index_negative, wrapped_index, begin_index);
|
||||||
|
}
|
||||||
|
start_indices.push_back(begin_index);
|
||||||
|
if (end_mask) {
|
||||||
|
end_index = dim_size;
|
||||||
|
} else {
|
||||||
|
end_index = xla::Slice(ctx->Input("end"), {sparse_index},
|
||||||
|
{sparse_index + 1}, {1});
|
||||||
|
end_index = xla::Reshape(end_index, {});
|
||||||
|
auto index_negative = xla::Lt(end_index, zero);
|
||||||
|
auto wrapped_index = xla::Add(dim_size, end_index);
|
||||||
|
end_index = xla::Select(index_negative, wrapped_index, end_index);
|
||||||
|
}
|
||||||
|
slice_sizes_dynamic.push_back(
|
||||||
|
xla::Max(xla::Sub(end_index, begin_index), zero));
|
||||||
|
}
|
||||||
|
|
||||||
|
slice =
|
||||||
|
xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes());
|
||||||
|
|
||||||
|
for (int64 i = 0; i < input_shape.dims(); ++i) {
|
||||||
|
if (result_dims_are_dynamic[i]) {
|
||||||
|
slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||||
|
ctx->SetOutput(0, slice);
|
||||||
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
const TensorShape input_shape = ctx->InputShape(0);
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
const TensorShape begin_shape = ctx->InputShape("begin");
|
const TensorShape begin_shape = ctx->InputShape("begin");
|
||||||
@ -80,31 +220,33 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||||
&strides_tensor));
|
&strides_tensor));
|
||||||
|
TensorShape processing_shape, final_shape;
|
||||||
TensorShape final_shape;
|
PartialTensorShape partial_processing_shape, partial_final_shape;
|
||||||
PartialTensorShape dummy_processing_shape, partial_final_shape;
|
|
||||||
bool dummy = false;
|
bool dummy = false;
|
||||||
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
StridedSliceShapeSpec shape_spec;
|
||||||
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx,
|
ctx,
|
||||||
ValidateStridedSliceOp(
|
ValidateStridedSliceOp(
|
||||||
begin_is_constant ? &begin_tensor : nullptr,
|
begin_is_constant ? &begin_tensor : nullptr,
|
||||||
end_is_constant ? &end_tensor : nullptr, strides_tensor,
|
end_is_constant ? &end_tensor : nullptr, strides_tensor,
|
||||||
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||||
shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
|
shrink_axis_mask_, &partial_processing_shape, &partial_final_shape,
|
||||||
&dummy, &dummy, &dummy, &begin, &end, &strides,
|
&dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec));
|
||||||
&output_to_sparse_mapping, &output_to_processing_mapping));
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
|
||||||
InvalidArgument("XLA can't deduce compile time constant output "
|
|
||||||
"shape for strided slice: ",
|
|
||||||
partial_final_shape.DebugString(),
|
|
||||||
", output shape must be a compile-time constant"));
|
|
||||||
|
|
||||||
xla::XlaOp slice = ctx->Input(0);
|
xla::XlaOp slice = ctx->Input(0);
|
||||||
|
std::vector<bool> begins_are_dynamic;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
|
||||||
|
std::vector<bool> ends_are_dynamic;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
|
||||||
if (begin_is_constant && end_is_constant) {
|
if (begin_is_constant && end_is_constant) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||||
|
InvalidArgument("XLA can't deduce compile time constant output "
|
||||||
|
"shape for strided slice: ",
|
||||||
|
partial_final_shape.DebugString(),
|
||||||
|
", output shape must be a compile-time constant"));
|
||||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||||
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||||
for (int i = 0; i < begin.size(); ++i) {
|
for (int i = 0; i < begin.size(); ++i) {
|
||||||
@ -129,12 +271,7 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
||||||
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
||||||
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
||||||
std::vector<bool> begins_are_dynamic;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
|
|
||||||
std::vector<bool> ends_are_dynamic;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
|
|
||||||
bool begins_are_static = absl::c_all_of(
|
bool begins_are_static = absl::c_all_of(
|
||||||
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
|
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
|
||||||
OP_REQUIRES(ctx, begins_are_static,
|
OP_REQUIRES(ctx, begins_are_static,
|
||||||
@ -150,13 +287,13 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int64 i = 0; i < final_shape.dims(); ++i) {
|
for (int64 i = 0; i < final_shape.dims(); ++i) {
|
||||||
int64 input_index = output_to_processing_mapping[i];
|
int64 input_index = shape_spec.output_to_processing_mapping[i];
|
||||||
if (input_index == -1) {
|
if (input_index == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
|
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
|
||||||
|
|
||||||
int64 sparse_index = output_to_sparse_mapping[i];
|
int64 sparse_index = shape_spec.output_to_sparse_mapping[i];
|
||||||
bool end_is_dynamic =
|
bool end_is_dynamic =
|
||||||
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
|
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
|
||||||
bool backward_slice = sparse_index == -1
|
bool backward_slice = sparse_index == -1
|
||||||
@ -208,62 +345,9 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
ctx->SetOutput(0, slice);
|
ctx->SetOutput(0, slice);
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
// When output shape is fully defined, it must be a size one slice:
|
EmitDynamicSlice(ctx, strides, processing_shape, final_shape,
|
||||||
//
|
partial_processing_shape, partial_final_shape,
|
||||||
// 1. The number of output elements has to be equal to the number of input
|
shape_spec, begins_are_dynamic, ends_are_dynamic);
|
||||||
// elements that are sliced.
|
|
||||||
// 2. The stride of the slice dimensions must be exact one.
|
|
||||||
int64 output_elements = final_shape.num_elements();
|
|
||||||
|
|
||||||
int64 input_elements_sliced = 1;
|
|
||||||
int64 slicing_dim_size = begin_shape.dim_size(0);
|
|
||||||
// We only support slicing major dimensions, so minor dimensions after
|
|
||||||
// slicing dimension are all sliced with their full sizes.
|
|
||||||
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
|
|
||||||
input_elements_sliced *= input_shape.dim_size(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, output_elements == input_elements_sliced,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Dynamic indices of strided_slice_op have to be leading "
|
|
||||||
"dimensions in the indices list."));
|
|
||||||
|
|
||||||
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, strides[i] == 1,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Strides have to be one when inputs are not constant."));
|
|
||||||
}
|
|
||||||
|
|
||||||
// When inputs are not compile time constants, shape inference can only
|
|
||||||
// inference size 1 slice.
|
|
||||||
std::vector<int64> slice_sizes(slicing_dim_size, 1);
|
|
||||||
std::vector<xla::XlaOp> start_indices;
|
|
||||||
auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
|
|
||||||
for (int64 d = 0; d < slicing_dim_size; ++d) {
|
|
||||||
auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1});
|
|
||||||
// Convert index to scalar.
|
|
||||||
index = xla::Reshape(index, {});
|
|
||||||
// Negative index: wrap it around with dimension size.
|
|
||||||
auto index_negative = xla::Lt(index, zero);
|
|
||||||
auto dim_size = xla::ConvertElementType(
|
|
||||||
xla::ConstantR0<int32>(ctx->builder(), input_shape.dim_size(d)),
|
|
||||||
ctx->InputXlaType("begin"));
|
|
||||||
auto wrapped_index = xla::Add(dim_size, index);
|
|
||||||
index = xla::Select(index_negative, wrapped_index, index);
|
|
||||||
start_indices.push_back(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
|
|
||||||
// For non-slice dims, naturally we get the full slice starting from 0.
|
|
||||||
slice_sizes.push_back(input_shape.dim_size(d));
|
|
||||||
start_indices.push_back(zero);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64> output_shape_dim_sizes;
|
|
||||||
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
|
|
||||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
|
||||||
ctx->SetOutput(0, slice);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,10 +392,7 @@ class StridedSliceGradOp : public XlaOpKernel {
|
|||||||
absl::InlinedVector<int64, 4> begin;
|
absl::InlinedVector<int64, 4> begin;
|
||||||
absl::InlinedVector<int64, 4> end;
|
absl::InlinedVector<int64, 4> end;
|
||||||
absl::InlinedVector<int64, 4> strides;
|
absl::InlinedVector<int64, 4> strides;
|
||||||
|
StridedSliceShapeSpec shape_spec;
|
||||||
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
|
||||||
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||||
&strides_tensor));
|
&strides_tensor));
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
@ -319,8 +400,7 @@ class StridedSliceGradOp : public XlaOpKernel {
|
|||||||
nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
|
nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
|
||||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||||
&processing_shape, &final_shape, &dummy, &dummy, &dummy,
|
&processing_shape, &final_shape, &dummy, &dummy, &dummy,
|
||||||
&begin, &end, &strides, &output_to_sparse_mapping,
|
&begin, &end, &strides, &shape_spec));
|
||||||
&output_to_processing_mapping));
|
|
||||||
for (int64 i = 0; i < processing_shape.dims(); ++i) {
|
for (int64 i = 0; i < processing_shape.dims(); ++i) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, strides[i] == 1,
|
ctx, strides[i] == 1,
|
||||||
@ -341,20 +421,20 @@ class StridedSliceGradOp : public XlaOpKernel {
|
|||||||
// Use grad shape, which is known, to update unknown processing shape.
|
// Use grad shape, which is known, to update unknown processing shape.
|
||||||
// Grad shape is the output of the ValidateStridedSliceOp function in
|
// Grad shape is the output of the ValidateStridedSliceOp function in
|
||||||
// forward pass, thus we use output_to_processing_mapping.
|
// forward pass, thus we use output_to_processing_mapping.
|
||||||
if (output_to_processing_mapping[i] != -1) {
|
if (shape_spec.output_to_processing_mapping[i] != -1) {
|
||||||
processing_shape.set_dim(output_to_processing_mapping[i],
|
processing_shape.set_dim(shape_spec.output_to_processing_mapping[i],
|
||||||
grad_shape.dimensions(i));
|
grad_shape.dimensions(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Similarly, use output_to_sparse_mapping to find out corresponding
|
// Similarly, use output_to_sparse_mapping to find out corresponding
|
||||||
// begin dim of the output, as indices for dynamic update slice.
|
// begin dim of the output, as indices for dynamic update slice.
|
||||||
int64 begin_dim = output_to_sparse_mapping[i];
|
int64 begin_dim = shape_spec.output_to_sparse_mapping[i];
|
||||||
if (begin_dim != -1) {
|
if (begin_dim != -1) {
|
||||||
auto begin_index =
|
auto begin_index =
|
||||||
xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
|
xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
|
||||||
auto begin_index_scalar = xla::Reshape(
|
auto begin_index_scalar = xla::Reshape(
|
||||||
xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
|
xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
|
||||||
begins[output_to_sparse_mapping[i]] = begin_index_scalar;
|
begins[shape_spec.output_to_sparse_mapping[i]] = begin_index_scalar;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
VLOG(1) << "processing_shape" << processing_shape.DebugString();
|
VLOG(1) << "processing_shape" << processing_shape.DebugString();
|
||||||
|
@ -3546,13 +3546,25 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
}
|
}
|
||||||
case HloOpcode::kTuple:
|
case HloOpcode::kTuple:
|
||||||
case HloOpcode::kTranspose:
|
case HloOpcode::kTranspose:
|
||||||
case HloOpcode::kGetTupleElement:
|
|
||||||
case HloOpcode::kSlice:
|
case HloOpcode::kSlice:
|
||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
case HloOpcode::kReshape:
|
case HloOpcode::kReshape:
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
break;
|
break;
|
||||||
|
case HloOpcode::kGetTupleElement: {
|
||||||
|
// Rewrite parameter followed by gte into constants to avoid
|
||||||
|
// rematerializing the tuple parameter (could be very large).
|
||||||
|
int64 operand_handle = instr_proto->operand_ids(0);
|
||||||
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||||
|
LookUpInstructionByHandle(operand_handle));
|
||||||
|
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||||
|
StringToHloOpcode(operand_proto->opcode()));
|
||||||
|
if (operand_opcode == HloOpcode::kParameter) {
|
||||||
|
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
case HloOpcode::kGetDimensionSize: {
|
case HloOpcode::kGetDimensionSize: {
|
||||||
int64 dimension = instr_proto->dimensions(0);
|
int64 dimension = instr_proto->dimensions(0);
|
||||||
int64 operand_handle = instr_proto->operand_ids(0);
|
int64 operand_handle = instr_proto->operand_ids(0);
|
||||||
@ -3646,6 +3658,18 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
should_visit_operand = false;
|
should_visit_operand = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (opcode == HloOpcode::kGetTupleElement) {
|
||||||
|
int64 operand_handle = instr_proto->operand_ids(0);
|
||||||
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||||
|
LookUpInstructionByHandle(operand_handle));
|
||||||
|
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||||
|
StringToHloOpcode(operand_proto->opcode()));
|
||||||
|
if (operand_opcode == HloOpcode::kParameter) {
|
||||||
|
// Don't rematerialize the whole parameter if it's followed by a GTE.
|
||||||
|
should_visit_operand = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (opcode == HloOpcode::kSelect) {
|
if (opcode == HloOpcode::kSelect) {
|
||||||
TF_ASSIGN_OR_RETURN(bool constant_predicate,
|
TF_ASSIGN_OR_RETURN(bool constant_predicate,
|
||||||
operand_is_constant(instr_proto, 0));
|
operand_is_constant(instr_proto, 0));
|
||||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/strided_slice_op.h"
|
#include "tensorflow/core/util/strided_slice_op.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <iterator>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
@ -64,6 +66,7 @@ struct StridedSliceDenseSpec {
|
|||||||
// index. A -1 in this vector means there the index is not from the sparse
|
// index. A -1 in this vector means there the index is not from the sparse
|
||||||
// input.
|
// input.
|
||||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
|
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
|
||||||
|
gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
|
||||||
// The dense indexed shrink mask is which processing dimensions
|
// The dense indexed shrink mask is which processing dimensions
|
||||||
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
||||||
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
||||||
@ -81,6 +84,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
|||||||
dense->begin.resize(dense->dims);
|
dense->begin.resize(dense->dims);
|
||||||
dense->end.resize(dense->dims);
|
dense->end.resize(dense->dims);
|
||||||
dense->strides.resize(dense->dims);
|
dense->strides.resize(dense->dims);
|
||||||
|
dense->input_shape_gather_indices_sparse.resize(dense->dims);
|
||||||
// What indices to get the final shape from.
|
// What indices to get the final shape from.
|
||||||
dense->begin_mask = 0;
|
dense->begin_mask = 0;
|
||||||
dense->end_mask = 0;
|
dense->end_mask = 0;
|
||||||
@ -114,6 +118,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
|||||||
dense->end_mask |= (1 << full_index);
|
dense->end_mask |= (1 << full_index);
|
||||||
dense->final_shape_gather_indices.push_back(full_index);
|
dense->final_shape_gather_indices.push_back(full_index);
|
||||||
dense->final_shape_gather_indices_sparse.push_back(-1);
|
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||||
|
dense->input_shape_gather_indices_sparse[full_index] = i;
|
||||||
}
|
}
|
||||||
} else if ((1 << i) & sparse.new_axis_mask) {
|
} else if ((1 << i) & sparse.new_axis_mask) {
|
||||||
dense->final_shape_gather_indices.push_back(kNewAxis);
|
dense->final_shape_gather_indices.push_back(kNewAxis);
|
||||||
@ -153,6 +158,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
|||||||
// from.
|
// from.
|
||||||
dense->final_shape_gather_indices_sparse.push_back(i);
|
dense->final_shape_gather_indices_sparse.push_back(i);
|
||||||
}
|
}
|
||||||
|
dense->input_shape_gather_indices_sparse[full_index] = i;
|
||||||
full_index++;
|
full_index++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,9 +174,7 @@ Status ValidateStridedSliceOp(
|
|||||||
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
||||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||||
gtl::InlinedVector<int64, 4>* strides,
|
gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) {
|
||||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
|
||||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
|
||||||
const bool begin_is_wrong =
|
const bool begin_is_wrong =
|
||||||
begin_tensor != nullptr &&
|
begin_tensor != nullptr &&
|
||||||
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
|
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
|
||||||
@ -375,13 +379,18 @@ Status ValidateStridedSliceOp(
|
|||||||
// slices like foo[3,...] will reduce dimension by 1.
|
// slices like foo[3,...] will reduce dimension by 1.
|
||||||
// This cannot be done earlier, because it depends on Step 3.
|
// This cannot be done earlier, because it depends on Step 3.
|
||||||
final_shape->Clear();
|
final_shape->Clear();
|
||||||
if (output_to_sparse_mapping != nullptr) {
|
if (shape_spec != nullptr) {
|
||||||
output_to_sparse_mapping->clear();
|
shape_spec->output_to_sparse_mapping.clear();
|
||||||
|
shape_spec->output_to_processing_mapping.clear();
|
||||||
|
shape_spec->processing_to_sparse_mapping.assign(
|
||||||
|
dense_spec.input_shape_gather_indices_sparse.begin(),
|
||||||
|
dense_spec.input_shape_gather_indices_sparse.end());
|
||||||
|
|
||||||
|
shape_spec->begin_dense_mask = dense_spec.begin_mask;
|
||||||
|
shape_spec->end_dense_mask = dense_spec.end_mask;
|
||||||
|
shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_to_processing_mapping != nullptr) {
|
|
||||||
output_to_processing_mapping->clear();
|
|
||||||
}
|
|
||||||
for (int64 dense_dim = 0;
|
for (int64 dense_dim = 0;
|
||||||
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
|
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
|
||||||
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
|
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
|
||||||
@ -389,22 +398,19 @@ Status ValidateStridedSliceOp(
|
|||||||
dense_spec.final_shape_gather_indices_sparse[dense_dim];
|
dense_spec.final_shape_gather_indices_sparse[dense_dim];
|
||||||
if (gather_index >= 0) {
|
if (gather_index >= 0) {
|
||||||
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
||||||
if (output_to_sparse_mapping != nullptr) {
|
if (shape_spec != nullptr) {
|
||||||
output_to_sparse_mapping->push_back(sparse_index);
|
shape_spec->output_to_sparse_mapping.push_back(sparse_index);
|
||||||
}
|
shape_spec->output_to_processing_mapping.push_back(gather_index);
|
||||||
if (output_to_processing_mapping != nullptr) {
|
|
||||||
output_to_processing_mapping->push_back(gather_index);
|
|
||||||
}
|
}
|
||||||
} else if (gather_index == kNewAxis) {
|
} else if (gather_index == kNewAxis) {
|
||||||
final_shape->AddDim(1);
|
final_shape->AddDim(1);
|
||||||
if (output_to_sparse_mapping != nullptr) {
|
if (shape_spec != nullptr) {
|
||||||
output_to_sparse_mapping->push_back(-1);
|
shape_spec->output_to_sparse_mapping.push_back(-1);
|
||||||
}
|
shape_spec->output_to_processing_mapping.push_back(-1);
|
||||||
if (output_to_processing_mapping != nullptr) {
|
|
||||||
output_to_processing_mapping->push_back(-1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,16 +422,14 @@ Status ValidateStridedSliceOp(
|
|||||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
StridedSliceShapeSpec* shape_spec) {
|
||||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
|
||||||
// Validate with PartialTensorShape output
|
// Validate with PartialTensorShape output
|
||||||
PartialTensorShape partial_processing_shape, partial_final_shape;
|
PartialTensorShape partial_processing_shape, partial_final_shape;
|
||||||
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
||||||
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
|
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
|
||||||
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||||
&partial_processing_shape, &partial_final_shape, is_identity,
|
&partial_processing_shape, &partial_final_shape, is_identity,
|
||||||
is_simple_slice, slice_dim0, begin, end, strides,
|
is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
|
||||||
output_to_sparse_mapping, output_to_processing_mapping));
|
|
||||||
|
|
||||||
// Verify that the output shapes are fully known
|
// Verify that the output shapes are fully known
|
||||||
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
|
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
|
||||||
|
@ -23,6 +23,26 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
struct StridedSliceShapeSpec {
|
||||||
|
// Begin mask canonlized in dense form.
|
||||||
|
int32 begin_dense_mask;
|
||||||
|
// End mask canonlized in dense form.
|
||||||
|
int32 end_dense_mask;
|
||||||
|
// Shrink axis mask canonlized in dense form.
|
||||||
|
int32 shrink_axis_dense_mask;
|
||||||
|
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
|
||||||
|
// index in the begin_tensor. If
|
||||||
|
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up
|
||||||
|
// in sparse_mapping.
|
||||||
|
gtl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||||
|
// output_to_processing_mapping is similar to output_to_sparse_mapping, but
|
||||||
|
// for processing shape.
|
||||||
|
gtl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||||
|
// processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding
|
||||||
|
// dim index in the begin_tensor.
|
||||||
|
gtl::InlinedVector<int64, 4> processing_to_sparse_mapping;
|
||||||
|
};
|
||||||
|
|
||||||
// Runs validation on the strided slice op parameters.
|
// Runs validation on the strided slice op parameters.
|
||||||
//
|
//
|
||||||
// Is a separate translation unit from the kernel so that:
|
// Is a separate translation unit from the kernel so that:
|
||||||
@ -41,16 +61,6 @@ namespace tensorflow {
|
|||||||
// (-1). Any validation that can be done without complete information is
|
// (-1). Any validation that can be done without complete information is
|
||||||
// performed.
|
// performed.
|
||||||
//
|
//
|
||||||
// This function changes the orders of dimensions, output_to_sparse_mapping and
|
|
||||||
// output_to_processing_mapping are used to track the order change.
|
|
||||||
//
|
|
||||||
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
|
|
||||||
// index in the begin_tensor. If
|
|
||||||
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
|
|
||||||
// sparse_mapping.
|
|
||||||
//
|
|
||||||
// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
|
|
||||||
// processing_shape.
|
|
||||||
Status ValidateStridedSliceOp(
|
Status ValidateStridedSliceOp(
|
||||||
const Tensor* begin_tensor, const Tensor* end_tensor,
|
const Tensor* begin_tensor, const Tensor* end_tensor,
|
||||||
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
|
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
|
||||||
@ -60,8 +70,7 @@ Status ValidateStridedSliceOp(
|
|||||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||||
gtl::InlinedVector<int64, 4>* strides,
|
gtl::InlinedVector<int64, 4>* strides,
|
||||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
StridedSliceShapeSpec* shape_spec = nullptr);
|
||||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
|
||||||
|
|
||||||
// Same as above, but the outputs are TensorShape, not PartialTensorShape
|
// Same as above, but the outputs are TensorShape, not PartialTensorShape
|
||||||
Status ValidateStridedSliceOp(
|
Status ValidateStridedSliceOp(
|
||||||
@ -72,8 +81,7 @@ Status ValidateStridedSliceOp(
|
|||||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
StridedSliceShapeSpec* shape_spec = nullptr);
|
||||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user