Support dynamic input feature for convolutions.
Sometimes TF may fail to infer the shape that's supposed to be static and hence give us a dynamic shaped input feature (b/173158715). Doesn't seem easy to fix that in TF so we can just fix it on our side. PiperOrigin-RevId: 343156314 Change-Id: I4bd0aa2f70b2bc94907f279b3f294b527fdf464f
This commit is contained in:
parent
fa7d66d8a0
commit
b928529dc4
@ -261,6 +261,31 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
if (custom_call_handler_) {
|
||||
return custom_call_handler_(hlo, parent_);
|
||||
}
|
||||
|
||||
if (hlo->custom_call_target() == "DynamicConvolutionForward") {
|
||||
// If input feature is dynamic and kernel feature is static, we can infer
|
||||
// that input feature is also static.
|
||||
// E.g.,:
|
||||
// lhs = [B, X, Y, ?]
|
||||
// rhs = [X, Y, I, O]
|
||||
// dim_labels = b01f_01io
|
||||
// We can infer that the dynamic dimension in rhs is static I.
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
hlo->convolution_dimension_numbers();
|
||||
HloInstruction* input_feature = parent_->GetDynamicSize(
|
||||
hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
|
||||
HloInstruction* kernel_feature = parent_->GetDynamicSize(
|
||||
hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
|
||||
|
||||
if (input_feature != nullptr && kernel_feature == nullptr) {
|
||||
if (hlo->mutable_operand(0)->shape().dimensions(
|
||||
dnums.input_feature_dimension()) ==
|
||||
hlo->mutable_operand(1)->shape().dimensions(
|
||||
dnums.kernel_input_feature_dimension()))
|
||||
parent_->SetDynamicSize(hlo->mutable_operand(0), {},
|
||||
dnums.input_feature_dimension(), nullptr);
|
||||
}
|
||||
}
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
@ -520,7 +545,6 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
|
||||
HloInstruction* conv = hlo;
|
||||
const ConvolutionDimensionNumbers& dimension_numbers =
|
||||
conv->convolution_dimension_numbers();
|
||||
|
||||
if (operand_index == 0) {
|
||||
if (dimension == dimension_numbers.input_batch_dimension()) {
|
||||
parent_->SetDynamicSize(conv, {},
|
||||
@ -676,9 +700,8 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return Unimplemented(
|
||||
"XLA doesn't support dynamic input feature dimension on convolution: %s",
|
||||
hlo->ToString());
|
||||
// Input Feature dim disappears after convolution.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
|
||||
|
@ -250,6 +250,8 @@ bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
|
||||
HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
|
||||
HloInstruction* dynamic_size,
|
||||
HloInstruction* padding_scalar) {
|
||||
CHECK(inst != nullptr && dynamic_size != nullptr &&
|
||||
padding_scalar != nullptr);
|
||||
const Shape mask_shape =
|
||||
ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
|
||||
const Shape pred_shape =
|
||||
@ -902,6 +904,14 @@ StatusOr<bool> RewriteDynamicConvolutionForward(
|
||||
window_dim.stride(), custom_call_conv->padding_type());
|
||||
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
|
||||
}
|
||||
// Input feature dim can be dynamic too, reset it to zero.
|
||||
const int64 input_feature_dim = dnums.input_feature_dimension();
|
||||
if (HloInstruction* input_feature_dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(
|
||||
custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
|
||||
input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size,
|
||||
zero);
|
||||
}
|
||||
|
||||
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
|
||||
input = RewriteInputWithDynamicPadding(
|
||||
@ -976,6 +986,16 @@ StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
|
||||
padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
|
||||
}
|
||||
|
||||
// We only need to pad input feature on lhs to 0 -- it's mathematically
|
||||
// equivalent to padding both lhs and rhs to 0.
|
||||
const int64 input_feature_dim = dnums.input_feature_dimension();
|
||||
if (HloInstruction* input_feature_dynamic_size =
|
||||
dynamic_dimension_inference->GetDynamicSize(
|
||||
custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
|
||||
activations = PadWithScalar(activations, input_feature_dim,
|
||||
input_feature_dynamic_size, zero);
|
||||
}
|
||||
|
||||
if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
|
||||
activations = RewriteInputWithDynamicPadding(
|
||||
custom_call_conv, activations, zero, absl::MakeSpan(padding_before),
|
||||
|
@ -919,6 +919,34 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicInputFeature) {
|
||||
const string hlo_text = R"(
|
||||
HloModule DynamicInputFeature
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1, 1, 5] parameter(0)
|
||||
const = s32[] constant(5)
|
||||
one = f32[] constant(1)
|
||||
kernel = f32[1,5,1]{2,1,0} broadcast(f32[] one), dimensions={}
|
||||
param_dynamic = f32[1,1,<=5] set-dimension-size(param, const), dimensions={2}
|
||||
ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,5,1]{2,1,0} kernel),
|
||||
window={size=1 pad=0_0},
|
||||
dim_labels=b0f_0io->b0f,
|
||||
padding_type=PADDING_VALID,
|
||||
custom_call_target="DynamicConvolutionForward"
|
||||
}
|
||||
)";
|
||||
|
||||
Literal operand = LiteralUtil::CreateR3<float>({{{1, 2, 3, 4, 5}}});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand});
|
||||
|
||||
Literal expected = LiteralUtil::CreateR3<float>({{{15}}});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
Loading…
x
Reference in New Issue
Block a user