diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index fea0885d21e..bb67305c344 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -64,6 +64,8 @@ StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kExp: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kLog: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kMaximum: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kMinimum: @@ -74,6 +76,8 @@ StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kRemainder: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kRsqrt: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kSelect: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kSign: diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 01e829ae964..13009992ab5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -86,6 +86,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kExp: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kLog: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kMaximum: func_builder.create(loc, rets, args, attrs); break; @@ -101,6 +104,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kRemainder: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kRsqrt: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kSelect: func_builder.create(loc, rets, args, attrs); break; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index c0c4bd6f67e..5ce65423bb7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -130,6 +130,21 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { )"); } +TEST_F(LhloGenTest, Log) { + CompileAndVerifyIr(R"( +HloModule Log + +ENTRY %Log (x: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x) +})", + R"( +;CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + TEST_F(LhloGenTest, AddInGPUDialect) { CompileAndVerifyIr(R"( HloModule Add @@ -478,6 +493,21 @@ ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { )"); } +TEST_F(LhloGenTest, Rsqrt) { + CompileAndVerifyIr(R"( +HloModule Rsqrt + +ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x) +})", + R"( +;CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { +;CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +;CHECK: } + )"); +} + TEST_F(LhloGenTest, Sign) { CompileAndVerifyIr(R"( HloModule Sign