diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 23135f4d6f9..627bb9ff7a5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -380,11 +380,6 @@ xla_test( name = "conv_depthwise_backprop_filter_test", timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], - # these backends do not natively handle batch group counts. - disabled_backends = [ - "gpu", - "cpu", - ], shard_count = 40, deps = [ ":test_macros_header", diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc index 871c55ec9b8..4a7070a32f3 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -177,15 +177,9 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { #endif const string hlo_text = BuildHloTextBatchGroupedConvolution2D( - spec, use_bfloat16, /*scheduled=*/true); + spec, use_bfloat16, /*scheduled=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01})); } INSTANTIATE_TEST_CASE_P( @@ -196,25 +190,9 @@ INSTANTIATE_TEST_CASE_P( ::testing::Bool()), BatchGroupedConvolution2DTestDataToString); -XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) { - const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); - bool use_bfloat16 = ::testing::get<1>(GetParam()); - -#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16 - if (use_bfloat16) { - return; - } -#endif - - const string hlo_text = - BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01})); -} - INSTANTIATE_TEST_CASE_P( BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices, - BatchGroupedConvolution2DDepthTest, + BatchGroupedConvolution2DTest, ::testing::Combine( ::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)), ::testing::Bool()),