Colocate splits ops for spatial partition to input data.
PiperOrigin-RevId: 343594062 Change-Id: I243b94922aac5138ab9e3512ae0f225bae9bed29
This commit is contained in:
parent
6847e331e1
commit
03ab8d341c
@ -749,9 +749,16 @@ xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
|
||||
split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
|
||||
split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
|
||||
Node* split_node = graph->AddNode(split_def, &s);
|
||||
split_node->set_assigned_device_name(input_assigned_device);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
split_node->set_assigned_device_name(input_assigned_device);
|
||||
|
||||
// If colocate the newly created split op to source node of input to TPU
|
||||
// computation.
|
||||
split_node->AddAttr(kColocationAttrName,
|
||||
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
|
||||
orig_src->name())});
|
||||
|
||||
graph->AddEdge(split_dim_node, 0, split_node, 0);
|
||||
graph->AddEdge(to_split_node, to_split_index, split_node, 1);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user