legalize operation 'tf.AllToAll'. Legalization target is xla_hlo.all_to_all.
PiperOrigin-RevId: 310947977 Change-Id: Ib906830860821b590f1bc55e497b4ffbb193e1da
This commit is contained in:
		
							parent
							
								
									f4de924576
								
							
						
					
					
						commit
						e4939f779e
					
				| @ -192,6 +192,44 @@ retained with length 1. | ||||
|   let verifier = [{ return Verify(*this); }]; | ||||
| } | ||||
| 
 | ||||
| def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { | ||||
|   let summary = "An Op to exchange data across TPU replicas."; | ||||
| 
 | ||||
|   let description = [{ | ||||
| On each replica, the input is split into `split_count` blocks along | ||||
| `split_dimension` and send to the other replicas given group_assignment. After | ||||
| receiving `split_count` - 1 blocks from other replicas, we concatenate the | ||||
| blocks along `concat_dimension` as the output. | ||||
| 
 | ||||
| For example, suppose there are 2 TPU replicas: | ||||
| replica 0 receives input: `[[A, B]]` | ||||
| replica 1 receives input: `[[C, D]]` | ||||
| 
 | ||||
| group_assignment=`[[0, 1]]` | ||||
| concat_dimension=0 | ||||
| split_dimension=1 | ||||
| split_count=2 | ||||
| 
 | ||||
| replica 0's output: `[[A], [C]]` | ||||
| replica 1's output: `[[B], [D]]` | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins | ||||
|     TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, | ||||
|     I32Tensor:$group_assignment, | ||||
| 
 | ||||
|     I64Attr:$concat_dimension, | ||||
|     I64Attr:$split_dimension, | ||||
|     I64Attr:$split_count | ||||
|   ); | ||||
| 
 | ||||
|   let results = (outs | ||||
|     TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output | ||||
|   ); | ||||
| 
 | ||||
|   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; | ||||
| } | ||||
| 
 | ||||
| def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { | ||||
|   let summary = "Returns the argument of a complex number."; | ||||
| 
 | ||||
|  | ||||
| @ -4096,6 +4096,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg | ||||
|   return %0 : tensor<4xf32> | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // AllToAll op legalizations. | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| // CHECK-LABEL: func @alltoall_basic | ||||
| func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { | ||||
|   %group_assignment = "tf.Const" () { | ||||
|     value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> | ||||
|   } : () -> tensor<3x4xi32> | ||||
|   %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} :  (tensor<10xf32>, tensor<3x4xi32>)  -> tensor<10xf32> | ||||
|   // CHECK: xla_hlo.all_to_all | ||||
|   // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> | ||||
|   return %result : tensor<10xf32> | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // Cumsum op legalizations. | ||||
| //===----------------------------------------------------------------------===// | ||||
|  | ||||
| @ -273,6 +273,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), | ||||
|           (HLO_CrossReplicaSumOp $input, | ||||
|             (CastElementsToI64Elements $group_assignment))>; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // All2All op patterns. | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), | ||||
|           (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // FFT op patterns. | ||||
| //===----------------------------------------------------------------------===// | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user