From 675799d1056fd05acc61ba85bdf8db9b86be0139 Mon Sep 17 00:00:00 2001 From: Stephan Herhut <herhut@google.com> Date: Wed, 12 Feb 2020 06:27:31 -0800 Subject: [PATCH] Be more selective when disabling mlir_gpu_lhlo_gen_test. PiperOrigin-RevId: 294653870 Change-Id: I6e0dcf432f630550e49b776390a738c013a1e891 --- .../compiler/xla/service/mlir_gpu/tests/BUILD | 22 ++--- .../mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc | 84 ++++++++++--------- 2 files changed, 54 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index 2fd9154d4d4..c0b90910b01 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -18,14 +18,14 @@ package_group( ], ) -# tf_cc_test( -# name = "mlir_gpu_lhlo_gen_test", -# srcs = ["mlir_gpu_lhlo_gen_test.cc"], -# tags = tf_cuda_tests_tags() + ["no_rocm"], -# deps = [ -# "//tensorflow/compiler/xla/service:gpu_plugin_mlir", -# "//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base", -# "//tensorflow/core:test_main", -# "//tensorflow/stream_executor/lib", -# ], -# ) +tf_cc_test( + name = "mlir_gpu_lhlo_gen_test", + srcs = ["mlir_gpu_lhlo_gen_test.cc"], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + "//tensorflow/compiler/xla/service:gpu_plugin_mlir", + "//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + ], +) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 5de07e2bc3c..9a23ff8748e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -240,7 +240,8 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { LoweringStage::KERNEL); } -TEST_F(LhloGenTest, AddMultiply) { +// TODO(b/149302060) Reenable once fusion is fixed. +TEST_F(LhloGenTest, DISABLED_AddMultiply) { CompileAndVerifyIr(R"( HloModule AddMultiply @@ -257,7 +258,7 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { ;CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]] ;CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]] ;CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]] -;CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]] +;CHECK: %[[ADD:.*]] = xla_hlo.add %[[REF1]], %[[REF2]] ;CHECK: %[[MUL:.*]] = xla_hlo.mul %[[ADD]], %[[REF0]] ;CHECK: tensor_store %[[MUL]], %[[RESULT]] ;CHECK: "xla_lhlo.terminator"() @@ -265,7 +266,8 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { )"); } -TEST_F(LhloGenTest, IotaAddMultiply) { +// TODO(b/149302060) Reenable once fusion is fixed. +TEST_F(LhloGenTest, DISABLED_IotaAddMultiply) { CompileAndVerifyIr(R"( HloModule AddMultiply @@ -315,44 +317,44 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { } // TODO(b/137624192): Reenable once we can fuse reductions. -// TEST_F(LhloGenTest, FusedReduce) { -// CompileAndVerifyIr(R"( -// HloModule FusedReduce -// -// %add (x: f32[], y: f32[]) -> f32[] { -// %x = f32[] parameter(0) -// %y = f32[] parameter(1) -// ROOT %add = f32[] add(f32[] %x, f32[] %y) -// } -// -// %fused_computation (param: f32[100,10]) -> f32[10] { -// %param = f32[100,10] parameter(0) -// %constant = f32[] constant(0) -// ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant), -// dimensions={0}, to_apply=%add -// } -// -// ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { -// %x = f32[100,10] parameter(0) -// ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput, -// calls=%fused_computation -// } -// )", -// R"( -// ;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) -// ;CHECK: "xla_lhlo.fusion"() ( { -// ;CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]] -// ;CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00> -// ;CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( { -// ;CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]]) -// ;CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]] -// ;CHECK: "xla_hlo.return"(%[[ADD]]) -// ;CHECK: }) -// ;CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]] -// ;CHECK: "xla_lhlo.terminator"() -// ;CHECK-NEXT: }) -// )"); -// } +TEST_F(LhloGenTest, DISABLED_FusedReduce) { + CompileAndVerifyIr(R"( +HloModule FusedReduce + +%add (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +%fused_computation (param: f32[100,10]) -> f32[10] { + %param = f32[100,10] parameter(0) + %constant = f32[] constant(0) + ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant), + dimensions={0}, to_apply=%add +} + +ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { + %x = f32[100,10] parameter(0) + ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput, + calls=%fused_computation +} +)", + R"( +;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) +;CHECK: "xla_lhlo.fusion"() ( { +;CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]] +;CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00> +;CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( { +;CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]]) +;CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]] +;CHECK: "xla_hlo.return"(%[[ADD]]) +;CHECK: }) +;CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]] +;CHECK: "xla_lhlo.terminator"() +;CHECK-NEXT: }) + )"); +} TEST_F(LhloGenTest, Broadcast) { CompileAndVerifyIr(R"(