Colocate splits ops for spatial partition to input data.

PiperOrigin-RevId: 343594062
Change-Id: I243b94922aac5138ab9e3512ae0f225bae9bed29
This commit is contained in:
A. Unique TensorFlower 2020-11-20 17:22:36 -08:00 committed by TensorFlower Gardener
parent 6847e331e1
commit 03ab8d341c

View File

@ -749,9 +749,16 @@ xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
Node* split_node = graph->AddNode(split_def, &s);
split_node->set_assigned_device_name(input_assigned_device);
TF_RETURN_IF_ERROR(s);
split_node->set_assigned_device_name(input_assigned_device);
// If colocate the newly created split op to source node of input to TPU
// computation.
split_node->AddAttr(kColocationAttrName,
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
orig_src->name())});
graph->AddEdge(split_dim_node, 0, split_node, 0);
graph->AddEdge(to_split_node, to_split_index, split_node, 1);