From 7fa2309477c0dcfa4288cb676890e13b0bf532fc Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 7 Jan 2021 05:03:37 -0800 Subject: [PATCH] [XLA:CPU] Move bf16->f32 conversion before fusion Otherwise the conversion pass will change fusions, confusing the pattern matching in the emitter. PiperOrigin-RevId: 350537358 Change-Id: I8c0ddd41fdefed8e10cde47cb1f7d04b2cc73e06 --- .../compiler/xla/service/cpu/cpu_compiler.cc | 6 ++++-- .../compiler/xla/tests/dot_operation_test.cc | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6b801585fef..bc6217bf10e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -295,6 +295,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(/*single_call_site=*/true); pipeline.AddPass(); pipeline.AddPass(); + // Convert BF16 operations to F32 operations so that the CPU backend can + // support BF16 operations without directly implementing a BF16 lowering for + // most ops. + pipeline.AddPass(BF16, F32); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); @@ -404,8 +408,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddPass(/*is_layout_sensitive=*/true); } - pipeline.AddPass(BF16, F32); - // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6baf08abcad..622139c497b 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1614,6 +1614,23 @@ ENTRY MatrixVectorComplex { EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) { + absl::string_view hlo_string = + R"( +HloModule MatrixVectorBF16 + +ENTRY MatrixVectorBF16 { + p0 = bf16[128] parameter(0) + p1 = bf16[128,256] parameter(1) + p2 = bf16[256] parameter(2) + dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT add = bf16[256] add(dot, p2) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + // Regression test for b/138155357, where we were incorrectly creating a dot-add // fusion where the dot had a batch dimension. This isn't supported on the CPU // backend.