[XLA] Several improvements to dynamic padder.

- Support partial slice on dynamic dimensions -- this is achieved by letting the client to set the dynamic dimension after building a xla slice.
- Support dynamic pad on padded dimension.
- Fix a terrible bug exposed by rxsang's experiment where transpose creates wrong dynamic dimension.

PiperOrigin-RevId: 303033314
Change-Id: Id76f4619d12e88b8c0b3e9ec75baa1d78d1a7270
This commit is contained in:
Yunxing Dai 2020-03-25 21:02:17 -07:00 committed by TensorFlower Gardener
parent 2d055a4226
commit b3212dd802
6 changed files with 251 additions and 23 deletions

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
@ -58,18 +59,21 @@ class SliceOp : public XlaOpKernel {
std::vector<int64> begin;
std::vector<int64> size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
std::vector<int64> wrapped_size(size.size());
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
// `begin` is a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
if (size[i] == -1) {
// A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
size[i] = input_shape.dim_size(i) - begin[i];
wrapped_size[i] = input_shape.dim_size(i) - begin[i];
} else {
wrapped_size[i] = size[i];
}
}
for (int i = 0; i < input_dims; ++i) {
int64 b = begin[i];
int64 s = size[i];
int64 s = wrapped_size[i];
if (input_shape.dim_size(i) == 0) {
OP_REQUIRES(ctx, b == 0 && s == 0,
errors::InvalidArgument(
@ -91,10 +95,28 @@ class SliceOp : public XlaOpKernel {
std::vector<int64> limits;
limits.reserve(begin.size());
for (int i = 0; i < begin.size(); ++i) {
limits.push_back(begin[i] + size[i]);
limits.push_back(begin[i] + wrapped_size[i]);
}
std::vector<int64> strides(begin.size(), 1);
ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides));
auto slice = xla::Slice(ctx->Input(0), begin, limits, strides);
// Check for slice on dynamic dimensions.
ctx->set_dynamic_dimension_is_minus_one(true);
std::vector<int64> dynamic_size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size));
for (int64 i = 0; i < size.size(); ++i) {
if (dynamic_size[i] == -1) {
if (size[i] != -1) {
// If there is a dynamic dimension, properly set dimension size of
// the slice.
auto dynamic_size =
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
slice = xla::SetDimensionSize(slice, dynamic_size, i);
}
}
}
ctx->SetOutput(0, slice);
} else {
// `begin` is not a compile-time constant.
for (int i = 0; i < input_dims; ++i) {

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
@ -115,6 +116,72 @@ class StridedSliceOp : public XlaOpKernel {
slice = xla::Rev(slice, dimensions_to_reverse);
}
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
OP_REQUIRES_OK(ctx, operand_shape_or.status());
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
if (xla_shape.is_static()) {
// Static output shape, return a static slice.
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
return;
}
auto input_dim_sizes = input_shape.dim_sizes();
for (int64 i = 0; i < xla_shape.rank(); ++i) {
if (xla_shape.is_dynamic_dimension(i)) {
input_dim_sizes[i] = -1;
}
}
PartialTensorShape input_partial_shape(input_dim_sizes);
partial_final_shape.Clear();
end.clear();
strides.clear();
begin.clear();
// Run shape inferenference again with partial shape.
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
&begin_tensor, &end_tensor, strides_tensor,
input_partial_shape, begin_mask_, end_mask_,
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&dummy_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides));
if (partial_final_shape.AsTensorShape(&final_shape)) {
// Static output shape, return a static slice.
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
return;
}
// We consider slicing a dynamic tensor t with negative indices as a
// dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
bool backward_slice = end[i] < 0;
if (dynamic_dim && backward_slice) {
OP_REQUIRES(
ctx, strides[i] == 1,
errors::InvalidArgument("XLA has not implemented dynamic "
"sized slice with non-trival stride yet. "
"Please file a bug against XLA"));
OP_REQUIRES(ctx, begin[i] >= 0,
errors::InvalidArgument(
"XLA has not implemented dynamic "
"sized slice with negative begin index %lld. "
"Please file a bug against XLA",
begin[i]));
// If there is a dynamic dimension, properly set dimension size of
// the result.
auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
operand_size = xla::Add(
operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
slice = xla::SetDimensionSize(
slice,
xla::Sub(operand_size,
xla::ConstantR0<int32>(ctx->builder(), begin[i])),
i);
}
}
} else {
// When output shape is fully defined, it must be a size one slice:
//

View File

@ -2404,6 +2404,7 @@ cc_library(
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core/platform:macros",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
@ -250,15 +251,25 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
}
const PaddingConfig_PaddingConfigDimension& padding_config =
hlo->padding_config().dimensions(dimension);
if (padding_config.interior_padding() == 0 &&
padding_config.edge_padding_low() == 0 &&
padding_config.edge_padding_high() == 0) {
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
if (padding_config.interior_padding() == 0) {
HloInstruction* dynamic_size_adjusted = dynamic_size;
HloInstruction* adjustment = hlo->parent()->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
padding_config.edge_padding_low() +
padding_config.edge_padding_high())));
dynamic_size_adjusted =
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
dynamic_size_adjusted->shape(), HloOpcode::kAdd,
dynamic_size_adjusted, adjustment));
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted,
constraint);
return Status::OK();
} else {
return Unimplemented(
"Dynamic dimension propagation on padding dimension is not "
"supported.");
"Dynamic dimension propagation on interio padding dimension is "
"not "
"supported: %s",
hlo->ToString());
}
});
}
@ -400,11 +411,19 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
hlo,
[&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension],
dynamic_size, constraint);
DimensionConstraint constraint) -> Status {
int64 permuted_dim = -1;
for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
if (hlo->dimensions()[i] == dimension) {
TF_RET_CHECK(permuted_dim == -1);
permuted_dim = i;
}
}
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size,
constraint);
return Status::OK();
});
}
@ -979,15 +998,9 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
hlo->slice_strides(dimension) != 1 ||
hlo->slice_limits(dimension) !=
operand->shape().dimensions(dimension)) {
// Slicing a single element out eliminates the dynamic dimension.
if (hlo->shape().dimensions(dimension) == 1) {
// Slicing a partial element out eliminates the dynamic dimension.
return Status::OK();
}
return Unimplemented(
"Dynamic dimension propagation on Slice where it doesn't slice "
"out an entire dimension is not supported %s",
hlo->ToString());
}
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);

