[XLA:CPU] Move bf16->f32 conversion before fusion

Otherwise the conversion pass will change fusions, confusing the pattern
matching in the emitter.

PiperOrigin-RevId: 350537358
Change-Id: I8c0ddd41fdefed8e10cde47cb1f7d04b2cc73e06
This commit is contained in:
Benjamin Kramer 2021-01-07 05:03:37 -08:00 committed by TensorFlower Gardener
parent c21f7e2db5
commit 7fa2309477
2 changed files with 21 additions and 2 deletions
tensorflow/compiler/xla

View File

@ -295,6 +295,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
// Convert BF16 operations to F32 operations so that the CPU backend can
// support BF16 operations without directly implementing a BF16 lowering for
// most ops.
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// After canonicalization, there may be more batch dots that can be
// simplified.
pipeline.AddPass<BatchDotSimplification>();
@ -404,8 +408,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0

View File

@ -1614,6 +1614,23 @@ ENTRY MatrixVectorComplex {
EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3}));
}
XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) {
absl::string_view hlo_string =
R"(
HloModule MatrixVectorBF16
ENTRY MatrixVectorBF16 {
p0 = bf16[128] parameter(0)
p1 = bf16[128,256] parameter(1)
p2 = bf16[256] parameter(2)
dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
ROOT add = bf16[256] add(dot, p2)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
}
// Regression test for b/138155357, where we were incorrectly creating a dot-add
// fusion where the dot had a batch dimension. This isn't supported on the CPU
// backend.