[MLIR] Add more ops support for LHLO.
PiperOrigin-RevId: 315997542 Change-Id: Id8a8c852d20d52cd5e49f4b7fba325afebbac085
This commit is contained in:
parent
603e328a1c
commit
f6938bff05
@ -379,15 +379,6 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
|
||||
// XLA communication op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
|
||||
StructFieldAttr<"handle", I64Attr>,
|
||||
StructFieldAttr<"type", I64Attr>]> {
|
||||
let description = "two 64-bit integers 'handle' and 'type'";
|
||||
}
|
||||
|
||||
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
|
||||
// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
|
||||
def HLO_InfeedOp : HLO_Op<"infeed", []> {
|
||||
@ -451,7 +442,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
|
||||
let arguments = (ins
|
||||
HLO_TensorOrTuple:$operand,
|
||||
HLO_Token:$token,
|
||||
ChannelHandle:$channel_id,
|
||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||
);
|
||||
|
||||
@ -476,7 +467,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
||||
|
||||
let arguments = (ins
|
||||
HLO_Token:$token,
|
||||
ChannelHandle:$channel_id,
|
||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||
);
|
||||
|
||||
@ -564,16 +555,8 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
|
||||
|
||||
|
||||
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
|
||||
SameOperandsAndResultType]> {
|
||||
string summary = "While operator";
|
||||
|
||||
string description = [{
|
||||
Returns the result of executing a body function until the cond body returns
|
||||
true.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#while.
|
||||
}];
|
||||
|
||||
SameOperandsAndResultType]>,
|
||||
BASE_HLO_WhileOp {
|
||||
let arguments = (ins HLO_TensorOrTuple:$val);
|
||||
|
||||
let regions = (region AnyRegion:$cond, AnyRegion:$body);
|
||||
@ -590,7 +573,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
I64ElementsAttr:$replica_groups,
|
||||
OptionalAttr<ChannelHandle>:$channel_id
|
||||
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
@ -584,6 +584,15 @@ class BASE_HLO_CaseOp {
|
||||
// XLA parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
|
||||
StructFieldAttr<"handle", I64Attr>,
|
||||
StructFieldAttr<"type", I64Attr>]> {
|
||||
let description = "two 64-bit integers 'handle' and 'type'";
|
||||
}
|
||||
|
||||
class BASE_HLO_ReplicaIdOp {
|
||||
string summary = "ReplicaId operator";
|
||||
|
||||
@ -1273,4 +1282,30 @@ class BASE_HLO_ReducePrecisionOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_InfeedOp {
|
||||
string summary = "Infeed operator";
|
||||
|
||||
string description = [{
|
||||
Reads a single data item from the implicit Infeed streaming interface of
|
||||
the device, interpreting the data as the given shape and its layout, and
|
||||
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
|
||||
computation, but there must be a total order among the Infeed operations.
|
||||
For example, two Infeeds in the code below have a total order since there
|
||||
is a dependency between the while loops.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#infeed
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_WhileOp {
|
||||
string summary = "While operator";
|
||||
|
||||
string description = [{
|
||||
Returns the result of executing a body function until the cond body returns
|
||||
true.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#while.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // HLO_OPS_BASE
|
||||
|
||||
@ -268,6 +268,16 @@ def LHLO_CaseOp: LHLO_Op<"case", [
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||
}
|
||||
|
||||
// TODO(timshen): Add a custom syntax for this.
|
||||
def LHLO_WhileOp: LHLO_Op<"while", [SameTypeOperands]>, BASE_HLO_WhileOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_BufferOrTuple, "", [MemRead]>:$val,
|
||||
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$output
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA tuple op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -417,7 +427,23 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||
// XLA Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
|
||||
BASE_HLO_BatchNormGradOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
|
||||
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||
BASE_HLO_BatchNormInferenceOp {
|
||||
|
||||
let arguments = (ins
|
||||
@ -432,6 +458,19 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
|
||||
BASE_HLO_BatchNormTrainingOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
||||
[]>, BASE_HLO_BroadcastOp {
|
||||
let arguments = (ins
|
||||
@ -601,6 +640,78 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
|
||||
BASE_HLO_AllReduceOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$replica_groups,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
}
|
||||
|
||||
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
|
||||
BASE_HLO_CollectivePermuteOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$source_target_pairs,
|
||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
HLO_FftTypeAttr:$fft_type,
|
||||
I64ElementsAttr:$fft_length
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$lower
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
DefaultValuedAttr<StrAttr, "">:$config
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
DefaultValuedAttr<StrAttr, "">:$config
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
|
||||
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||
}
|
||||
|
||||
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
||||
BASE_HLO_TriangularSolveOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
||||
BoolAttr:$left_side,
|
||||
BoolAttr:$lower,
|
||||
BoolAttr:$unit_diagonal,
|
||||
HLO_TransposeAttr:$transpose_a
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Late operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -730,3 +730,136 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @all_reduce_memrefs
|
||||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
channel_id = { handle = 5 : i64, type = 2 : i64 },
|
||||
constrain_layout = true,
|
||||
use_global_device_ids = true
|
||||
}: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @collective_permute_memrefs
|
||||
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
|
||||
channel_id = { handle = 5 : i64, type = 2 : i64 }
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fft_memrefs
|
||||
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_grad_memrefs
|
||||
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
|
||||
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
|
||||
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
|
||||
tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_inference_memrefs
|
||||
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_training_memrefs
|
||||
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
|
||||
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cholesky_memrefs
|
||||
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @infeed_memrefs
|
||||
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
|
||||
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @outfeed_memrefs
|
||||
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
|
||||
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @replica_id_memrefs
|
||||
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
"xla_lhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<i64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user