legalize operation 'tf.AllToAll'. Legalization target is xla_hlo.all_to_all.
PiperOrigin-RevId: 310947977 Change-Id: Ib906830860821b590f1bc55e497b4ffbb193e1da
This commit is contained in:
parent
f4de924576
commit
e4939f779e
tensorflow/compiler/mlir
@ -192,6 +192,44 @@ retained with length 1.
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> {
|
||||
let summary = "An Op to exchange data across TPU replicas.";
|
||||
|
||||
let description = [{
|
||||
On each replica, the input is split into `split_count` blocks along
|
||||
`split_dimension` and send to the other replicas given group_assignment. After
|
||||
receiving `split_count` - 1 blocks from other replicas, we concatenate the
|
||||
blocks along `concat_dimension` as the output.
|
||||
|
||||
For example, suppose there are 2 TPU replicas:
|
||||
replica 0 receives input: `[[A, B]]`
|
||||
replica 1 receives input: `[[C, D]]`
|
||||
|
||||
group_assignment=`[[0, 1]]`
|
||||
concat_dimension=0
|
||||
split_dimension=1
|
||||
split_count=2
|
||||
|
||||
replica 0's output: `[[A], [C]]`
|
||||
replica 1's output: `[[B], [D]]`
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
I32Tensor:$group_assignment,
|
||||
|
||||
I64Attr:$concat_dimension,
|
||||
I64Attr:$split_dimension,
|
||||
I64Attr:$split_count
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "Returns the argument of a complex number.";
|
||||
|
||||
|
@ -4096,6 +4096,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllToAll op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @alltoall_basic
|
||||
func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%group_assignment = "tf.Const" () {
|
||||
value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32>
|
||||
} : () -> tensor<3x4xi32>
|
||||
%result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32>
|
||||
// CHECK: xla_hlo.all_to_all
|
||||
// CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64>
|
||||
return %result : tensor<10xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cumsum op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -273,6 +273,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)),
|
||||
(HLO_CrossReplicaSumOp $input,
|
||||
(CastElementsToI64Elements $group_assignment))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// All2All op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count),
|
||||
(HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FFT op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user