Don't match to backward input convolution in unsupported case.
For grouped convolutions, we assume that in the backward input convolution case, the input and output feature dimensions of the kernel are adjacent. If that is not the case, don't treat it as backward input convolution. PiperOrigin-RevId: 339029980 Change-Id: If0b4f8a64cd3ca73e9648358d8a579ce262b27c9
This commit is contained in:
parent
632bf67c1b
commit
edfc5938ba
@ -536,11 +536,12 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
// 'kernel_output_feature_dimension' by 'feature_group_count'.
|
||||
int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
|
||||
int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
|
||||
// The following code assumes that input_feature_dimension and
|
||||
// output_feature_dimension are adjacent.
|
||||
if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// In the backward convolution case, the spatial dimensions become the
|
||||
// feature dimensions, and we are guaranteed that the spatial dimensions are
|
||||
// adjacent.
|
||||
CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
|
||||
int64 input_features = rhs->shape().dimensions(input_feature_dimension);
|
||||
int64 output_features = rhs->shape().dimensions(output_feature_dimension);
|
||||
|
||||
|
@ -413,16 +413,18 @@ xla_test(
|
||||
],
|
||||
shard_count = 50,
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:bfloat16_normalization",
|
||||
"//tensorflow/compiler/xla/service:despecializer",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
@ -23,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -248,5 +253,28 @@ INSTANTIATE_TEST_CASE_P(
|
||||
::testing::Bool()),
|
||||
GroupedConvolution2DTestDataToString);
|
||||
|
||||
using GroupedConvolutionTest = HloTestBase;
|
||||
|
||||
XLA_TEST_F(GroupedConvolutionTest, BackwardInputConvolution) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule convolution_module
|
||||
|
||||
ENTRY convolution {
|
||||
p1 = f32[2,1,1,1]{3,2,1,0} parameter(0)
|
||||
p2 = f32[2,4,4,1]{3,2,1,0} parameter(1)
|
||||
reverse = f32[2,4,4,1]{3,2,1,0} reverse(p2), dimensions={1,2}
|
||||
ROOT convolution = f32[2,4,4,1]{3,2,1,0} convolution(p1, reverse), window={size=4x4 pad=3_3x3_3}, dim_labels=fb01_o01i->f01b, feature_group_count=2
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get()));
|
||||
std::vector<Literal*> fake_argument_ptrs;
|
||||
absl::c_transform(
|
||||
fake_arguments, std::back_inserter(fake_argument_ptrs),
|
||||
[](const Literal& literal) { return &const_cast<Literal&>(literal); });
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs,
|
||||
ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user