View File

@ -386,6 +386,53 @@ TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
}
TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
auto builder = HloComputation::Builder(TestName());
auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, lhs_shape, "A"));
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, rhs_shape, "B"));
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, scalar_shape_, "size_param"));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(1);
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
HloTestBase::DefaultPrecisionConfig(2)));
module_->AddEntryComputation(builder.Build());
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 1}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{1, {}, 0}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{1, {}, 1}));
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(RunInference());
// Nothing is dynamic in the output.
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
}
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
auto builder = HloComputation::Builder(TestName());
constexpr int xdim = 3;
@ -474,6 +521,45 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
}
TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
// Test the ability to trace unmodified dimensions
auto builder = HloComputation::Builder(TestName());
auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, input_shape, "A"));
auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, scalar_shape_, "size_param"));
auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, scalar_shape_, "size_param"));
auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/3, scalar_shape_, "size_param"));
auto* transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
module_->AddEntryComputation(builder.Build());
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{1, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 1}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{3, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 2}));
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(RunInference());
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
}
TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
// Test the ability to trace unmodified reshape dimensions.
auto builder = HloComputation::Builder(TestName());

View File

@ -865,6 +865,45 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicPad) {
const string hlo_text = R"(
HloModule TEST
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
ENTRY main {
param = s32[4] parameter(0)
size = s32[] constant(3)
padding = s32[] constant(2)
param_dynamic = s32[4] set-dimension-size(param, size),
dimensions={0}
// pad head and tail to 2
pad = s32[6] pad(param_dynamic, padding), padding=1_1
init = s32[] constant(0)
ROOT reduce = s32[] reduce(pad, init),
dimensions={0},
to_apply=update_s32
}
)";
Literal operand = LiteralUtil::CreateR1<int32>({1, 4, 3, 5});
auto module = GetHloModule(hlo_text);
// After padding head and tail with "2", the effective data will be [2, 1, 4,
// 3, 2]
Literal result = PadAndExecute(std::move(module), {&operand},
/*slice_dynamic_output=*/false);
Literal expected = LiteralUtil::CreateR0<int32>(12);
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
const string hlo_text = R"(
HloModule TEST