Add documentation and tests
This commit is contained in:
		
							parent
							
								
									5fea6eecde
								
							
						
					
					
						commit
						c77c165dcd
					
				@ -1382,6 +1382,10 @@ For a more intuitive description, see the "Informal Description" section below.
 | 
				
			|||||||
| `indices_are_sorted`   | `bool`              | Whether the indices are       |
 | 
					| `indices_are_sorted`   | `bool`              | Whether the indices are       |
 | 
				
			||||||
:                        :                     : guaranteed to be sorted by    :
 | 
					:                        :                     : guaranteed to be sorted by    :
 | 
				
			||||||
:                        :                     : the caller.                   :
 | 
					:                        :                     : the caller.                   :
 | 
				
			||||||
 | 
					| `use_atomic`           | `bool`              | Whether to use atomic         |
 | 
				
			||||||
 | 
					:                        :                     : operation for the update. To  :
 | 
				
			||||||
 | 
					:                        :                     : use only when the the caller  :
 | 
				
			||||||
 | 
					:                        :                     : guarante no duplicate indices :
 | 
				
			||||||
 | 
					
 | 
				
			||||||
For convenience, we label dimensions in the output array not in `offset_dims`
 | 
					For convenience, we label dimensions in the output array not in `offset_dims`
 | 
				
			||||||
as `batch_dims`.
 | 
					as `batch_dims`.
 | 
				
			||||||
@ -1450,6 +1454,9 @@ If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
 | 
				
			|||||||
are sorted (in ascending `start_index_map` order) by the user. If they are not
 | 
					are sorted (in ascending `start_index_map` order) by the user. If they are not
 | 
				
			||||||
then the semantics is implementation defined.
 | 
					then the semantics is implementation defined.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If `use_atomic` is set to false then XLA will not use atomic operation. This
 | 
				
			||||||
 | 
					is only safe when there is no duplicate indices.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Informal Description and Examples
 | 
					### Informal Description and Examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Informally, every index `Out` in the output array corresponds to an element `E`
 | 
					Informally, every index `Out` in the output array corresponds to an element `E`
 | 
				
			||||||
 | 
				
			|||||||
@ -1529,7 +1529,8 @@ TEST_F(HloInstructionTest, StringifyScatter) {
 | 
				
			|||||||
              /*inserted_window_dims=*/{},
 | 
					              /*inserted_window_dims=*/{},
 | 
				
			||||||
              /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
 | 
					              /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
 | 
				
			||||||
              /*index_vector_dim=*/2),
 | 
					              /*index_vector_dim=*/2),
 | 
				
			||||||
          /*indices_are_sorted=*/false));
 | 
					              /*indices_are_sorted=*/false,
 | 
				
			||||||
 | 
					              /*use_atomic=*/true));
 | 
				
			||||||
  module->AddEntryComputation(builder.Build());
 | 
					  module->AddEntryComputation(builder.Build());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  EXPECT_EQ(
 | 
					  EXPECT_EQ(
 | 
				
			||||||
 | 
				
			|||||||
@ -934,6 +934,25 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7
 | 
				
			|||||||
  ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
 | 
					  ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					)"
 | 
				
			||||||
 | 
					},
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					"AtomicScatter",
 | 
				
			||||||
 | 
					R"(HloModule StringifyAtomicScatter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
 | 
				
			||||||
 | 
					  %lhs = f32[] parameter(0)
 | 
				
			||||||
 | 
					  %rhs = f32[] parameter(1)
 | 
				
			||||||
 | 
					  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
 | 
				
			||||||
 | 
					  %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
 | 
				
			||||||
 | 
					  %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
 | 
				
			||||||
 | 
					  %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
 | 
				
			||||||
 | 
					  ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, use_atomic=false, to_apply=%add_F32.v3
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
)"
 | 
					)"
 | 
				
			||||||
},
 | 
					},
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user