Support dynamic input feature for convolutions.

Sometimes TF may fail to infer the shape that's supposed to be static and hence give us a dynamic shaped input feature (b/173158715). Doesn't seem easy to fix that in TF so we can just fix it on our side.

PiperOrigin-RevId: 343156314
Change-Id: I4bd0aa2f70b2bc94907f279b3f294b527fdf464f
This commit is contained in:
Yunxing Dai 2020-11-18 14:24:22 -08:00 committed by TensorFlower Gardener
parent fa7d66d8a0
commit b928529dc4
3 changed files with 75 additions and 4 deletions

View File

@ -261,6 +261,31 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
if (custom_call_handler_) {
return custom_call_handler_(hlo, parent_);
}
if (hlo->custom_call_target() == "DynamicConvolutionForward") {
// If input feature is dynamic and kernel feature is static, we can infer
// that input feature is also static.
// E.g.,:
// lhs = [B, X, Y, ?]
// rhs = [X, Y, I, O]
// dim_labels = b01f_01io
// We can infer that the dynamic dimension in rhs is static I.
const ConvolutionDimensionNumbers& dnums =
hlo->convolution_dimension_numbers();
HloInstruction* input_feature = parent_->GetDynamicSize(
hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
HloInstruction* kernel_feature = parent_->GetDynamicSize(
hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
if (input_feature != nullptr && kernel_feature == nullptr) {
if (hlo->mutable_operand(0)->shape().dimensions(
dnums.input_feature_dimension()) ==
hlo->mutable_operand(1)->shape().dimensions(
dnums.kernel_input_feature_dimension()))
parent_->SetDynamicSize(hlo->mutable_operand(0), {},
dnums.input_feature_dimension(), nullptr);
}
}
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size) {
@ -520,7 +545,6 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
HloInstruction* conv = hlo;
const ConvolutionDimensionNumbers& dimension_numbers =
conv->convolution_dimension_numbers();
if (operand_index == 0) {
if (dimension == dimension_numbers.input_batch_dimension()) {
parent_->SetDynamicSize(conv, {},
@ -676,9 +700,8 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
return Status::OK();
}
}
return Unimplemented(
"XLA doesn't support dynamic input feature dimension on convolution: %s",
hlo->ToString());
// Input Feature dim disappears after convolution.
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(

View File

@ -250,6 +250,8 @@ bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
HloInstruction* dynamic_size,
HloInstruction* padding_scalar) {
CHECK(inst != nullptr && dynamic_size != nullptr &&
padding_scalar != nullptr);
const Shape mask_shape =
ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
const Shape pred_shape =
@ -902,6 +904,14 @@ StatusOr<bool> RewriteDynamicConvolutionForward(
window_dim.stride(), custom_call_conv->padding_type());
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
}
// Input feature dim can be dynamic too, reset it to zero.
const int64 input_feature_dim = dnums.input_feature_dimension();
if (HloInstruction* input_feature_dynamic_size =
dynamic_dimension_inference->GetDynamicSize(
custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size,
zero);
}
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
input = RewriteInputWithDynamicPadding(
@ -976,6 +986,16 @@ StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
}
// We only need to pad input feature on lhs to 0 -- it's mathematically
// equivalent to padding both lhs and rhs to 0.
const int64 input_feature_dim = dnums.input_feature_dimension();
if (HloInstruction* input_feature_dynamic_size =
dynamic_dimension_inference->GetDynamicSize(
custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
activations = PadWithScalar(activations, input_feature_dim,
input_feature_dynamic_size, zero);
}
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
activations = RewriteInputWithDynamicPadding(
custom_call_conv, activations, zero, absl::MakeSpan(padding_before),

View File

@ -919,6 +919,34 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicInputFeature) {
const string hlo_text = R"(
HloModule DynamicInputFeature
ENTRY main {
param = f32[1, 1, 5] parameter(0)
const = s32[] constant(5)
one = f32[] constant(1)
kernel = f32[1,5,1]{2,1,0} broadcast(f32[] one), dimensions={}
param_dynamic = f32[1,1,<=5] set-dimension-size(param, const), dimensions={2}
ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,5,1]{2,1,0} kernel),
window={size=1 pad=0_0},
dim_labels=b0f_0io->b0f,
padding_type=PADDING_VALID,
custom_call_target="DynamicConvolutionForward"
}
)";
Literal operand = LiteralUtil::CreateR3<float>({{{1, 2, 3, 4, 5}}});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand});
Literal expected = LiteralUtil::CreateR3<float>({{{15}}});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1