Add HasOutputProperties to check for pruned ops; Return
device name instead of casting it to a short name (GPU:0/CPU:0); VLOG(2) when printing op device placement since it is a lot of output. PiperOrigin-RevId: 157519077
This commit is contained in:
parent
2994444bf6
commit
5784e1e35e
tensorflow/core/grappler/costs
@ -253,6 +253,10 @@ Status GraphProperties::InferDynamically(Cluster* cluster) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool GraphProperties::HasOutputProperties(const string& name) const {
|
||||
return output_properties_.find(name) != output_properties_.end();
|
||||
}
|
||||
|
||||
std::vector<OpInfo::TensorProperties> GraphProperties::GetInputProperties(
|
||||
const string& node_name) const {
|
||||
auto it = input_properties_.find(node_name);
|
||||
|
@ -37,6 +37,7 @@ class GraphProperties {
|
||||
Status InferStatically();
|
||||
Status InferDynamically(Cluster* cluster);
|
||||
|
||||
bool HasOutputProperties(const string& name) const;
|
||||
std::vector<OpInfo::TensorProperties> GetInputProperties(
|
||||
const string& node_name) const;
|
||||
std::vector<OpInfo::TensorProperties> GetOutputProperties(
|
||||
|
@ -184,13 +184,17 @@ void VirtualScheduler::MaybeUpdateInputProperties(
|
||||
value->add_float_val(1);
|
||||
inputs->push_back(control_message);
|
||||
} else {
|
||||
const auto input_position = NodePosition(input_source_name);
|
||||
// Use the input source's output property as _Send and _Recv's input
|
||||
// property.
|
||||
auto outputs =
|
||||
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
||||
CHECK_GT(outputs.size(), input_position);
|
||||
inputs->push_back(outputs[input_position]);
|
||||
// Like with HasInputProperties, if a node does not have output
|
||||
// properties, it's likely it was pruned during the shape inference run.
|
||||
if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) {
|
||||
const auto input_position = NodePosition(input_source_name);
|
||||
// Use the input source's output property as _Send and _Recv's input
|
||||
// property.
|
||||
auto outputs =
|
||||
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
||||
CHECK_GT(outputs.size(), input_position);
|
||||
inputs->push_back(outputs[input_position]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,16 +215,8 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||
const auto* to = node_state.outputs[0];
|
||||
return ChannelDeviceName(from, to);
|
||||
} else {
|
||||
const string& device = node->device().empty()
|
||||
? "/" + default_device_type_ + ":0"
|
||||
: node->device();
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
||||
LOG(WARNING) << "Device name parse failed: " << device;
|
||||
return device;
|
||||
}
|
||||
// Return a short name like /CPU:0 or /GPU:0.
|
||||
return "/" + DeviceNameUtils::LocalName(parsed.type, parsed.id);
|
||||
return node->device().empty() ? "/" + default_device_type_ + ":0"
|
||||
: node->device();
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user