[XLA] Several improvements to dynamic padder.
- Support partial slice on dynamic dimensions -- this is achieved by letting the client to set the dynamic dimension after building a xla slice. - Support dynamic pad on padded dimension. - Fix a terrible bug exposed by rxsang's experiment where transpose creates wrong dynamic dimension. PiperOrigin-RevId: 303033314 Change-Id: Id76f4619d12e88b8c0b3e9ec75baa1d78d1a7270
This commit is contained in:
parent
2d055a4226
commit
b3212dd802
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/ops_util.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -58,18 +59,21 @@ class SliceOp : public XlaOpKernel {
|
||||
std::vector<int64> begin;
|
||||
std::vector<int64> size;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
|
||||
std::vector<int64> wrapped_size(size.size());
|
||||
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
|
||||
// `begin` is a compile-time constant.
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
if (size[i] == -1) {
|
||||
// A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
|
||||
size[i] = input_shape.dim_size(i) - begin[i];
|
||||
wrapped_size[i] = input_shape.dim_size(i) - begin[i];
|
||||
} else {
|
||||
wrapped_size[i] = size[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
int64 b = begin[i];
|
||||
int64 s = size[i];
|
||||
int64 s = wrapped_size[i];
|
||||
if (input_shape.dim_size(i) == 0) {
|
||||
OP_REQUIRES(ctx, b == 0 && s == 0,
|
||||
errors::InvalidArgument(
|
||||
@ -91,10 +95,28 @@ class SliceOp : public XlaOpKernel {
|
||||
std::vector<int64> limits;
|
||||
limits.reserve(begin.size());
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
limits.push_back(begin[i] + size[i]);
|
||||
limits.push_back(begin[i] + wrapped_size[i]);
|
||||
}
|
||||
std::vector<int64> strides(begin.size(), 1);
|
||||
ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides));
|
||||
auto slice = xla::Slice(ctx->Input(0), begin, limits, strides);
|
||||
// Check for slice on dynamic dimensions.
|
||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||
std::vector<int64> dynamic_size;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size));
|
||||
|
||||
for (int64 i = 0; i < size.size(); ++i) {
|
||||
if (dynamic_size[i] == -1) {
|
||||
if (size[i] != -1) {
|
||||
// If there is a dynamic dimension, properly set dimension size of
|
||||
// the slice.
|
||||
auto dynamic_size =
|
||||
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
|
||||
|
||||
slice = xla::SetDimensionSize(slice, dynamic_size, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx->SetOutput(0, slice);
|
||||
} else {
|
||||
// `begin` is not a compile-time constant.
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -115,6 +116,72 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
slice = xla::Rev(slice, dimensions_to_reverse);
|
||||
}
|
||||
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
|
||||
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
||||
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
||||
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
||||
if (xla_shape.is_static()) {
|
||||
// Static output shape, return a static slice.
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
return;
|
||||
}
|
||||
auto input_dim_sizes = input_shape.dim_sizes();
|
||||
|
||||
for (int64 i = 0; i < xla_shape.rank(); ++i) {
|
||||
if (xla_shape.is_dynamic_dimension(i)) {
|
||||
input_dim_sizes[i] = -1;
|
||||
}
|
||||
}
|
||||
PartialTensorShape input_partial_shape(input_dim_sizes);
|
||||
partial_final_shape.Clear();
|
||||
end.clear();
|
||||
strides.clear();
|
||||
begin.clear();
|
||||
// Run shape inferenference again with partial shape.
|
||||
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor,
|
||||
input_partial_shape, begin_mask_, end_mask_,
|
||||
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&dummy_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
if (partial_final_shape.AsTensorShape(&final_shape)) {
|
||||
// Static output shape, return a static slice.
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
return;
|
||||
}
|
||||
|
||||
// We consider slicing a dynamic tensor t with negative indices as a
|
||||
// dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
|
||||
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
|
||||
bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
|
||||
bool backward_slice = end[i] < 0;
|
||||
if (dynamic_dim && backward_slice) {
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1,
|
||||
errors::InvalidArgument("XLA has not implemented dynamic "
|
||||
"sized slice with non-trival stride yet. "
|
||||
"Please file a bug against XLA"));
|
||||
|
||||
OP_REQUIRES(ctx, begin[i] >= 0,
|
||||
errors::InvalidArgument(
|
||||
"XLA has not implemented dynamic "
|
||||
"sized slice with negative begin index %lld. "
|
||||
"Please file a bug against XLA",
|
||||
begin[i]));
|
||||
// If there is a dynamic dimension, properly set dimension size of
|
||||
// the result.
|
||||
auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
|
||||
|
||||
operand_size = xla::Add(
|
||||
operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
|
||||
slice = xla::SetDimensionSize(
|
||||
slice,
|
||||
xla::Sub(operand_size,
|
||||
xla::ConstantR0<int32>(ctx->builder(), begin[i])),
|
||||
i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When output shape is fully defined, it must be a size one slice:
|
||||
//
|
||||
|
@ -2404,6 +2404,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -250,15 +251,25 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
|
||||
}
|
||||
const PaddingConfig_PaddingConfigDimension& padding_config =
|
||||
hlo->padding_config().dimensions(dimension);
|
||||
if (padding_config.interior_padding() == 0 &&
|
||||
padding_config.edge_padding_low() == 0 &&
|
||||
padding_config.edge_padding_high() == 0) {
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
if (padding_config.interior_padding() == 0) {
|
||||
HloInstruction* dynamic_size_adjusted = dynamic_size;
|
||||
HloInstruction* adjustment = hlo->parent()->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
|
||||
padding_config.edge_padding_low() +
|
||||
padding_config.edge_padding_high())));
|
||||
dynamic_size_adjusted =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
|
||||
dynamic_size_adjusted->shape(), HloOpcode::kAdd,
|
||||
dynamic_size_adjusted, adjustment));
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted,
|
||||
constraint);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Unimplemented(
|
||||
"Dynamic dimension propagation on padding dimension is not "
|
||||
"supported.");
|
||||
"Dynamic dimension propagation on interio padding dimension is "
|
||||
"not "
|
||||
"supported: %s",
|
||||
hlo->ToString());
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -400,11 +411,19 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension],
|
||||
dynamic_size, constraint);
|
||||
hlo,
|
||||
[&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
int64 permuted_dim = -1;
|
||||
for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
|
||||
if (hlo->dimensions()[i] == dimension) {
|
||||
TF_RET_CHECK(permuted_dim == -1);
|
||||
permuted_dim = i;
|
||||
}
|
||||
}
|
||||
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size,
|
||||
constraint);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -979,14 +998,8 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
|
||||
hlo->slice_strides(dimension) != 1 ||
|
||||
hlo->slice_limits(dimension) !=
|
||||
operand->shape().dimensions(dimension)) {
|
||||
// Slicing a single element out eliminates the dynamic dimension.
|
||||
if (hlo->shape().dimensions(dimension) == 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
return Unimplemented(
|
||||
"Dynamic dimension propagation on Slice where it doesn't slice "
|
||||
"out an entire dimension is not supported %s",
|
||||
hlo->ToString());
|
||||
// Slicing a partial element out eliminates the dynamic dimension.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
|
@ -386,6 +386,53 @@ TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
|
||||
auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
|
||||
|
||||
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, lhs_shape, "A"));
|
||||
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, rhs_shape, "B"));
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/2, scalar_shape_, "size_param"));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(0);
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
dot_dnums.add_rhs_contracting_dimensions(1);
|
||||
auto dot = builder.AddInstruction(
|
||||
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
|
||||
HloTestBase::DefaultPrecisionConfig(2)));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {}, 1}));
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{1, {}, 0}));
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{1, {}, 1}));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
TF_ASSERT_OK(RunInference());
|
||||
// Nothing is dynamic in the output.
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
constexpr int xdim = 3;
|
||||
@ -474,6 +521,45 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
|
||||
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
|
||||
// Test the ability to trace unmodified dimensions
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
|
||||
|
||||
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, input_shape, "A"));
|
||||
auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, scalar_shape_, "size_param"));
|
||||
auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/2, scalar_shape_, "size_param"));
|
||||
auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/3, scalar_shape_, "size_param"));
|
||||
|
||||
auto* transpose = builder.AddInstruction(
|
||||
HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{1, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {}, 1}));
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{3, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {}, 2}));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
TF_ASSERT_OK(RunInference());
|
||||
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
|
||||
// Test the ability to trace unmodified reshape dimensions.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
@ -865,6 +865,45 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicPad) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TEST
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = s32[4] parameter(0)
|
||||
size = s32[] constant(3)
|
||||
padding = s32[] constant(2)
|
||||
param_dynamic = s32[4] set-dimension-size(param, size),
|
||||
dimensions={0}
|
||||
// pad head and tail to 2
|
||||
pad = s32[6] pad(param_dynamic, padding), padding=1_1
|
||||
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[] reduce(pad, init),
|
||||
dimensions={0},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
Literal operand = LiteralUtil::CreateR1<int32>({1, 4, 3, 5});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
// After padding head and tail with "2", the effective data will be [2, 1, 4,
|
||||
// 3, 2]
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand},
|
||||
/*slice_dynamic_output=*/false);
|
||||
Literal expected = LiteralUtil::CreateR0<int32>(12);
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TEST
|
||||
|
Loading…
Reference in New Issue
Block a user