Colocate splits ops for spatial partition to input data.
PiperOrigin-RevId: 343594062 Change-Id: I243b94922aac5138ab9e3512ae0f225bae9bed29
This commit is contained in:
		
							parent
							
								
									6847e331e1
								
							
						
					
					
						commit
						03ab8d341c
					
				@ -749,9 +749,16 @@ xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
 | 
			
		||||
  split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
 | 
			
		||||
  split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
 | 
			
		||||
  Node* split_node = graph->AddNode(split_def, &s);
 | 
			
		||||
  split_node->set_assigned_device_name(input_assigned_device);
 | 
			
		||||
  TF_RETURN_IF_ERROR(s);
 | 
			
		||||
 | 
			
		||||
  split_node->set_assigned_device_name(input_assigned_device);
 | 
			
		||||
 | 
			
		||||
  // If colocate the newly created split op to source node of input to TPU
 | 
			
		||||
  // computation.
 | 
			
		||||
  split_node->AddAttr(kColocationAttrName,
 | 
			
		||||
                      std::vector<string>{absl::StrCat(kColocationGroupPrefix,
 | 
			
		||||
                                                       orig_src->name())});
 | 
			
		||||
 | 
			
		||||
  graph->AddEdge(split_dim_node, 0, split_node, 0);
 | 
			
		||||
  graph->AddEdge(to_split_node, to_split_index, split_node, 1);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user