Assign device to split node.
When tensor is split on multiple dimensions, the later split refers to previous split's assigned device. PiperOrigin-RevId: 325279999 Change-Id: I08220371f52be6582d0a76cd7ac88a5d7b92c170
This commit is contained in:
parent
d9b5042f03
commit
b9a5452924
@ -685,6 +685,7 @@ xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim,
|
||||
split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
|
||||
split_def.add_input(absl::StrCat(orig_src->name(), ":", orig_src_output));
|
||||
Node* split_node = graph->AddNode(split_def, &s);
|
||||
split_node->set_assigned_device_name(input_assigned_device);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
graph->AddEdge(split_dim_node, 0, split_node, 0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user