Assign device to split node.

When tensor is split on multiple dimensions, the later split refers to previous split's assigned device.

PiperOrigin-RevId: 325279999
Change-Id: I08220371f52be6582d0a76cd7ac88a5d7b92c170
This commit is contained in:
Yuanzhong Xu 2020-08-06 12:04:14 -07:00 committed by TensorFlower Gardener
parent d9b5042f03
commit b9a5452924

View File

@ -685,6 +685,7 @@ xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim,
split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
split_def.add_input(absl::StrCat(orig_src->name(), ":", orig_src_output));
Node* split_node = graph->AddNode(split_def, &s);
split_node->set_assigned_device_name(input_assigned_device);
TF_RETURN_IF_ERROR(s);
graph->AddEdge(split_dim_node, 0, split_node, 0);