Enable batch group count test on all backends.

PiperOrigin-RevId: 330537531
Change-Id: Ibcf082c6198648d2ed2cce0ab5d703bd5acfa990
This commit is contained in:
A. Unique TensorFlower 2020-09-08 10:29:43 -07:00 committed by TensorFlower Gardener
parent b23d35059b
commit 62403106a5
2 changed files with 3 additions and 30 deletions

View File

@ -380,11 +380,6 @@ xla_test(
name = "conv_depthwise_backprop_filter_test",
timeout = "long",
srcs = ["conv_depthwise_backprop_filter_test.cc"],
# these backends do not natively handle batch group counts.
disabled_backends = [
"gpu",
"cpu",
],
shard_count = 40,
deps = [
":test_macros_header",

View File

@ -177,15 +177,9 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
#endif
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
spec, use_bfloat16, /*scheduled=*/true);
spec, use_bfloat16, /*scheduled=*/false);
EXPECT_TRUE(RunAndCompareNoHloPasses(
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
BFloat16MixedPrecisionRemoval remover;
TF_RETURN_IF_ERROR(remover.Run(module).status());
Despecializer despecializer;
return despecializer.Run(module).status();
}));
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
}
INSTANTIATE_TEST_CASE_P(
@ -196,25 +190,9 @@ INSTANTIATE_TEST_CASE_P(
::testing::Bool()),
BatchGroupedConvolution2DTestDataToString);
XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) {
const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam());
bool use_bfloat16 = ::testing::get<1>(GetParam());
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
if (use_bfloat16) {
return;
}
#endif
const string hlo_text =
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
}
INSTANTIATE_TEST_CASE_P(
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
BatchGroupedConvolution2DDepthTest,
BatchGroupedConvolution2DTest,
::testing::Combine(
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
::testing::Bool()),