Fix incorrect early-out in partitioning.

Fixes a bug in which the arg and retval node indices weren't rewritten when
a function was repartitioned by a StatefulPartitionedCallOp. Repartitioning can
occur when the kernel is visited with a previously unseen
FunctionLibraryRuntime.

PiperOrigin-RevId: 208255736
This commit is contained in:
Akshay Agrawal 2018-08-10 13:06:09 -07:00 committed by TensorFlower Gardener
parent b4b56abe30
commit 0a359212bc

View File

@ -261,12 +261,6 @@ class PartitionedCallOp : public AsyncOpKernel {
// device, and
// (3) records which `Arg` and `Retval` nodes live in host memory.
Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) {
if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) {
// This function has already been partitioned, albeit for a different
// function library.
return Status::OK();
}
ArgAndRetIndices indices;
std::vector<int>* arg_indices = &indices.first;
std::vector<int>* ret_indices = &indices.second;
@ -274,6 +268,8 @@ class PartitionedCallOp : public AsyncOpKernel {
std::vector<std::pair<Node*, int>> ret_nodes;
const AttrValue* attr_value;
// Find the Arg and Retval nodes, along with their corresponding indices
// in the original function.
for (Node* node : subgraph->op_nodes()) {
string node_type = node->type_string();
if (node_type == FunctionLibraryDefinition::kArgOp) {
@ -289,6 +285,8 @@ class PartitionedCallOp : public AsyncOpKernel {
}
}
// Rewrite the indices of the Arg and Retval nodes for this function
// to range from 0 to the number of Arg nodes, Retval nodes, respectively.
auto sort_by_index = [](std::pair<Node*, int> one,
std::pair<Node*, int> two) -> bool {
return one.second < two.second;
@ -318,7 +316,12 @@ class PartitionedCallOp : public AsyncOpKernel {
arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr);
}
arg_and_ret_indices_.emplace(device, indices);
// If this kernel execution corresponds to a StatefulPartitionedCallOp,
// `arg_and_ret_indices_` might have been populated by a previous
// invocation.
if (arg_and_ret_indices_.find(device) == arg_and_ret_indices_.end()) {
arg_and_ret_indices_.emplace(device, indices);
}
return Status::OK();
}