From b3212dd80273040d1f21885ba38abec3ab264a56 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 25 Mar 2020 21:02:17 -0700 Subject: [PATCH] [XLA] Several improvements to dynamic padder. - Support partial slice on dynamic dimensions -- this is achieved by letting the client to set the dynamic dimension after building a xla slice. - Support dynamic pad on padded dimension. - Fix a terrible bug exposed by rxsang's experiment where transpose creates wrong dynamic dimension. PiperOrigin-RevId: 303033314 Change-Id: Id76f4619d12e88b8c0b3e9ec75baa1d78d1a7270 --- .../compiler/tf2xla/kernels/slice_op.cc | 30 ++++++- .../tf2xla/kernels/strided_slice_op.cc | 67 +++++++++++++++ tensorflow/compiler/xla/service/BUILD | 1 + .../service/dynamic_dimension_inference.cc | 51 +++++++---- .../dynamic_dimension_inference_test.cc | 86 +++++++++++++++++++ .../xla/service/dynamic_padder_test.cc | 39 +++++++++ 6 files changed, 251 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 1be651da470..17d0b87edda 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" @@ -58,18 +59,21 @@ class SliceOp : public XlaOpKernel { std::vector begin; std::vector size; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size)); + std::vector wrapped_size(size.size()); if (ctx->ConstantInputAsIntVector(1, &begin).ok()) { // `begin` is a compile-time constant. for (int i = 0; i < input_dims; ++i) { if (size[i] == -1) { // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". - size[i] = input_shape.dim_size(i) - begin[i]; + wrapped_size[i] = input_shape.dim_size(i) - begin[i]; + } else { + wrapped_size[i] = size[i]; } } for (int i = 0; i < input_dims; ++i) { int64 b = begin[i]; - int64 s = size[i]; + int64 s = wrapped_size[i]; if (input_shape.dim_size(i) == 0) { OP_REQUIRES(ctx, b == 0 && s == 0, errors::InvalidArgument( @@ -91,10 +95,28 @@ class SliceOp : public XlaOpKernel { std::vector limits; limits.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { - limits.push_back(begin[i] + size[i]); + limits.push_back(begin[i] + wrapped_size[i]); } std::vector strides(begin.size(), 1); - ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides)); + auto slice = xla::Slice(ctx->Input(0), begin, limits, strides); + // Check for slice on dynamic dimensions. + ctx->set_dynamic_dimension_is_minus_one(true); + std::vector dynamic_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size)); + + for (int64 i = 0; i < size.size(); ++i) { + if (dynamic_size[i] == -1) { + if (size[i] != -1) { + // If there is a dynamic dimension, properly set dimension size of + // the slice. + auto dynamic_size = + xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {}); + + slice = xla::SetDimensionSize(slice, dynamic_size, i); + } + } + } + ctx->SetOutput(0, slice); } else { // `begin` is not a compile-time constant. for (int i = 0; i < input_dims; ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 05f1ee1797a..9093175af75 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { @@ -115,6 +116,72 @@ class StridedSliceOp : public XlaOpKernel { slice = xla::Rev(slice, dimensions_to_reverse); } slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); + auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); + OP_REQUIRES_OK(ctx, operand_shape_or.status()); + xla::Shape xla_shape = operand_shape_or.ValueOrDie(); + if (xla_shape.is_static()) { + // Static output shape, return a static slice. + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); + return; + } + auto input_dim_sizes = input_shape.dim_sizes(); + + for (int64 i = 0; i < xla_shape.rank(); ++i) { + if (xla_shape.is_dynamic_dimension(i)) { + input_dim_sizes[i] = -1; + } + } + PartialTensorShape input_partial_shape(input_dim_sizes); + partial_final_shape.Clear(); + end.clear(); + strides.clear(); + begin.clear(); + // Run shape inferenference again with partial shape. + OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + input_partial_shape, begin_mask_, end_mask_, + ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &dummy_processing_shape, &partial_final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); + if (partial_final_shape.AsTensorShape(&final_shape)) { + // Static output shape, return a static slice. + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); + return; + } + + // We consider slicing a dynamic tensor t with negative indices as a + // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n + for (int64 i = 0; i < partial_final_shape.dims(); ++i) { + bool dynamic_dim = partial_final_shape.dim_size(i) - 1; + bool backward_slice = end[i] < 0; + if (dynamic_dim && backward_slice) { + OP_REQUIRES( + ctx, strides[i] == 1, + errors::InvalidArgument("XLA has not implemented dynamic " + "sized slice with non-trival stride yet. " + "Please file a bug against XLA")); + + OP_REQUIRES(ctx, begin[i] >= 0, + errors::InvalidArgument( + "XLA has not implemented dynamic " + "sized slice with negative begin index %lld. " + "Please file a bug against XLA", + begin[i])); + // If there is a dynamic dimension, properly set dimension size of + // the result. + auto operand_size = xla::GetDimensionSize(ctx->Input(0), i); + + operand_size = xla::Add( + operand_size, xla::ConstantR0(ctx->builder(), end[i])); + slice = xla::SetDimensionSize( + slice, + xla::Sub(operand_size, + xla::ConstantR0(ctx->builder(), begin[i])), + i); + } + } } else { // When output shape is fully defined, it must be a size one slice: // diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 6d470149ca8..ae629197889 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2404,6 +2404,7 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/core/platform:macros", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 34d144ea1e9..94815e2fdbc 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" namespace xla { @@ -250,15 +251,25 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { } const PaddingConfig_PaddingConfigDimension& padding_config = hlo->padding_config().dimensions(dimension); - if (padding_config.interior_padding() == 0 && - padding_config.edge_padding_low() == 0 && - padding_config.edge_padding_high() == 0) { - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + if (padding_config.interior_padding() == 0) { + HloInstruction* dynamic_size_adjusted = dynamic_size; + HloInstruction* adjustment = hlo->parent()->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + padding_config.edge_padding_low() + + padding_config.edge_padding_high()))); + dynamic_size_adjusted = + hlo->parent()->AddInstruction(HloInstruction::CreateBinary( + dynamic_size_adjusted->shape(), HloOpcode::kAdd, + dynamic_size_adjusted, adjustment)); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted, + constraint); return Status::OK(); } else { return Unimplemented( - "Dynamic dimension propagation on padding dimension is not " - "supported."); + "Dynamic dimension propagation on interio padding dimension is " + "not " + "supported: %s", + hlo->ToString()); } }); } @@ -400,11 +411,19 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { - parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], - dynamic_size, constraint); + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size, + DimensionConstraint constraint) -> Status { + int64 permuted_dim = -1; + for (int64 i = 0; i < hlo->dimensions().size(); ++i) { + if (hlo->dimensions()[i] == dimension) { + TF_RET_CHECK(permuted_dim == -1); + permuted_dim = i; + } + } + parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size, + constraint); return Status::OK(); }); } @@ -979,14 +998,8 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) { hlo->slice_strides(dimension) != 1 || hlo->slice_limits(dimension) != operand->shape().dimensions(dimension)) { - // Slicing a single element out eliminates the dynamic dimension. - if (hlo->shape().dimensions(dimension) == 1) { - return Status::OK(); - } - return Unimplemented( - "Dynamic dimension propagation on Slice where it doesn't slice " - "out an entire dimension is not supported %s", - hlo->ToString()); + // Slicing a partial element out eliminates the dynamic dimension. + return Status::OK(); } parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index d2913f9d2a1..dbe57985fd4 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -386,6 +386,53 @@ TEST_F(DynamicDimensionInferenceTest, DotTestBatch) { EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr); } +TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) { + auto builder = HloComputation::Builder(TestName()); + auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512}); + auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, lhs_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, rhs_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 1})); + + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + // Nothing is dynamic in the output. + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr); +} + TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { auto builder = HloComputation::Builder(TestName()); constexpr int xdim = 3; @@ -474,6 +521,45 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) { EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); } +TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2); +} + TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { // Test the ability to trace unmodified reshape dimensions. auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index c37f9d0c3db..e669bc4dbe2 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -865,6 +865,45 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicPad) { + const string hlo_text = R"( +HloModule TEST + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[4] parameter(0) + size = s32[] constant(3) + padding = s32[] constant(2) + param_dynamic = s32[4] set-dimension-size(param, size), + dimensions={0} + // pad head and tail to 2 + pad = s32[6] pad(param_dynamic, padding), padding=1_1 + + init = s32[] constant(0) + ROOT reduce = s32[] reduce(pad, init), + dimensions={0}, + to_apply=update_s32 +} +)"; + + Literal operand = LiteralUtil::CreateR1({1, 4, 3, 5}); + auto module = GetHloModule(hlo_text); + + // After padding head and tail with "2", the effective data will be [2, 1, 4, + // 3, 2] + + Literal result = PadAndExecute(std::move(module), {&operand}, + /*slice_dynamic_output=*/false); + Literal expected = LiteralUtil::CreateR0(12); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicTupleSort) { const string hlo_text = R"( HloModule TEST