[XLA:CPU] Move bf16->f32 conversion before fusion
Otherwise the conversion pass will change fusions, confusing the pattern matching in the emitter. PiperOrigin-RevId: 350537358 Change-Id: I8c0ddd41fdefed8e10cde47cb1f7d04b2cc73e06
This commit is contained in:
parent
c21f7e2db5
commit
7fa2309477
tensorflow/compiler/xla
@ -295,6 +295,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
|
||||
pipeline.AddPass<BatchDotSimplification>();
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
// Convert BF16 operations to F32 operations so that the CPU backend can
|
||||
// support BF16 operations without directly implementing a BF16 lowering for
|
||||
// most ops.
|
||||
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
|
||||
// After canonicalization, there may be more batch dots that can be
|
||||
// simplified.
|
||||
pipeline.AddPass<BatchDotSimplification>();
|
||||
@ -404,8 +408,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
|
||||
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
|
||||
}
|
||||
|
||||
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
|
||||
|
||||
// Outline ops in the entry computation into calls to subcomputations.
|
||||
const int max_parallelism =
|
||||
module->config().intra_op_parallelism_threads() > 0
|
||||
|
@ -1614,6 +1614,23 @@ ENTRY MatrixVectorComplex {
|
||||
EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule MatrixVectorBF16
|
||||
|
||||
ENTRY MatrixVectorBF16 {
|
||||
p0 = bf16[128] parameter(0)
|
||||
p1 = bf16[128,256] parameter(1)
|
||||
p2 = bf16[256] parameter(2)
|
||||
dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
|
||||
ROOT add = bf16[256] add(dot, p2)
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
|
||||
}
|
||||
|
||||
// Regression test for b/138155357, where we were incorrectly creating a dot-add
|
||||
// fusion where the dot had a batch dimension. This isn't supported on the CPU
|
||||
// backend.
|
||||
|
Loading…
Reference in New Issue
Block a user