[TF2XLA] Support dynamic slice size in strided slice op.
- Add two side outputs in ValidateStridedSliceOp to help analyze dynamic dimensions. - Correctly set strided slice op's dynamic size if the slice size (slice end) is dynamic PiperOrigin-RevId: 327893466 Change-Id: I7a5061ed1b92006c6f1f15c71661265b5bdfec25
This commit is contained in:
parent
f8066f9dc4
commit
1d654624e4
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -28,13 +29,26 @@ class BroadcastToOp : public XlaOpKernel {
|
|||||||
: XlaOpKernel(context) {}
|
: XlaOpKernel(context) {}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* context) override {
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
const TensorShape input_shape = context->InputShape(0);
|
|
||||||
TensorShape output_shape;
|
TensorShape output_shape;
|
||||||
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
|
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
|
||||||
|
auto output_status_or =
|
||||||
|
BroadcastTo(context->Input(0), output_shape.dim_sizes());
|
||||||
|
OP_REQUIRES_OK(context, output_status_or.status());
|
||||||
|
auto output = output_status_or.ValueOrDie();
|
||||||
|
std::vector<bool> dynamic_dims;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
|
||||||
|
for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) {
|
||||||
|
if (dynamic_dims[dim]) {
|
||||||
|
output = xla::SetDimensionSize(
|
||||||
|
output,
|
||||||
|
xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}),
|
||||||
|
{}),
|
||||||
|
dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
|
context->SetOutput(0, output);
|
||||||
OP_REQUIRES_OK(context, output.status());
|
|
||||||
context->SetOutput(0, output.ValueOrDie());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/util/strided_slice_op.h"
|
#include "tensorflow/core/util/strided_slice_op.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
@ -23,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/ops_util.h"
|
#include "tensorflow/core/framework/ops_util.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -33,6 +37,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
using errors::InvalidArgument;
|
||||||
|
|
||||||
class StridedSliceOp : public XlaOpKernel {
|
class StridedSliceOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -48,7 +53,7 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
const TensorShape input_shape = ctx->InputShape(0);
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
const TensorShape begin_shape = ctx->InputShape("begin");
|
const TensorShape begin_shape = ctx->InputShape("begin");
|
||||||
|
VLOG(0) << "strided slice";
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, begin_shape.dims() == 1,
|
ctx, begin_shape.dims() == 1,
|
||||||
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
|
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
|
||||||
@ -78,20 +83,24 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
TensorShape final_shape;
|
TensorShape final_shape;
|
||||||
PartialTensorShape dummy_processing_shape, partial_final_shape;
|
PartialTensorShape dummy_processing_shape, partial_final_shape;
|
||||||
bool dummy = false;
|
bool dummy = false;
|
||||||
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||||
begin_is_constant ? &begin_tensor : nullptr,
|
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||||
end_is_constant ? &end_tensor : nullptr,
|
OP_REQUIRES_OK(
|
||||||
strides_tensor, input_shape, begin_mask_, end_mask_,
|
ctx,
|
||||||
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
ValidateStridedSliceOp(
|
||||||
&dummy_processing_shape, &partial_final_shape,
|
begin_is_constant ? &begin_tensor : nullptr,
|
||||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
end_is_constant ? &end_tensor : nullptr, strides_tensor,
|
||||||
|
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||||
|
shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
|
||||||
|
&dummy, &dummy, &dummy, &begin, &end, &strides,
|
||||||
|
&output_to_sparse_mapping, &output_to_processing_mapping));
|
||||||
|
|
||||||
OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||||
"XLA can't deduce compile time constant output "
|
InvalidArgument("XLA can't deduce compile time constant output "
|
||||||
"shape for strided slice: ",
|
"shape for strided slice: ",
|
||||||
partial_final_shape.DebugString(),
|
partial_final_shape.DebugString(),
|
||||||
", output shape must be a compile-time constant"));
|
", output shape must be a compile-time constant"));
|
||||||
|
|
||||||
xla::XlaOp slice = ctx->Input(0);
|
xla::XlaOp slice = ctx->Input(0);
|
||||||
if (begin_is_constant && end_is_constant) {
|
if (begin_is_constant && end_is_constant) {
|
||||||
@ -119,69 +128,84 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
||||||
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
||||||
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
||||||
if (xla_shape.is_static()) {
|
std::vector<bool> begins_are_dynamic;
|
||||||
// Static output shape, return a static slice.
|
OP_REQUIRES_OK(
|
||||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
|
||||||
|
std::vector<bool> ends_are_dynamic;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
|
||||||
|
bool begins_are_static = absl::c_all_of(
|
||||||
|
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
|
||||||
|
OP_REQUIRES(ctx, begins_are_static,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"XLA can't use dynamic begin values for slice."));
|
||||||
|
bool ends_are_static = absl::c_all_of(
|
||||||
|
ends_are_dynamic, [](bool dynamic) { return !dynamic; });
|
||||||
|
// Static output shape, return a static slice.
|
||||||
|
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||||
|
if (xla_shape.is_static() && ends_are_static) {
|
||||||
ctx->SetOutput(0, slice);
|
ctx->SetOutput(0, slice);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto input_dim_sizes = input_shape.dim_sizes();
|
|
||||||
|
|
||||||
for (int64 i = 0; i < xla_shape.rank(); ++i) {
|
for (int64 i = 0; i < final_shape.dims(); ++i) {
|
||||||
if (xla_shape.is_dynamic_dimension(i)) {
|
int64 input_index = output_to_processing_mapping[i];
|
||||||
input_dim_sizes[i] = -1;
|
if (input_index == -1) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
}
|
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
|
||||||
PartialTensorShape input_partial_shape(input_dim_sizes);
|
|
||||||
partial_final_shape.Clear();
|
|
||||||
end.clear();
|
|
||||||
strides.clear();
|
|
||||||
begin.clear();
|
|
||||||
// Run shape inferenference again with partial shape.
|
|
||||||
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
|
||||||
&begin_tensor, &end_tensor, strides_tensor,
|
|
||||||
input_partial_shape, begin_mask_, end_mask_,
|
|
||||||
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
|
||||||
&dummy_processing_shape, &partial_final_shape,
|
|
||||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
|
||||||
if (partial_final_shape.AsTensorShape(&final_shape)) {
|
|
||||||
// Static output shape, return a static slice.
|
|
||||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
|
||||||
ctx->SetOutput(0, slice);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We consider slicing a dynamic tensor t with negative indices as a
|
int64 sparse_index = output_to_sparse_mapping[i];
|
||||||
// dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
|
bool end_is_dynamic =
|
||||||
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
|
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
|
||||||
bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
|
bool backward_slice = sparse_index == -1
|
||||||
bool backward_slice = end[i] < 0;
|
? false
|
||||||
if (dynamic_dim && backward_slice) {
|
: end_literal.Get<int32>({sparse_index}) < 0;
|
||||||
|
if ((input_is_dynamic && backward_slice) || end_is_dynamic) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, strides[i] == 1,
|
ctx, strides[input_index] == 1,
|
||||||
errors::InvalidArgument("XLA has not implemented dynamic "
|
errors::InvalidArgument("XLA has not implemented dynamic "
|
||||||
"sized slice with non-trival stride yet. "
|
"sized slice with non-trival stride yet. "
|
||||||
"Please file a bug against XLA"));
|
"Please file a bug against XLA"));
|
||||||
|
|
||||||
OP_REQUIRES(ctx, begin[i] >= 0,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"XLA has not implemented dynamic "
|
|
||||||
"sized slice with negative begin index %lld. "
|
|
||||||
"Please file a bug against XLA",
|
|
||||||
begin[i]));
|
|
||||||
// If there is a dynamic dimension, properly set dimension size of
|
// If there is a dynamic dimension, properly set dimension size of
|
||||||
// the result.
|
// the result.
|
||||||
auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
|
auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index);
|
||||||
|
if (backward_slice) {
|
||||||
operand_size = xla::Add(
|
// We consider slicing a dynamic tensor t with negative indices as
|
||||||
operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
|
// a dynamic sized slice. E.g., t[: -n], the result length is
|
||||||
|
// shape(t) - n.
|
||||||
|
OP_REQUIRES(ctx, !end_is_dynamic,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"XLA has not implemented dynamic "
|
||||||
|
"sized slice with dynamic negative index %lld. "));
|
||||||
|
operand_size = xla::Add(
|
||||||
|
operand_size,
|
||||||
|
xla::ConstantR0<int32>(ctx->builder(),
|
||||||
|
end_literal.Get<int32>({sparse_index})));
|
||||||
|
} else {
|
||||||
|
// The end of slice with dynamic slice size is the min of operand
|
||||||
|
// shape and slice size. E.g., t[:end_size], result size is
|
||||||
|
// min(shape(t), end_size).
|
||||||
|
xla::XlaOp end_size;
|
||||||
|
if (end_is_dynamic) {
|
||||||
|
end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index},
|
||||||
|
{sparse_index + 1}, {1}),
|
||||||
|
{});
|
||||||
|
} else {
|
||||||
|
end_size =
|
||||||
|
xla::ConstantR0<int32>(ctx->builder(), end[input_index]);
|
||||||
|
}
|
||||||
|
operand_size = xla::Min(operand_size, end_size);
|
||||||
|
}
|
||||||
slice = xla::SetDimensionSize(
|
slice = xla::SetDimensionSize(
|
||||||
slice,
|
slice,
|
||||||
xla::Sub(operand_size,
|
xla::Sub(operand_size, xla::ConstantR0<int32>(
|
||||||
xla::ConstantR0<int32>(ctx->builder(), begin[i])),
|
ctx->builder(), begin[input_index])),
|
||||||
i);
|
i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ctx->SetOutput(0, slice);
|
||||||
|
return;
|
||||||
} else {
|
} else {
|
||||||
// When output shape is fully defined, it must be a size one slice:
|
// When output shape is fully defined, it must be a size one slice:
|
||||||
//
|
//
|
||||||
@ -239,9 +263,9 @@ class StridedSliceOp : public XlaOpKernel {
|
|||||||
|
|
||||||
std::vector<int64> output_shape_dim_sizes;
|
std::vector<int64> output_shape_dim_sizes;
|
||||||
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
|
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
|
||||||
|
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||||
|
ctx->SetOutput(0, slice);
|
||||||
}
|
}
|
||||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
|
||||||
ctx->SetOutput(0, slice);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -59,6 +59,11 @@ struct StridedSliceDenseSpec {
|
|||||||
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
|
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
|
||||||
// it will be 1. A shrunk dimension is skipped.
|
// it will be 1. A shrunk dimension is skipped.
|
||||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
|
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
|
||||||
|
// This vector has the same size as final_shape_gather_indices, but it
|
||||||
|
// remembers the sparse index that a dimension comes from, instead of dense
|
||||||
|
// index. A -1 in this vector means there the index is not from the sparse
|
||||||
|
// input.
|
||||||
|
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
|
||||||
// The dense indexed shrink mask is which processing dimensions
|
// The dense indexed shrink mask is which processing dimensions
|
||||||
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
||||||
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
||||||
@ -108,9 +113,11 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
|||||||
dense->begin_mask |= (1 << full_index);
|
dense->begin_mask |= (1 << full_index);
|
||||||
dense->end_mask |= (1 << full_index);
|
dense->end_mask |= (1 << full_index);
|
||||||
dense->final_shape_gather_indices.push_back(full_index);
|
dense->final_shape_gather_indices.push_back(full_index);
|
||||||
|
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||||
}
|
}
|
||||||
} else if ((1 << i) & sparse.new_axis_mask) {
|
} else if ((1 << i) & sparse.new_axis_mask) {
|
||||||
dense->final_shape_gather_indices.push_back(kNewAxis);
|
dense->final_shape_gather_indices.push_back(kNewAxis);
|
||||||
|
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||||
} else {
|
} else {
|
||||||
if (full_index == dense->begin.size()) {
|
if (full_index == dense->begin.size()) {
|
||||||
return errors::InvalidArgument("Index out of range using input dim ",
|
return errors::InvalidArgument("Index out of range using input dim ",
|
||||||
@ -138,9 +145,13 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
|||||||
// axis (now in dense form) so we can ignore dense->end below.
|
// axis (now in dense form) so we can ignore dense->end below.
|
||||||
if (sparse.shrink_axis_mask & (1 << i)) {
|
if (sparse.shrink_axis_mask & (1 << i)) {
|
||||||
dense->final_shape_gather_indices.push_back(kShrinkAxis);
|
dense->final_shape_gather_indices.push_back(kShrinkAxis);
|
||||||
|
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||||
dense->shrink_axis_mask |= (1 << full_index);
|
dense->shrink_axis_mask |= (1 << full_index);
|
||||||
} else {
|
} else {
|
||||||
dense->final_shape_gather_indices.push_back(full_index);
|
dense->final_shape_gather_indices.push_back(full_index);
|
||||||
|
// Remember that where in the sparse shape the dense dim comes
|
||||||
|
// from.
|
||||||
|
dense->final_shape_gather_indices_sparse.push_back(i);
|
||||||
}
|
}
|
||||||
full_index++;
|
full_index++;
|
||||||
}
|
}
|
||||||
@ -157,7 +168,9 @@ Status ValidateStridedSliceOp(
|
|||||||
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
||||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||||
gtl::InlinedVector<int64, 4>* strides) {
|
gtl::InlinedVector<int64, 4>* strides,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
||||||
const bool begin_is_wrong =
|
const bool begin_is_wrong =
|
||||||
begin_tensor != nullptr &&
|
begin_tensor != nullptr &&
|
||||||
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
|
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
|
||||||
@ -362,11 +375,34 @@ Status ValidateStridedSliceOp(
|
|||||||
// slices like foo[3,...] will reduce dimension by 1.
|
// slices like foo[3,...] will reduce dimension by 1.
|
||||||
// This cannot be done earlier, because it depends on Step 3.
|
// This cannot be done earlier, because it depends on Step 3.
|
||||||
final_shape->Clear();
|
final_shape->Clear();
|
||||||
for (auto gather_index : dense_spec.final_shape_gather_indices) {
|
if (output_to_sparse_mapping != nullptr) {
|
||||||
|
output_to_sparse_mapping->clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output_to_processing_mapping != nullptr) {
|
||||||
|
output_to_processing_mapping->clear();
|
||||||
|
}
|
||||||
|
for (int64 dense_dim = 0;
|
||||||
|
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
|
||||||
|
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
|
||||||
|
int64 sparse_index =
|
||||||
|
dense_spec.final_shape_gather_indices_sparse[dense_dim];
|
||||||
if (gather_index >= 0) {
|
if (gather_index >= 0) {
|
||||||
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
||||||
|
if (output_to_sparse_mapping != nullptr) {
|
||||||
|
output_to_sparse_mapping->push_back(sparse_index);
|
||||||
|
}
|
||||||
|
if (output_to_processing_mapping != nullptr) {
|
||||||
|
output_to_processing_mapping->push_back(gather_index);
|
||||||
|
}
|
||||||
} else if (gather_index == kNewAxis) {
|
} else if (gather_index == kNewAxis) {
|
||||||
final_shape->AddDim(1);
|
final_shape->AddDim(1);
|
||||||
|
if (output_to_sparse_mapping != nullptr) {
|
||||||
|
output_to_sparse_mapping->push_back(-1);
|
||||||
|
}
|
||||||
|
if (output_to_processing_mapping != nullptr) {
|
||||||
|
output_to_processing_mapping->push_back(-1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -379,14 +415,17 @@ Status ValidateStridedSliceOp(
|
|||||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
|
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
||||||
// Validate with PartialTensorShape output
|
// Validate with PartialTensorShape output
|
||||||
PartialTensorShape partial_processing_shape, partial_final_shape;
|
PartialTensorShape partial_processing_shape, partial_final_shape;
|
||||||
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
||||||
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
|
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
|
||||||
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||||
&partial_processing_shape, &partial_final_shape, is_identity,
|
&partial_processing_shape, &partial_final_shape, is_identity,
|
||||||
is_simple_slice, slice_dim0, begin, end, strides));
|
is_simple_slice, slice_dim0, begin, end, strides,
|
||||||
|
output_to_sparse_mapping, output_to_processing_mapping));
|
||||||
|
|
||||||
// Verify that the output shapes are fully known
|
// Verify that the output shapes are fully known
|
||||||
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
|
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
|
||||||
|
@ -40,6 +40,17 @@ namespace tensorflow {
|
|||||||
// some dimensions of <processing_shape> and/or <final_shape> may be unknown
|
// some dimensions of <processing_shape> and/or <final_shape> may be unknown
|
||||||
// (-1). Any validation that can be done without complete information is
|
// (-1). Any validation that can be done without complete information is
|
||||||
// performed.
|
// performed.
|
||||||
|
//
|
||||||
|
// This function changes the orders of dimensions, output_to_sparse_mapping and
|
||||||
|
// output_to_processing_mapping are used to track the order change.
|
||||||
|
//
|
||||||
|
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
|
||||||
|
// index in the begin_tensor. If
|
||||||
|
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
|
||||||
|
// sparse_mapping.
|
||||||
|
//
|
||||||
|
// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
|
||||||
|
// processing_shape.
|
||||||
Status ValidateStridedSliceOp(
|
Status ValidateStridedSliceOp(
|
||||||
const Tensor* begin_tensor, const Tensor* end_tensor,
|
const Tensor* begin_tensor, const Tensor* end_tensor,
|
||||||
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
|
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
|
||||||
@ -48,7 +59,9 @@ Status ValidateStridedSliceOp(
|
|||||||
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
||||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||||
gtl::InlinedVector<int64, 4>* strides);
|
gtl::InlinedVector<int64, 4>* strides,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
||||||
|
|
||||||
// Same as above, but the outputs are TensorShape, not PartialTensorShape
|
// Same as above, but the outputs are TensorShape, not PartialTensorShape
|
||||||
Status ValidateStridedSliceOp(
|
Status ValidateStridedSliceOp(
|
||||||
@ -58,7 +71,9 @@ Status ValidateStridedSliceOp(
|
|||||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
|
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
||||||
|
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user