[MLIR] Add kLog and kRsqrt to lhlo/hlo emitters.
PiperOrigin-RevId: 294413426 Change-Id: Ie1c6aaa3faf3f103a388bc0b1707c62a7fe4b4f2
This commit is contained in:
parent
55c6d2eefb
commit
f752cf8155
@ -64,6 +64,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
|
||||
return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kExp:
|
||||
return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kLog:
|
||||
return {func_builder.create<hlo::LogOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kMaximum:
|
||||
return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kMinimum:
|
||||
@ -74,6 +76,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
|
||||
return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kRemainder:
|
||||
return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kRsqrt:
|
||||
return {func_builder.create<hlo::RsqrtOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kSelect:
|
||||
return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
|
||||
case HloOpcode::kSign:
|
||||
|
@ -86,6 +86,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
|
||||
case HloOpcode::kExp:
|
||||
func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
|
||||
break;
|
||||
case HloOpcode::kLog:
|
||||
func_builder.create<lhlo::LogOp>(loc, rets, args, attrs);
|
||||
break;
|
||||
case HloOpcode::kMaximum:
|
||||
func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
|
||||
break;
|
||||
@ -101,6 +104,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
|
||||
case HloOpcode::kRemainder:
|
||||
func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
|
||||
break;
|
||||
case HloOpcode::kRsqrt:
|
||||
func_builder.create<lhlo::RsqrtOp>(loc, rets, args, attrs);
|
||||
break;
|
||||
case HloOpcode::kSelect:
|
||||
func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
|
||||
break;
|
||||
|
@ -130,6 +130,21 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] {
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, Log) {
|
||||
CompileAndVerifyIr(R"(
|
||||
HloModule Log
|
||||
|
||||
ENTRY %Log (x: f32[2,2]) -> f32[2,2] {
|
||||
%x = f32[2,2]{1,0} parameter(0)
|
||||
ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x)
|
||||
})",
|
||||
R"(
|
||||
;CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
|
||||
;CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
;CHECK: }
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, AddInGPUDialect) {
|
||||
CompileAndVerifyIr(R"(
|
||||
HloModule Add
|
||||
@ -478,6 +493,21 @@ ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, Rsqrt) {
|
||||
CompileAndVerifyIr(R"(
|
||||
HloModule Rsqrt
|
||||
|
||||
ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] {
|
||||
%x = f32[2,2]{1,0} parameter(0)
|
||||
ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x)
|
||||
})",
|
||||
R"(
|
||||
;CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
|
||||
;CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
;CHECK: }
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, Sign) {
|
||||
CompileAndVerifyIr(R"(
|
||||
HloModule Sign
|
||||
|
Loading…
Reference in New Issue
Block a user