From b928529dc409c14d27225e8f509ce81a6cf7dac3 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 18 Nov 2020 14:24:22 -0800 Subject: [PATCH] Support dynamic input feature for convolutions. Sometimes TF may fail to infer the shape that's supposed to be static and hence give us a dynamic shaped input feature (b/173158715). Doesn't seem easy to fix that in TF so we can just fix it on our side. PiperOrigin-RevId: 343156314 Change-Id: I4bd0aa2f70b2bc94907f279b3f294b527fdf464f --- .../service/dynamic_dimension_inference.cc | 31 ++++++++++++++++--- .../compiler/xla/service/dynamic_padder.cc | 20 ++++++++++++ .../xla/service/dynamic_padder_test.cc | 28 +++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 53efbcadd44..e10cd1f176b 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -261,6 +261,31 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { if (custom_call_handler_) { return custom_call_handler_(hlo, parent_); } + + if (hlo->custom_call_target() == "DynamicConvolutionForward") { + // If input feature is dynamic and kernel feature is static, we can infer + // that input feature is also static. + // E.g.,: + // lhs = [B, X, Y, ?] + // rhs = [X, Y, I, O] + // dim_labels = b01f_01io + // We can infer that the dynamic dimension in rhs is static I. + const ConvolutionDimensionNumbers& dnums = + hlo->convolution_dimension_numbers(); + HloInstruction* input_feature = parent_->GetDynamicSize( + hlo->mutable_operand(0), {}, dnums.input_feature_dimension()); + HloInstruction* kernel_feature = parent_->GetDynamicSize( + hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension()); + + if (input_feature != nullptr && kernel_feature == nullptr) { + if (hlo->mutable_operand(0)->shape().dimensions( + dnums.input_feature_dimension()) == + hlo->mutable_operand(1)->shape().dimensions( + dnums.kernel_input_feature_dimension())) + parent_->SetDynamicSize(hlo->mutable_operand(0), {}, + dnums.input_feature_dimension(), nullptr); + } + } return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, int64 operand_index, HloInstruction* dynamic_size) { @@ -520,7 +545,6 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution( HloInstruction* conv = hlo; const ConvolutionDimensionNumbers& dimension_numbers = conv->convolution_dimension_numbers(); - if (operand_index == 0) { if (dimension == dimension_numbers.input_batch_dimension()) { parent_->SetDynamicSize(conv, {}, @@ -676,9 +700,8 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward( return Status::OK(); } } - return Unimplemented( - "XLA doesn't support dynamic input feature dimension on convolution: %s", - hlo->ToString()); + // Input Feature dim disappears after convolution. + return Status::OK(); } Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding( diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 0c70a82e90a..9850e8710b6 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -250,6 +250,8 @@ bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num, HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim, HloInstruction* dynamic_size, HloInstruction* padding_scalar) { + CHECK(inst != nullptr && dynamic_size != nullptr && + padding_scalar != nullptr); const Shape mask_shape = ShapeUtil::ChangeElementType(inst->shape(), xla::S32); const Shape pred_shape = @@ -902,6 +904,14 @@ StatusOr RewriteDynamicConvolutionForward( window_dim.stride(), custom_call_conv->padding_type()); padding_before[spatial_dim_index] = dynamic_window_dims.padding_before; } + // Input feature dim can be dynamic too, reset it to zero. + const int64 input_feature_dim = dnums.input_feature_dimension(); + if (HloInstruction* input_feature_dynamic_size = + dynamic_dimension_inference->GetDynamicSize( + custom_call_conv->mutable_operand(0), {}, input_feature_dim)) { + input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size, + zero); + } if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) { input = RewriteInputWithDynamicPadding( @@ -976,6 +986,16 @@ StatusOr RewriteDynamicConvolutionKernelGrad( padding_before[spatial_dim_index] = dynamic_window_dims.padding_before; } + // We only need to pad input feature on lhs to 0 -- it's mathematically + // equivalent to padding both lhs and rhs to 0. + const int64 input_feature_dim = dnums.input_feature_dimension(); + if (HloInstruction* input_feature_dynamic_size = + dynamic_dimension_inference->GetDynamicSize( + custom_call_conv->mutable_operand(0), {}, input_feature_dim)) { + activations = PadWithScalar(activations, input_feature_dim, + input_feature_dynamic_size, zero); + } + if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) { activations = RewriteInputWithDynamicPadding( custom_call_conv, activations, zero, absl::MakeSpan(padding_before), diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 3855531a97b..9af25a6fbcc 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -919,6 +919,34 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicInputFeature) { + const string hlo_text = R"( +HloModule DynamicInputFeature + +ENTRY main { + param = f32[1, 1, 5] parameter(0) + const = s32[] constant(5) + one = f32[] constant(1) + kernel = f32[1,5,1]{2,1,0} broadcast(f32[] one), dimensions={} + param_dynamic = f32[1,1,<=5] set-dimension-size(param, const), dimensions={2} + ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,5,1]{2,1,0} kernel), + window={size=1 pad=0_0}, + dim_labels=b0f_0io->b0f, + padding_type=PADDING_VALID, + custom_call_target="DynamicConvolutionForward" +} +)"; + + Literal operand = LiteralUtil::CreateR3({{{1, 2, 3, 4, 5}}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + + Literal expected = LiteralUtil::CreateR3({{{15}}}); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) { const string hlo_text = R"( HloModule TensorFlowScatterV1