legalize operation 'tf.AllToAll'. Legalization target is xla_hlo.all_to_all.

PiperOrigin-RevId: 310947977
Change-Id: Ib906830860821b590f1bc55e497b4ffbb193e1da
This commit is contained in:
A. Unique TensorFlower 2020-05-11 10:57:33 -07:00 committed by TensorFlower Gardener
parent f4de924576
commit e4939f779e
3 changed files with 60 additions and 0 deletions
tensorflow/compiler/mlir

View File

@ -192,6 +192,44 @@ retained with length 1.
let verifier = [{ return Verify(*this); }];
}
def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> {
let summary = "An Op to exchange data across TPU replicas.";
let description = [{
On each replica, the input is split into `split_count` blocks along
`split_dimension` and send to the other replicas given group_assignment. After
receiving `split_count` - 1 blocks from other replicas, we concatenate the
blocks along `concat_dimension` as the output.
For example, suppose there are 2 TPU replicas:
replica 0 receives input: `[[A, B]]`
replica 1 receives input: `[[C, D]]`
group_assignment=`[[0, 1]]`
concat_dimension=0
split_dimension=1
split_count=2
replica 0's output: `[[A], [C]]`
replica 1's output: `[[B], [D]]`
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
I32Tensor:$group_assignment,
I64Attr:$concat_dimension,
I64Attr:$split_dimension,
I64Attr:$split_count
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Returns the argument of a complex number.";

View File

@ -4096,6 +4096,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg
return %0 : tensor<4xf32>
}
//===----------------------------------------------------------------------===//
// AllToAll op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @alltoall_basic
func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> {
%group_assignment = "tf.Const" () {
value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32>
} : () -> tensor<3x4xi32>
%result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32>
// CHECK: xla_hlo.all_to_all
// CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64>
return %result : tensor<10xf32>
}
//===----------------------------------------------------------------------===//
// Cumsum op legalizations.
//===----------------------------------------------------------------------===//

View File

@ -273,6 +273,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)),
(HLO_CrossReplicaSumOp $input,
(CastElementsToI64Elements $group_assignment))>;
//===----------------------------------------------------------------------===//
// All2All op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count),
(HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>;
//===----------------------------------------------------------------------===//
// FFT op patterns.
//===----------------------------------------------------------------------===//