[MLIR][XLA] Add complex, imag, and real to lhlo/hlo emitters.

PiperOrigin-RevId: 307840887
Change-Id: I67ac19dc231e4e1fa7d1c68880a9455e92d7dada
This commit is contained in:
A. Unique TensorFlower 2020-04-22 10:09:42 -07:00 committed by TensorFlower Gardener
parent c105190702
commit 65fd3b702b
5 changed files with 49 additions and 0 deletions

View File

@ -58,6 +58,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
return {func_builder.create<hlo::AndOp>(loc, rets, args, attrs)};
case HloOpcode::kCeil:
return {func_builder.create<hlo::CeilOp>(loc, rets, args, attrs)};
case HloOpcode::kComplex:
return {func_builder.create<hlo::ComplexOp>(loc, rets, args, attrs)};
case HloOpcode::kCopy:
return {func_builder.create<hlo::CopyOp>(loc, rets, args, attrs)};
case HloOpcode::kCos:
@ -66,6 +68,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
case HloOpcode::kExp:
return {func_builder.create<hlo::ExpOp>(loc, rets, args, attrs)};
case HloOpcode::kImag:
return {func_builder.create<hlo::ImagOp>(loc, rets, args, attrs)};
case HloOpcode::kLog:
return {func_builder.create<hlo::LogOp>(loc, rets, args, attrs)};
case HloOpcode::kMaximum:
@ -76,6 +80,8 @@ StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
case HloOpcode::kNegate:
return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
case HloOpcode::kReal:
return {func_builder.create<hlo::RealOp>(loc, rets, args, attrs)};
case HloOpcode::kRemainder:
return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
case HloOpcode::kRsqrt:

View File

@ -77,6 +77,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
case HloOpcode::kCeil:
func_builder.create<lhlo::CeilOp>(loc, rets, args, attrs);
break;
case HloOpcode::kComplex:
func_builder.create<lhlo::ComplexOp>(loc, rets, args, attrs);
break;
case HloOpcode::kCopy:
func_builder.create<lhlo::CopyOp>(loc, rets, args, attrs);
break;
@ -89,6 +92,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
case HloOpcode::kExp:
func_builder.create<lhlo::ExpOp>(loc, rets, args, attrs);
break;
case HloOpcode::kImag:
func_builder.create<lhlo::ImagOp>(loc, rets, args, attrs);
break;
case HloOpcode::kLog:
func_builder.create<lhlo::LogOp>(loc, rets, args, attrs);
break;
@ -104,6 +110,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
case HloOpcode::kNegate:
func_builder.create<lhlo::NegOp>(loc, rets, args, attrs);
break;
case HloOpcode::kReal:
func_builder.create<lhlo::RealOp>(loc, rets, args, attrs);
break;
case HloOpcode::kRemainder:
func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
break;

View File

@ -0,0 +1,12 @@
// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure
HloModule Complex
ENTRY %Complex (real: f32[2,2]{0,1}, imag: f32[2,2]{0,1}) -> c64[2,2] {
%real = f32[2,2]{0,1} parameter(0)
%imag = f32[2,2]{0,1} parameter(1)
ROOT %compl = c64[2,2]{0,1} complex(%real, %imag)
}
// CHECK: func @complex(%[[REAL:.*]]: [[BUF_F32:.*]], %[[IMAG:.*]]: [[BUF_F32]], %[[OUT:.*]]: [[BUF_C64:.*]]) {
// CHECK: "xla_lhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> ()
// CHECK: }

View File

@ -0,0 +1,11 @@
// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure
HloModule Imag
ENTRY %Imag (x: c64[2,2]{0,1}) -> f32[2,2] {
%x = c64[2,2]{0,1} parameter(0)
ROOT %imag = f32[2,2]{0,1} imag(%x)
}
// CHECK: func @imag(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) {
// CHECK: "xla_lhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> ()
// CHECK: }

View File

@ -0,0 +1,11 @@
// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure
HloModule Real
ENTRY %Real (x: c64[2,2]{0,1}) -> f32[2,2] {
%x = c64[2,2]{0,1} parameter(0)
ROOT %real = f32[2,2]{0,1} real(%x)
}
// CHECK: func @real(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) {
// CHECK: "xla_lhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> ()
// CHECK: }