Enable batch group count test on all backends.
PiperOrigin-RevId: 330537531 Change-Id: Ibcf082c6198648d2ed2cce0ab5d703bd5acfa990
This commit is contained in:
parent
b23d35059b
commit
62403106a5
@ -380,11 +380,6 @@ xla_test(
|
||||
name = "conv_depthwise_backprop_filter_test",
|
||||
timeout = "long",
|
||||
srcs = ["conv_depthwise_backprop_filter_test.cc"],
|
||||
# these backends do not natively handle batch group counts.
|
||||
disabled_backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
shard_count = 40,
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
|
@ -177,15 +177,9 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
|
||||
#endif
|
||||
|
||||
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
|
||||
spec, use_bfloat16, /*scheduled=*/true);
|
||||
spec, use_bfloat16, /*scheduled=*/false);
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
||||
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
|
||||
BFloat16MixedPrecisionRemoval remover;
|
||||
TF_RETURN_IF_ERROR(remover.Run(module).status());
|
||||
Despecializer despecializer;
|
||||
return despecializer.Run(module).status();
|
||||
}));
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
@ -196,25 +190,9 @@ INSTANTIATE_TEST_CASE_P(
|
||||
::testing::Bool()),
|
||||
BatchGroupedConvolution2DTestDataToString);
|
||||
|
||||
XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) {
|
||||
const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam());
|
||||
bool use_bfloat16 = ::testing::get<1>(GetParam());
|
||||
|
||||
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
|
||||
if (use_bfloat16) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const string hlo_text =
|
||||
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
|
||||
BatchGroupedConvolution2DDepthTest,
|
||||
BatchGroupedConvolution2DTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
|
||||
::testing::Bool()),
|
||||
|
Loading…
x
Reference in New Issue
Block a user