[MLIR] Add kLog and kRsqrt to lhlo/hlo emitters.
PiperOrigin-RevId: 294413426 Change-Id: Ie1c6aaa3faf3f103a388bc0b1707c62a7fe4b4f2
This commit is contained in:
parent
55c6d2eefb
commit
f752cf8155
@ -64,6 +64,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
|
|||||||
return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
|
return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
|
return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
|
||||||
|
case HloOpcode::kLog:
|
||||||
|
return {func_builder.create<hlo::LogOp>(loc, rets, args, attrs)};
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
|
return {func_builder.create<hlo::MaxOp>(loc, rets, args, attrs)};
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
@ -74,6 +76,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
|
|||||||
return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
|
return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
|
return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
|
||||||
|
case HloOpcode::kRsqrt:
|
||||||
|
return {func_builder.create<hlo::RsqrtOp>(loc, rets, args, attrs)};
|
||||||
case HloOpcode::kSelect:
|
case HloOpcode::kSelect:
|
||||||
return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
|
return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
|
@ -86,6 +86,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
|
|||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
|
func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
|
||||||
break;
|
break;
|
||||||
|
case HloOpcode::kLog:
|
||||||
|
func_builder.create<lhlo::LogOp>(loc, rets, args, attrs);
|
||||||
|
break;
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
|
func_builder.create<lhlo::MaxOp>(loc, rets, args, attrs);
|
||||||
break;
|
break;
|
||||||
@ -101,6 +104,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
|
|||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
|
func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
|
||||||
break;
|
break;
|
||||||
|
case HloOpcode::kRsqrt:
|
||||||
|
func_builder.create<lhlo::RsqrtOp>(loc, rets, args, attrs);
|
||||||
|
break;
|
||||||
case HloOpcode::kSelect:
|
case HloOpcode::kSelect:
|
||||||
func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
|
func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
|
||||||
break;
|
break;
|
||||||
|
@ -130,6 +130,21 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] {
|
|||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LhloGenTest, Log) {
|
||||||
|
CompileAndVerifyIr(R"(
|
||||||
|
HloModule Log
|
||||||
|
|
||||||
|
ENTRY %Log (x: f32[2,2]) -> f32[2,2] {
|
||||||
|
%x = f32[2,2]{1,0} parameter(0)
|
||||||
|
ROOT %log = f32[2,2]{1,0} log(f32[2,2]{1,0} %x)
|
||||||
|
})",
|
||||||
|
R"(
|
||||||
|
;CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
|
||||||
|
;CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||||
|
;CHECK: }
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LhloGenTest, AddInGPUDialect) {
|
TEST_F(LhloGenTest, AddInGPUDialect) {
|
||||||
CompileAndVerifyIr(R"(
|
CompileAndVerifyIr(R"(
|
||||||
HloModule Add
|
HloModule Add
|
||||||
@ -478,6 +493,21 @@ ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
|
|||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LhloGenTest, Rsqrt) {
|
||||||
|
CompileAndVerifyIr(R"(
|
||||||
|
HloModule Rsqrt
|
||||||
|
|
||||||
|
ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] {
|
||||||
|
%x = f32[2,2]{1,0} parameter(0)
|
||||||
|
ROOT %rsqrt = f32[2,2]{1,0} rsqrt(f32[2,2]{1,0} %x)
|
||||||
|
})",
|
||||||
|
R"(
|
||||||
|
;CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
|
||||||
|
;CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||||
|
;CHECK: }
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LhloGenTest, Sign) {
|
TEST_F(LhloGenTest, Sign) {
|
||||||
CompileAndVerifyIr(R"(
|
CompileAndVerifyIr(R"(
|
||||||
HloModule Sign
|
HloModule Sign
|
||||||
|
Loading…
Reference in New Issue
Block a user