Add HasOutputProperties to check for pruned ops; Return
device name instead of casting it to a short name (GPU:0/CPU:0); VLOG(2) when printing op device placement since it is a lot of output. PiperOrigin-RevId: 157519077
This commit is contained in:
parent
2994444bf6
commit
5784e1e35e
@ -253,6 +253,10 @@ Status GraphProperties::InferDynamically(Cluster* cluster) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GraphProperties::HasOutputProperties(const string& name) const {
|
||||||
|
return output_properties_.find(name) != output_properties_.end();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<OpInfo::TensorProperties> GraphProperties::GetInputProperties(
|
std::vector<OpInfo::TensorProperties> GraphProperties::GetInputProperties(
|
||||||
const string& node_name) const {
|
const string& node_name) const {
|
||||||
auto it = input_properties_.find(node_name);
|
auto it = input_properties_.find(node_name);
|
||||||
|
@ -37,6 +37,7 @@ class GraphProperties {
|
|||||||
Status InferStatically();
|
Status InferStatically();
|
||||||
Status InferDynamically(Cluster* cluster);
|
Status InferDynamically(Cluster* cluster);
|
||||||
|
|
||||||
|
bool HasOutputProperties(const string& name) const;
|
||||||
std::vector<OpInfo::TensorProperties> GetInputProperties(
|
std::vector<OpInfo::TensorProperties> GetInputProperties(
|
||||||
const string& node_name) const;
|
const string& node_name) const;
|
||||||
std::vector<OpInfo::TensorProperties> GetOutputProperties(
|
std::vector<OpInfo::TensorProperties> GetOutputProperties(
|
||||||
|
@ -184,13 +184,17 @@ void VirtualScheduler::MaybeUpdateInputProperties(
|
|||||||
value->add_float_val(1);
|
value->add_float_val(1);
|
||||||
inputs->push_back(control_message);
|
inputs->push_back(control_message);
|
||||||
} else {
|
} else {
|
||||||
const auto input_position = NodePosition(input_source_name);
|
// Like with HasInputProperties, if a node does not have output
|
||||||
// Use the input source's output property as _Send and _Recv's input
|
// properties, it's likely it was pruned during the shape inference run.
|
||||||
// property.
|
if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) {
|
||||||
auto outputs =
|
const auto input_position = NodePosition(input_source_name);
|
||||||
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
// Use the input source's output property as _Send and _Recv's input
|
||||||
CHECK_GT(outputs.size(), input_position);
|
// property.
|
||||||
inputs->push_back(outputs[input_position]);
|
auto outputs =
|
||||||
|
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
||||||
|
CHECK_GT(outputs.size(), input_position);
|
||||||
|
inputs->push_back(outputs[input_position]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -211,16 +215,8 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
|||||||
const auto* to = node_state.outputs[0];
|
const auto* to = node_state.outputs[0];
|
||||||
return ChannelDeviceName(from, to);
|
return ChannelDeviceName(from, to);
|
||||||
} else {
|
} else {
|
||||||
const string& device = node->device().empty()
|
return node->device().empty() ? "/" + default_device_type_ + ":0"
|
||||||
? "/" + default_device_type_ + ":0"
|
: node->device();
|
||||||
: node->device();
|
|
||||||
DeviceNameUtils::ParsedName parsed;
|
|
||||||
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
|
||||||
LOG(WARNING) << "Device name parse failed: " << device;
|
|
||||||
return device;
|
|
||||||
}
|
|
||||||
// Return a short name like /CPU:0 or /GPU:0.
|
|
||||||
return "/" + DeviceNameUtils::LocalName(parsed.type, parsed.id);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user