legalize operation 'tf.AllToAll'. Legalization target is xla_hlo.all_to_all.
PiperOrigin-RevId: 310947977 Change-Id: Ib906830860821b590f1bc55e497b4ffbb193e1da
This commit is contained in:
		
							parent
							
								
									f4de924576
								
							
						
					
					
						commit
						e4939f779e
					
				| @ -192,6 +192,44 @@ retained with length 1. | |||||||
|   let verifier = [{ return Verify(*this); }]; |   let verifier = [{ return Verify(*this); }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { | ||||||
|  |   let summary = "An Op to exchange data across TPU replicas."; | ||||||
|  | 
 | ||||||
|  |   let description = [{ | ||||||
|  | On each replica, the input is split into `split_count` blocks along | ||||||
|  | `split_dimension` and send to the other replicas given group_assignment. After | ||||||
|  | receiving `split_count` - 1 blocks from other replicas, we concatenate the | ||||||
|  | blocks along `concat_dimension` as the output. | ||||||
|  | 
 | ||||||
|  | For example, suppose there are 2 TPU replicas: | ||||||
|  | replica 0 receives input: `[[A, B]]` | ||||||
|  | replica 1 receives input: `[[C, D]]` | ||||||
|  | 
 | ||||||
|  | group_assignment=`[[0, 1]]` | ||||||
|  | concat_dimension=0 | ||||||
|  | split_dimension=1 | ||||||
|  | split_count=2 | ||||||
|  | 
 | ||||||
|  | replica 0's output: `[[A], [C]]` | ||||||
|  | replica 1's output: `[[B], [D]]` | ||||||
|  |   }]; | ||||||
|  | 
 | ||||||
|  |   let arguments = (ins | ||||||
|  |     TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, | ||||||
|  |     I32Tensor:$group_assignment, | ||||||
|  | 
 | ||||||
|  |     I64Attr:$concat_dimension, | ||||||
|  |     I64Attr:$split_dimension, | ||||||
|  |     I64Attr:$split_count | ||||||
|  |   ); | ||||||
|  | 
 | ||||||
|  |   let results = (outs | ||||||
|  |     TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output | ||||||
|  |   ); | ||||||
|  | 
 | ||||||
|  |   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { | def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { | ||||||
|   let summary = "Returns the argument of a complex number."; |   let summary = "Returns the argument of a complex number."; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -4096,6 +4096,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg | |||||||
|   return %0 : tensor<4xf32> |   return %0 : tensor<4xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | // AllToAll op legalizations. | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: func @alltoall_basic | ||||||
|  | func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { | ||||||
|  |   %group_assignment = "tf.Const" () { | ||||||
|  |     value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> | ||||||
|  |   } : () -> tensor<3x4xi32> | ||||||
|  |   %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} :  (tensor<10xf32>, tensor<3x4xi32>)  -> tensor<10xf32> | ||||||
|  |   // CHECK: xla_hlo.all_to_all | ||||||
|  |   // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> | ||||||
|  |   return %result : tensor<10xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| // Cumsum op legalizations. | // Cumsum op legalizations. | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
|  | |||||||
| @ -273,6 +273,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), | |||||||
|           (HLO_CrossReplicaSumOp $input, |           (HLO_CrossReplicaSumOp $input, | ||||||
|             (CastElementsToI64Elements $group_assignment))>; |             (CastElementsToI64Elements $group_assignment))>; | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | // All2All op patterns. | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | 
 | ||||||
|  | def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), | ||||||
|  |           (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; | ||||||
|  | 
 | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| // FFT op patterns. | // FFT op patterns. | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user