[MLIR] Add kLog and kRsqrt to lhlo/hlo emitters.

PiperOrigin-RevId: 294413426
Change-Id: Ie1c6aaa3faf3f103a388bc0b1707c62a7fe4b4f2
This commit is contained in:
Adrian Kuegel 2020-02-11 04:43:10 -08:00 committed by TensorFlower Gardener
parent 55c6d2eefb
commit f752cf8155
3 changed files with 40 additions and 0 deletions

View File

@ -64,6 +64,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)}; return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
case HloOpcode::kExp: case HloOpcode::kExp:
return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)}; return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
case HloOpcode::kLog:
return {func_builder.create<hlo::LogOp>(loc, rets, args, attrs)};
case HloOpcode::kMaximum: case HloOpcode::kMaximum:
return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)}; return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
case HloOpcode::kMinimum: case HloOpcode::kMinimum:
@ -74,6 +76,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)}; return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
case HloOpcode::kRemainder: case HloOpcode::kRemainder:
return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)}; return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
case HloOpcode::kRsqrt:
return {func_builder.create<hlo::RsqrtOp>(loc, rets, args, attrs)};
case HloOpcode::kSelect: case HloOpcode::kSelect:
return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)}; return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
case HloOpcode::kSign: case HloOpcode::kSign:

View File

@ -86,6 +86,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
case HloOpcode::kExp: case HloOpcode::kExp:
func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs); func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
break; break;
case HloOpcode::kLog:
func_builder.create<lhlo::LogOp>(loc, rets, args, attrs);
break;
case HloOpcode::kMaximum: case HloOpcode::kMaximum:
func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs); func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
break; break;
@ -101,6 +104,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
case HloOpcode::kRemainder: case HloOpcode::kRemainder:
func_builder.create<lhlo::RemOp>(loc, rets, args, attrs); func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
break; break;
case HloOpcode::kRsqrt:
func_builder.create<lhlo::RsqrtOp>(loc, rets, args, attrs);
break;
case HloOpcode::kSelect: case HloOpcode::kSelect:
func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs); func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
break; break;

View File

@ -130,6 +130,21 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] {
)"); )");
} }
TEST_F(LhloGenTest, Log) {
CompileAndVerifyIr(R"(
HloModule Log
ENTRY %Log (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x)
})",
R"(
;CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
}
TEST_F(LhloGenTest, AddInGPUDialect) { TEST_F(LhloGenTest, AddInGPUDialect) {
CompileAndVerifyIr(R"( CompileAndVerifyIr(R"(
HloModule Add HloModule Add
@ -478,6 +493,21 @@ ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
)"); )");
} }
TEST_F(LhloGenTest, Rsqrt) {
CompileAndVerifyIr(R"(
HloModule Rsqrt
ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] {
%x = f32[2,2]{1,0} parameter(0)
ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x)
})",
R"(
;CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
;CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
;CHECK: }
)");
}
TEST_F(LhloGenTest, Sign) { TEST_F(LhloGenTest, Sign) {
CompileAndVerifyIr(R"( CompileAndVerifyIr(R"(
HloModule Sign HloModule Sign