[MLIR] Add more ops support for LHLO.

PiperOrigin-RevId: 315997542
Change-Id: Id8a8c852d20d52cd5e49f4b7fba325afebbac085
This commit is contained in:
Tim Shen 2020-06-11 16:07:41 -07:00 committed by TensorFlower Gardener
parent 603e328a1c
commit f6938bff05
4 changed files with 285 additions and 23 deletions

View File

@ -379,15 +379,6 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
// XLA communication op definitions.
//===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
def HLO_InfeedOp : HLO_Op<"infeed", []> {
@ -451,7 +442,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
let arguments = (ins
HLO_TensorOrTuple:$operand,
HLO_Token:$token,
ChannelHandle:$channel_id,
ChannelHandle<HLO_Dialect>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
);
@ -476,7 +467,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
let arguments = (ins
HLO_Token:$token,
ChannelHandle:$channel_id,
ChannelHandle<HLO_Dialect>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
);
@ -564,16 +555,8 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
SameOperandsAndResultType]> {
string summary = "While operator";
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
SameOperandsAndResultType]>,
BASE_HLO_WhileOp {
let arguments = (ins HLO_TensorOrTuple:$val);
let regions = (region AnyRegion:$cond, AnyRegion:$body);
@ -590,7 +573,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle>:$channel_id
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
);
let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor);

View File

@ -584,6 +584,15 @@ class BASE_HLO_CaseOp {
// XLA parallelism related op definitions.
//===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
class BASE_HLO_ReplicaIdOp {
string summary = "ReplicaId operator";
@ -1273,4 +1282,30 @@ class BASE_HLO_ReducePrecisionOp {
}];
}
class BASE_HLO_InfeedOp {
string summary = "Infeed operator";
string description = [{
Reads a single data item from the implicit Infeed streaming interface of
the device, interpreting the data as the given shape and its layout, and
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
computation, but there must be a total order among the Infeed operations.
For example, two Infeeds in the code below have a total order since there
is a dependency between the while loops.
See https://www.tensorflow.org/xla/operation_semantics#infeed
}];
}
class BASE_HLO_WhileOp {
string summary = "While operator";
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
}
#endif // HLO_OPS_BASE

View File

@ -268,6 +268,16 @@ def LHLO_CaseOp: LHLO_Op<"case", [
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameTypeOperands]>, BASE_HLO_WhileOp {
let arguments = (ins
Arg<LHLO_BufferOrTuple, "", [MemRead]>:$val,
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$output
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
//===----------------------------------------------------------------------===//
@ -417,7 +427,23 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
// XLA Other op definitions.
//===----------------------------------------------------------------------===//
def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
@ -432,6 +458,19 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
);
}
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins
@ -601,6 +640,78 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
);
}
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
BASE_HLO_AllReduceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
BASE_HLO_CollectivePermuteOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
);
}
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
HLO_FftTypeAttr:$fft_type,
I64ElementsAttr:$fft_length
);
}
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
DefaultValuedAttr<BoolAttr, "false">:$lower
);
}
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
);
}
//===----------------------------------------------------------------------===//
// Late operations
//===----------------------------------------------------------------------===//

View File

@ -730,3 +730,136 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 },
constrain_layout = true,
use_global_device_ids = true
}: (memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 }
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_training_memrefs
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
return
}
// -----
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"xla_lhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> ()
return
}