Don't match to backward input convolution in unsupported case.

For grouped convolutions, we assume that in the backward input convolution
case, the input and output feature dimensions of the kernel are adjacent.
If that is not the case, don't treat it as backward input convolution.

PiperOrigin-RevId: 339029980
Change-Id: If0b4f8a64cd3ca73e9648358d8a579ce262b27c9
This commit is contained in:
Adrian Kuegel 2020-10-26 07:03:26 -07:00 committed by TensorFlower Gardener
parent 632bf67c1b
commit edfc5938ba
3 changed files with 38 additions and 7 deletions

View File

@ -536,11 +536,12 @@ MatchBackwardInput(HloInstruction* conv) {
// 'kernel_output_feature_dimension' by 'feature_group_count'. // 'kernel_output_feature_dimension' by 'feature_group_count'.
int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
// The following code assumes that input_feature_dimension and
// output_feature_dimension are adjacent.
if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
return no_match_result;
}
// In the backward convolution case, the spatial dimensions become the
// feature dimensions, and we are guaranteed that the spatial dimensions are
// adjacent.
CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
int64 input_features = rhs->shape().dimensions(input_feature_dimension); int64 input_features = rhs->shape().dimensions(input_feature_dimension);
int64 output_features = rhs->shape().dimensions(output_feature_dimension); int64 output_features = rhs->shape().dimensions(output_feature_dimension);

View File

@ -413,16 +413,18 @@ xla_test(
], ],
shard_count = 50, shard_count = 50,
deps = [ deps = [
":client_library_test_base",
":hlo_test_base",
":test_macros_header", ":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:bfloat16_normalization", "//tensorflow/compiler/xla/service:bfloat16_normalization",
"//tensorflow/compiler/xla/service:despecializer", "//tensorflow/compiler/xla/service:despecializer",
"//tensorflow/compiler/xla/tests:client_library_test_base", "@com_google_absl//absl/algorithm:container",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/execution_options_util.h"
@ -23,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace xla { namespace xla {
namespace { namespace {
@ -248,5 +253,28 @@ INSTANTIATE_TEST_CASE_P(
::testing::Bool()), ::testing::Bool()),
GroupedConvolution2DTestDataToString); GroupedConvolution2DTestDataToString);
using GroupedConvolutionTest = HloTestBase;
XLA_TEST_F(GroupedConvolutionTest, BackwardInputConvolution) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule convolution_module
ENTRY convolution {
p1 = f32[2,1,1,1]{3,2,1,0} parameter(0)
p2 = f32[2,4,4,1]{3,2,1,0} parameter(1)
reverse = f32[2,4,4,1]{3,2,1,0} reverse(p2), dimensions={1,2}
ROOT convolution = f32[2,4,4,1]{3,2,1,0} convolution(p1, reverse), window={size=4x4 pad=3_3x3_3}, dim_labels=fb01_o01i->f01b, feature_group_count=2
}
)")
.ValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get()));
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
[](const Literal& literal) { return &const_cast<Literal&>(literal); });
EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs,
ErrorSpec{0.01, 0.01}));
}
} // namespace } // namespace
} // namespace xla } // namespace xla