Fix incorrect early-out in partitioning.
Fixes a bug in which the arg and retval node indices weren't rewritten when a function was repartitioned by a StatefulPartitionedCallOp. Repartitioning can occur when the kernel is visited with a previously unseen FunctionLibraryRuntime. PiperOrigin-RevId: 208255736
This commit is contained in:
parent
b4b56abe30
commit
0a359212bc
@ -261,12 +261,6 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
// device, and
|
||||
// (3) records which `Arg` and `Retval` nodes live in host memory.
|
||||
Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) {
|
||||
if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) {
|
||||
// This function has already been partitioned, albeit for a different
|
||||
// function library.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ArgAndRetIndices indices;
|
||||
std::vector<int>* arg_indices = &indices.first;
|
||||
std::vector<int>* ret_indices = &indices.second;
|
||||
@ -274,6 +268,8 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
std::vector<std::pair<Node*, int>> ret_nodes;
|
||||
const AttrValue* attr_value;
|
||||
|
||||
// Find the Arg and Retval nodes, along with their corresponding indices
|
||||
// in the original function.
|
||||
for (Node* node : subgraph->op_nodes()) {
|
||||
string node_type = node->type_string();
|
||||
if (node_type == FunctionLibraryDefinition::kArgOp) {
|
||||
@ -289,6 +285,8 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrite the indices of the Arg and Retval nodes for this function
|
||||
// to range from 0 to the number of Arg nodes, Retval nodes, respectively.
|
||||
auto sort_by_index = [](std::pair<Node*, int> one,
|
||||
std::pair<Node*, int> two) -> bool {
|
||||
return one.second < two.second;
|
||||
@ -318,7 +316,12 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr);
|
||||
}
|
||||
|
||||
arg_and_ret_indices_.emplace(device, indices);
|
||||
// If this kernel execution corresponds to a StatefulPartitionedCallOp,
|
||||
// `arg_and_ret_indices_` might have been populated by a previous
|
||||
// invocation.
|
||||
if (arg_and_ret_indices_.find(device) == arg_and_ret_indices_.end()) {
|
||||
arg_and_ret_indices_.emplace(device, indices);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user