Enable batch group count test on all backends.
PiperOrigin-RevId: 330537531 Change-Id: Ibcf082c6198648d2ed2cce0ab5d703bd5acfa990
This commit is contained in:
parent
b23d35059b
commit
62403106a5
@ -380,11 +380,6 @@ xla_test(
|
|||||||
name = "conv_depthwise_backprop_filter_test",
|
name = "conv_depthwise_backprop_filter_test",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
srcs = ["conv_depthwise_backprop_filter_test.cc"],
|
srcs = ["conv_depthwise_backprop_filter_test.cc"],
|
||||||
# these backends do not natively handle batch group counts.
|
|
||||||
disabled_backends = [
|
|
||||||
"gpu",
|
|
||||||
"cpu",
|
|
||||||
],
|
|
||||||
shard_count = 40,
|
shard_count = 40,
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
|
@ -177,15 +177,9 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
|
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
|
||||||
spec, use_bfloat16, /*scheduled=*/true);
|
spec, use_bfloat16, /*scheduled=*/false);
|
||||||
|
|
||||||
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
|
||||||
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
|
|
||||||
BFloat16MixedPrecisionRemoval remover;
|
|
||||||
TF_RETURN_IF_ERROR(remover.Run(module).status());
|
|
||||||
Despecializer despecializer;
|
|
||||||
return despecializer.Run(module).status();
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
@ -196,25 +190,9 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
::testing::Bool()),
|
::testing::Bool()),
|
||||||
BatchGroupedConvolution2DTestDataToString);
|
BatchGroupedConvolution2DTestDataToString);
|
||||||
|
|
||||||
XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) {
|
|
||||||
const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam());
|
|
||||||
bool use_bfloat16 = ::testing::get<1>(GetParam());
|
|
||||||
|
|
||||||
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
|
|
||||||
if (use_bfloat16) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const string hlo_text =
|
|
||||||
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
|
|
||||||
|
|
||||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
|
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
|
||||||
BatchGroupedConvolution2DDepthTest,
|
BatchGroupedConvolution2DTest,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
|
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
|
||||||
::testing::Bool()),
|
::testing::Bool()),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user