Convert 4D convolutions with trivial dimensions to lower-dimensional convolutions.
cuDNN cannot handle convolutions with 4 or more spatial dimensions. However we can work around this in some cases by removing trivial dimensions. PiperOrigin-RevId: 291894511 Change-Id: Ic0e3fa4f4181e105ca62f92a235a55413acd7253
This commit is contained in:
parent
e2b70c0f80
commit
682cef5c04
@ -1767,6 +1767,36 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "convolution_4d_expander",
|
||||||
|
srcs = ["convolution_4d_expander.cc"],
|
||||||
|
hdrs = ["convolution_4d_expander.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":op_expander_pass",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "convolution_4d_expander_test",
|
||||||
|
srcs = ["convolution_4d_expander_test.cc"],
|
||||||
|
deps = [
|
||||||
|
"convolution_4d_expander",
|
||||||
|
":hlo",
|
||||||
|
":hlo_matchers",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "batchnorm_expander_test",
|
name = "batchnorm_expander_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
175
tensorflow/compiler/xla/service/convolution_4d_expander.cc
Normal file
175
tensorflow/compiler/xla/service/convolution_4d_expander.cc
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
bool Convolution4DExpander::InstructionMatchesPattern(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
if (instruction->opcode() != HloOpcode::kConvolution) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether it is a 4D convolution and whether there is at least one
|
||||||
|
// trivial dimension.
|
||||||
|
const ConvolutionDimensionNumbers& dim_nums =
|
||||||
|
instruction->convolution_dimension_numbers();
|
||||||
|
if (dim_nums.input_spatial_dimensions().size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Shape input = instruction->operand(0)->shape();
|
||||||
|
for (int64 i = 0; i < dim_nums.input_spatial_dimensions().size(); ++i) {
|
||||||
|
int64 spatial_dim = dim_nums.input_spatial_dimensions(i);
|
||||||
|
if (input.dimensions(spatial_dim) == 1 &&
|
||||||
|
instruction->window().dimensions(i).padding_low() == 0 &&
|
||||||
|
instruction->window().dimensions(i).padding_high() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<HloInstruction*> Convolution4DExpander::ExpandInstruction(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
HloComputation* computation = instruction->parent();
|
||||||
|
ConvolutionDimensionNumbers dim_nums =
|
||||||
|
instruction->convolution_dimension_numbers();
|
||||||
|
ConvolutionDimensionNumbers new_dim_nums = dim_nums;
|
||||||
|
|
||||||
|
std::vector<int64> removed_input_dimensions;
|
||||||
|
std::vector<int64> removed_kernel_dimensions;
|
||||||
|
std::vector<int64> removed_output_dimensions;
|
||||||
|
new_dim_nums.clear_input_spatial_dimensions();
|
||||||
|
new_dim_nums.clear_output_spatial_dimensions();
|
||||||
|
new_dim_nums.clear_kernel_spatial_dimensions();
|
||||||
|
Window new_window;
|
||||||
|
HloInstruction* input = instruction->mutable_operand(0);
|
||||||
|
|
||||||
|
// Collect all trivial input spatial dimensions, and the corresponding
|
||||||
|
// dimensions of the kernel and the output. Those will be removed.
|
||||||
|
for (int64 i = 0; i < dim_nums.input_spatial_dimensions().size(); ++i) {
|
||||||
|
int64 input_spatial_dim = dim_nums.input_spatial_dimensions(i);
|
||||||
|
int64 output_spatial_dim = dim_nums.output_spatial_dimensions(i);
|
||||||
|
int64 kernel_spatial_dim = dim_nums.kernel_spatial_dimensions(i);
|
||||||
|
if (input->shape().dimensions(input_spatial_dim) == 1 &&
|
||||||
|
instruction->window().dimensions(i).padding_low() == 0 &&
|
||||||
|
instruction->window().dimensions(i).padding_high() == 0) {
|
||||||
|
removed_input_dimensions.push_back(input_spatial_dim);
|
||||||
|
removed_output_dimensions.push_back(output_spatial_dim);
|
||||||
|
removed_kernel_dimensions.push_back(kernel_spatial_dim);
|
||||||
|
} else {
|
||||||
|
*new_window.add_dimensions() = instruction->window().dimensions(i);
|
||||||
|
new_dim_nums.add_input_spatial_dimensions(input_spatial_dim);
|
||||||
|
new_dim_nums.add_output_spatial_dimensions(output_spatial_dim);
|
||||||
|
new_dim_nums.add_kernel_spatial_dimensions(kernel_spatial_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We sort the removed dimensions into descending order, because we need to
|
||||||
|
// delete higher dimensions first, otherwise we would have to adjust dimension
|
||||||
|
// indices.
|
||||||
|
std::sort(removed_input_dimensions.begin(), removed_input_dimensions.end(),
|
||||||
|
std::greater<>());
|
||||||
|
std::sort(removed_output_dimensions.begin(), removed_output_dimensions.end(),
|
||||||
|
std::greater<>());
|
||||||
|
std::sort(removed_kernel_dimensions.begin(), removed_kernel_dimensions.end(),
|
||||||
|
std::greater<>());
|
||||||
|
|
||||||
|
// Compute the new shapes.
|
||||||
|
Shape new_input_shape = input->shape();
|
||||||
|
for (int64 dim : removed_input_dimensions) {
|
||||||
|
new_input_shape.DeleteDimension(dim);
|
||||||
|
}
|
||||||
|
HloInstruction* kernel = instruction->mutable_operand(1);
|
||||||
|
Shape new_kernel_shape = kernel->shape();
|
||||||
|
for (int64 dim : removed_kernel_dimensions) {
|
||||||
|
new_kernel_shape.DeleteDimension(dim);
|
||||||
|
}
|
||||||
|
Shape new_output_shape = instruction->shape();
|
||||||
|
for (int64 dim : removed_output_dimensions) {
|
||||||
|
new_output_shape.DeleteDimension(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relabel the dimension numbers to account for the deleted dimensions. For
|
||||||
|
// each dimension number, we need to reduce its value by the number of removed
|
||||||
|
// smaller dimensions.
|
||||||
|
auto compute_new_dimension = [](const std::vector<int64>& removed_dimensions,
|
||||||
|
int64 old_dimension) {
|
||||||
|
int64 num_smaller = absl::c_count_if(
|
||||||
|
removed_dimensions, [old_dimension](int64 removed_dimension) {
|
||||||
|
return removed_dimension < old_dimension;
|
||||||
|
});
|
||||||
|
return old_dimension - num_smaller;
|
||||||
|
};
|
||||||
|
new_dim_nums.set_input_batch_dimension(compute_new_dimension(
|
||||||
|
removed_input_dimensions, new_dim_nums.input_batch_dimension()));
|
||||||
|
new_dim_nums.set_input_feature_dimension(compute_new_dimension(
|
||||||
|
removed_input_dimensions, new_dim_nums.input_feature_dimension()));
|
||||||
|
for (int64 i = 0; i < new_dim_nums.input_spatial_dimensions().size(); ++i) {
|
||||||
|
new_dim_nums.set_input_spatial_dimensions(
|
||||||
|
i, compute_new_dimension(removed_input_dimensions,
|
||||||
|
new_dim_nums.input_spatial_dimensions(i)));
|
||||||
|
}
|
||||||
|
new_dim_nums.set_output_batch_dimension(compute_new_dimension(
|
||||||
|
removed_output_dimensions, new_dim_nums.output_batch_dimension()));
|
||||||
|
new_dim_nums.set_output_feature_dimension(compute_new_dimension(
|
||||||
|
removed_output_dimensions, new_dim_nums.output_feature_dimension()));
|
||||||
|
for (int64 i = 0; i < new_dim_nums.output_spatial_dimensions().size(); ++i) {
|
||||||
|
new_dim_nums.set_output_spatial_dimensions(
|
||||||
|
i, compute_new_dimension(removed_output_dimensions,
|
||||||
|
new_dim_nums.output_spatial_dimensions(i)));
|
||||||
|
}
|
||||||
|
new_dim_nums.set_kernel_input_feature_dimension(
|
||||||
|
compute_new_dimension(removed_kernel_dimensions,
|
||||||
|
new_dim_nums.kernel_input_feature_dimension()));
|
||||||
|
new_dim_nums.set_kernel_output_feature_dimension(
|
||||||
|
compute_new_dimension(removed_kernel_dimensions,
|
||||||
|
new_dim_nums.kernel_output_feature_dimension()));
|
||||||
|
for (int64 i = 0; i < new_dim_nums.kernel_spatial_dimensions().size(); ++i) {
|
||||||
|
new_dim_nums.set_kernel_spatial_dimensions(
|
||||||
|
i, compute_new_dimension(removed_kernel_dimensions,
|
||||||
|
new_dim_nums.kernel_spatial_dimensions(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape the input and the kernel.
|
||||||
|
HloInstruction* reshaped_input = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateReshape(new_input_shape, input));
|
||||||
|
HloInstruction* reshaped_kernel = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateReshape(new_kernel_shape, kernel));
|
||||||
|
|
||||||
|
// We want to use CloneWithNewOperands, but that doesn't support substituting
|
||||||
|
// the window and the ConvolutionDimensionNumbers. So we set this on the old
|
||||||
|
// instruction (which is going to be removed anyway) before cloning it.
|
||||||
|
instruction->set_convolution_dimension_numbers(new_dim_nums);
|
||||||
|
instruction->set_window(new_window);
|
||||||
|
HloInstruction* new_convolution =
|
||||||
|
computation->AddInstruction(instruction->CloneWithNewOperands(
|
||||||
|
new_output_shape, {reshaped_input, reshaped_kernel}));
|
||||||
|
return computation->AddInstruction(
|
||||||
|
HloInstruction::CreateReshape(instruction->shape(), new_convolution));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
39
tensorflow/compiler/xla/service/convolution_4d_expander.h
Normal file
39
tensorflow/compiler/xla/service/convolution_4d_expander.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
class Convolution4DExpander : public OpExpanderPass {
|
||||||
|
public:
|
||||||
|
absl::string_view name() const override { return "convolution_4d_expander"; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||||
|
|
||||||
|
StatusOr<HloInstruction*> ExpandInstruction(
|
||||||
|
HloInstruction* instruction) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_
|
172
tensorflow/compiler/xla/service/convolution_4d_expander_test.cc
Normal file
172
tensorflow/compiler/xla/service/convolution_4d_expander_test.cc
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using Convolution4DExpanderTest = HloTestBase;
|
||||||
|
|
||||||
|
TEST_F(Convolution4DExpanderTest, ConvertTo2DConvolution) {
|
||||||
|
string hlo_string = R"(HloModule convolution_4d_fp32
|
||||||
|
|
||||||
|
ENTRY convolution_computation {
|
||||||
|
input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0)
|
||||||
|
kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1)
|
||||||
|
ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto computation = module->entry_computation();
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(root->window().dimensions_size(), 4);
|
||||||
|
Convolution4DExpander expander_pass;
|
||||||
|
ASSERT_TRUE(expander_pass.Run(module.get()).ValueOrDie());
|
||||||
|
root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
|
||||||
|
const HloInstruction* new_convolution = root->operand(0);
|
||||||
|
// Check that the new convolution has 2 spatial dimensions.
|
||||||
|
EXPECT_EQ(new_convolution->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(new_convolution->window().dimensions_size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(Convolution4DExpanderTest, ConvertTo3DConvolution) {
|
||||||
|
string hlo_string = R"(HloModule convolution_4d_fp32
|
||||||
|
|
||||||
|
ENTRY convolution_computation {
|
||||||
|
input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0)
|
||||||
|
kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1)
|
||||||
|
ROOT conv = f32[15,1,9,2,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4 pad=0_0x0_0x1_0x0_0}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto computation = module->entry_computation();
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(root->window().dimensions_size(), 4);
|
||||||
|
Convolution4DExpander expander_pass;
|
||||||
|
ASSERT_TRUE(expander_pass.Run(module.get()).ValueOrDie());
|
||||||
|
root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
|
||||||
|
const HloInstruction* new_convolution = root->operand(0);
|
||||||
|
// Check that the new convolution has 3 spatial dimensions. Note that although
|
||||||
|
// there are 2 input dimensions of size 1, one of them is not trivial because
|
||||||
|
// with the low padding the output dimension will be 2.
|
||||||
|
EXPECT_EQ(new_convolution->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(new_convolution->window().dimensions_size(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(Convolution4DExpanderTest, ConvertTo0DConvolution) {
|
||||||
|
string hlo_string = R"(HloModule convolution_4d_fp32
|
||||||
|
|
||||||
|
ENTRY convolution_computation {
|
||||||
|
input = f32[1,1,1,1,5,20]{5,4,3,2,1,0} parameter(0)
|
||||||
|
kernel = f32[20,1,1,1,1,15]{5,4,3,2,1,0} parameter(1)
|
||||||
|
ROOT conv = f32[15,1,1,1,1,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x1x1x1}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto computation = module->entry_computation();
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(root->window().dimensions_size(), 4);
|
||||||
|
Convolution4DExpander expander_pass;
|
||||||
|
ASSERT_TRUE(expander_pass.Run(module.get()).ValueOrDie());
|
||||||
|
root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
|
||||||
|
const HloInstruction* new_convolution = root->operand(0);
|
||||||
|
// Check that the new convolution has 0 spatial dimensions.
|
||||||
|
EXPECT_EQ(new_convolution->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(new_convolution->window().dimensions_size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(Convolution4DExpanderTest, DontConvert3DConvolution) {
|
||||||
|
string hlo_string = R"(HloModule convolution_4d_fp32
|
||||||
|
|
||||||
|
ENTRY convolution_computation {
|
||||||
|
input = f32[1,1,1,5,20]{4,3,2,1,0} parameter(0)
|
||||||
|
kernel = f32[20,1,1,1,15]{4,3,2,1,0} parameter(1)
|
||||||
|
ROOT conv = f32[15,1,1,1,5]{4,3,2,1,0} convolution(input, kernel), dim_labels=012bf_i012o->f012b, window={size=1x1x1}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto computation = module->entry_computation();
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(root->window().dimensions_size(), 3);
|
||||||
|
Convolution4DExpander expander_pass;
|
||||||
|
ASSERT_FALSE(expander_pass.Run(module.get()).ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(Convolution4DExpanderTest, DontConvertIfNoTrivialDimensionAvailable) {
|
||||||
|
string hlo_string = R"(HloModule convolution_4d_fp32
|
||||||
|
|
||||||
|
ENTRY convolution_computation {
|
||||||
|
input = f32[2,10,2,10,5,20]{5,4,3,2,1,0} parameter(0)
|
||||||
|
kernel = f32[20,2,2,2,4,15]{5,4,3,2,1,0} parameter(1)
|
||||||
|
ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=2x2x2x4}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto computation = module->entry_computation();
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(root->window().dimensions_size(), 4);
|
||||||
|
Convolution4DExpander expander_pass;
|
||||||
|
ASSERT_FALSE(expander_pass.Run(module.get()).ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(Convolution4DExpanderTest, DontConvertIfPaddingIsNonzero) {
|
||||||
|
string hlo_string = R"(HloModule convolution_4d_fp32
|
||||||
|
|
||||||
|
ENTRY convolution_computation {
|
||||||
|
input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0)
|
||||||
|
kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1)
|
||||||
|
ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4 stride=2x1x2x1 pad=1_0x0_0x0_1x0_0}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto computation = module->entry_computation();
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
|
||||||
|
EXPECT_EQ(root->window().dimensions_size(), 4);
|
||||||
|
Convolution4DExpander expander_pass;
|
||||||
|
// Although we have two spatial input dimensions of size 1, and the
|
||||||
|
// corresponding spatial output dimensions are also of size 1, these
|
||||||
|
// dimensions are not trivial because they involve lower and/or higher padding
|
||||||
|
// plus stride.
|
||||||
|
ASSERT_FALSE(expander_pass.Run(module.get()).ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
@ -1122,6 +1122,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||||
"//tensorflow/compiler/xla/service:call_inliner",
|
"//tensorflow/compiler/xla/service:call_inliner",
|
||||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||||
|
"//tensorflow/compiler/xla/service:convolution_4d_expander",
|
||||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||||
"//tensorflow/compiler/xla/service:depthwise_convolution_converter",
|
"//tensorflow/compiler/xla/service:depthwise_convolution_converter",
|
||||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||||
#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h"
|
#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h"
|
||||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||||
@ -140,6 +141,8 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
|
|
||||||
pipeline.AddPass<DotDecomposer>();
|
pipeline.AddPass<DotDecomposer>();
|
||||||
|
|
||||||
|
pipeline.AddPass<Convolution4DExpander>();
|
||||||
|
|
||||||
auto cost_model = [](HloInstruction*) {
|
auto cost_model = [](HloInstruction*) {
|
||||||
// We need a cost model for GPUs. Currently, do nothing.
|
// We need a cost model for GPUs. Currently, do nothing.
|
||||||
return false;
|
return false;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user