Created a new class, SchedulerState, to encapsulate all of the scheduler state-related functionality we would like to reuse across different scheduler implementations. Scheduler state-related member functions/variables in VirtualScheduler have been moved to SchedulerState accordingly.
PiperOrigin-RevId: 308204183 Change-Id: Ie8ffe167d31844cc82c865ec5ac4a28d0e53d3a9
This commit is contained in:
parent
ecfcb090c5
commit
3921264ef7
|
@ -40,12 +40,6 @@ namespace {
|
||||||
|
|
||||||
using ::tensorflow::strings::HumanReadableNumBytes;
|
using ::tensorflow::strings::HumanReadableNumBytes;
|
||||||
|
|
||||||
constexpr char kAttrInputSrc[] = "input_source_";
|
|
||||||
constexpr char kAttrSrcDevice[] = "send_device";
|
|
||||||
constexpr char kAttrDstDevice[] = "recv_device";
|
|
||||||
constexpr char kAttrTensorName[] = "tensor_name";
|
|
||||||
constexpr char kChannelDevice[] = "Channel";
|
|
||||||
|
|
||||||
float Round2(const float x) {
|
float Round2(const float x) {
|
||||||
// Not using std::round from <cmath> here because not all platforms seem to
|
// Not using std::round from <cmath> here because not all platforms seem to
|
||||||
// support that (specifically Android).
|
// support that (specifically Android).
|
||||||
|
@ -347,13 +341,11 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
SchedulerState::SchedulerState(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference,
|
const bool use_aggressive_shape_inference,
|
||||||
Cluster* cluster,
|
Cluster* cluster,
|
||||||
ReadyNodeManager* ready_nodes,
|
std::unique_ptr<VirtualPlacer> placer)
|
||||||
std::unique_ptr<VirtualPlacer> placer)
|
: graph_costs_(Costs::ZeroCosts()),
|
||||||
: ready_nodes_(ready_nodes),
|
|
||||||
graph_costs_(Costs::ZeroCosts()),
|
|
||||||
cluster_(cluster),
|
cluster_(cluster),
|
||||||
use_static_shapes_(use_static_shapes),
|
use_static_shapes_(use_static_shapes),
|
||||||
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
||||||
|
@ -364,10 +356,12 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||||
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
Status SchedulerState::Init(const GrapplerItem* item,
|
||||||
|
std::vector<const NodeDef*>* initial_nodes,
|
||||||
|
bool create_explicit_channel_device) {
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
|
|
||||||
// Clear all internal states so that the VirtualScheduler is reusable for
|
// Clear all internal states so that the SchedulerState is reusable for
|
||||||
// different GrapplerItems
|
// different GrapplerItems
|
||||||
node_map_.clear();
|
node_map_.clear();
|
||||||
device_.clear();
|
device_.clear();
|
||||||
|
@ -380,14 +374,12 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
op_counts_.clear();
|
op_counts_.clear();
|
||||||
op_costs_.clear();
|
op_costs_.clear();
|
||||||
|
|
||||||
// Init() preprocesses the input grappler_item and graph_properties to extract
|
initial_nodes->clear();
|
||||||
// necessary information for emulating tensorflow op scheduling and
|
|
||||||
// construct internal data structures (NodeState and DeviceState) for virtual
|
|
||||||
// scheduling.
|
|
||||||
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
|
||||||
|
|
||||||
// Constructs graph properties and performs shape inference.
|
// Constructs graph properties and performs shape inference.
|
||||||
graph_properties_ = absl::make_unique<GraphProperties>(*item);
|
graph_properties_ = absl::make_unique<GraphProperties>(*item);
|
||||||
|
// TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
|
||||||
|
// to get rid of use_static_shapes_ and cluster_.
|
||||||
if (use_static_shapes_) {
|
if (use_static_shapes_) {
|
||||||
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
|
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
|
||||||
true, use_aggressive_shape_inference_, true));
|
true, use_aggressive_shape_inference_, true));
|
||||||
|
@ -399,6 +391,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
const auto& graph = grappler_item_->graph;
|
const auto& graph = grappler_item_->graph;
|
||||||
const auto& fetch_nodes = grappler_item_->fetch;
|
const auto& fetch_nodes = grappler_item_->fetch;
|
||||||
std::set<string> feed_nodes;
|
std::set<string> feed_nodes;
|
||||||
|
|
||||||
for (const auto& f : grappler_item_->feed) {
|
for (const auto& f : grappler_item_->feed) {
|
||||||
auto iter_and_inserted_flag = feed_nodes.insert(f.first);
|
auto iter_and_inserted_flag = feed_nodes.insert(f.first);
|
||||||
QCHECK(iter_and_inserted_flag.second)
|
QCHECK(iter_and_inserted_flag.second)
|
||||||
|
@ -486,8 +479,9 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
} else {
|
} else {
|
||||||
// Different device, no cached copy; transfer input_node to the
|
// Different device, no cached copy; transfer input_node to the
|
||||||
// curr_node's device.
|
// curr_node's device.
|
||||||
auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
|
auto send_and_recv =
|
||||||
input_node_name);
|
CreateSendRecv(input_node, curr_node, input_node, input_node_name,
|
||||||
|
create_explicit_channel_device);
|
||||||
// Note that CreateSendRecv() already connected input/output between
|
// Note that CreateSendRecv() already connected input/output between
|
||||||
// _Send and _Recv ops.
|
// _Send and _Recv ops.
|
||||||
const auto* send = send_and_recv.first;
|
const auto* send = send_and_recv.first;
|
||||||
|
@ -514,7 +508,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
|
|
||||||
if (given_as_feed || has_no_inputs) {
|
if (given_as_feed || has_no_inputs) {
|
||||||
curr_node_state.time_ready = Costs::Duration();
|
curr_node_state.time_ready = Costs::Duration();
|
||||||
ready_nodes_->AddNode(curr_node);
|
initial_nodes->push_back(curr_node);
|
||||||
VLOG(3) << "Added ready node: " << curr_node->name();
|
VLOG(3) << "Added ready node: " << curr_node->name();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -530,7 +524,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ready_nodes_->Empty()) {
|
if (initial_nodes->empty()) {
|
||||||
return errors::InvalidArgument("No ready nodes in the graph.");
|
return errors::InvalidArgument("No ready nodes in the graph.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -546,20 +540,20 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
|
void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
|
||||||
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
|
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
|
||||||
// This method is called when NodeState is created and adds input and output
|
// This method is called when NodeState is created and adds input and output
|
||||||
// properties for a few exceptional cases that GraphProperties cannot provide
|
// properties for a few exceptional cases that GraphProperties cannot provide
|
||||||
// input/output properties.
|
// input/output properties.
|
||||||
if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
|
if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
|
||||||
// _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
|
// _Send and _Recv ops created from SchedulerState have kAttrInputSrc
|
||||||
// attr; normal _Send and _Recv ops (from the input graph) do not have that
|
// attr; normal _Send and _Recv ops (from the input graph) do not have that
|
||||||
// attr.
|
// attr.
|
||||||
auto& node_state = node_map_[node];
|
auto& node_state = node_map_[node];
|
||||||
auto& inputs = node_state.input_properties;
|
auto& inputs = node_state.input_properties;
|
||||||
auto& outputs = node_state.output_properties;
|
auto& outputs = node_state.output_properties;
|
||||||
|
|
||||||
// _Send and _Recv ops are created from VirtualScheduler, so
|
// _Send and _Recv ops are created from SchedulerState, so
|
||||||
// there should be no inputs TensorProperties.
|
// there should be no inputs TensorProperties.
|
||||||
CHECK(inputs.empty());
|
CHECK(inputs.empty());
|
||||||
CHECK(outputs.empty());
|
CHECK(outputs.empty());
|
||||||
|
@ -595,27 +589,27 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
string SchedulerState::DeviceName(const NodeDef* node) const {
|
||||||
return placer_->get_canonical_device_name(*node);
|
return placer_->get_canonical_device_name(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
|
string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
|
||||||
// Replace the ":" characters that may be present in the device name with "_".
|
// Replace the ":" characters that may be present in the device name with "_".
|
||||||
// This makes it possible to then use the resulting string in a node name.
|
// This makes it possible to then use the resulting string in a node name.
|
||||||
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
|
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
|
||||||
{{":", "_"}});
|
{{":", "_"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
|
string SchedulerState::ChannelDeviceName(const NodeDef* from,
|
||||||
const NodeDef* to) const {
|
const NodeDef* to) const {
|
||||||
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
|
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
|
||||||
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
|
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
|
||||||
"_to_", SanitizedDeviceName(to));
|
"_to_", SanitizedDeviceName(to));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
|
||||||
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
|
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
|
||||||
const string& input_name) {
|
const string& input_name, bool create_channel_device) {
|
||||||
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
|
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
|
||||||
|
|
||||||
// Connect "from" node to "to" node with _Send and _Recv such that
|
// Connect "from" node to "to" node with _Send and _Recv such that
|
||||||
|
@ -643,7 +637,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||||
"_to_" + SanitizedDeviceName(to));
|
"_to_" + SanitizedDeviceName(to));
|
||||||
send->set_op("_Send");
|
send->set_op("_Send");
|
||||||
send->add_input(from->name());
|
send->add_input(from->name());
|
||||||
send->set_device(ChannelDeviceName(from, to));
|
auto send_device =
|
||||||
|
create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
|
||||||
|
send->set_device(send_device);
|
||||||
auto& send_attr = *(send->mutable_attr());
|
auto& send_attr = *(send->mutable_attr());
|
||||||
send_attr[kAttrInputSrc].set_s(input_name);
|
send_attr[kAttrInputSrc].set_s(input_name);
|
||||||
send_attr[kAttrSrcDevice].set_s(DeviceName(from));
|
send_attr[kAttrSrcDevice].set_s(DeviceName(from));
|
||||||
|
@ -687,9 +683,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||||
return std::make_pair(send, recv);
|
return std::make_pair(send, recv);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpContext VirtualScheduler::GetCurrNode() const {
|
OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
|
||||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
|
||||||
|
|
||||||
// Get the device from the placer.
|
// Get the device from the placer.
|
||||||
DeviceProperties device;
|
DeviceProperties device;
|
||||||
device = placer_->get_device(*node);
|
device = placer_->get_device(*node);
|
||||||
|
@ -721,7 +715,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
|
||||||
return op_context;
|
return op_context;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||||
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
|
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
|
||||||
|
|
||||||
auto it = node_map_.find(node);
|
auto it = node_map_.find(node);
|
||||||
|
@ -766,8 +760,9 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void VirtualScheduler::AddOutputNodesToReadyQueue(
|
void SchedulerState::GetOutputNodes(const NodeDef* node,
|
||||||
const NodeDef* node, const Costs::Duration& curr_time) {
|
const Costs::Duration& curr_time,
|
||||||
|
std::vector<const NodeDef*>* output_nodes) {
|
||||||
// Checks whether the Switch's output slots change over iterations.
|
// Checks whether the Switch's output slots change over iterations.
|
||||||
int slot = -1;
|
int slot = -1;
|
||||||
if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
|
if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
|
||||||
|
@ -780,7 +775,6 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment num_inputs_ready of the output nodes and maybe add to ready
|
// Increment num_inputs_ready of the output nodes and maybe add to ready
|
||||||
// nodes.
|
// nodes.
|
||||||
auto& node_state = node_map_[node];
|
auto& node_state = node_map_[node];
|
||||||
|
@ -799,16 +793,15 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
|
||||||
IsMerge(*output_node)) {
|
IsMerge(*output_node)) {
|
||||||
// This output node is now ready.
|
// This output node is now ready.
|
||||||
output_state.time_ready = curr_time;
|
output_state.time_ready = curr_time;
|
||||||
ready_nodes_->AddNode(output_node);
|
output_nodes->push_back(output_node);
|
||||||
VLOG(3) << " Add output: " << output_node->name();
|
VLOG(3) << " Add output: " << output_node->name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
|
||||||
// Update graph_costs_ and per-op costs.
|
const NodeDef* node, const Costs& node_costs, const OpContext& op_context) {
|
||||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
|
||||||
auto& node_state = node_map_[node];
|
auto& node_state = node_map_[node];
|
||||||
// TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
|
// TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
|
||||||
// Merge -- it can create a loop which may include loop-carried dependency,
|
// Merge -- it can create a loop which may include loop-carried dependency,
|
||||||
|
@ -834,8 +827,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
|
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
// Also keep track of op counts and costs per op (with their shapes).
|
// Also keep track of op counts and costs per op (with their shapes).
|
||||||
OpContext op_context = GetCurrNode();
|
|
||||||
|
|
||||||
string node_description = GetOpDescription(op_context.op_info);
|
string node_description = GetOpDescription(op_context.op_info);
|
||||||
op_counts_[node_description] += 1;
|
op_counts_[node_description] += 1;
|
||||||
op_costs_[node_description] =
|
op_costs_[node_description] =
|
||||||
|
@ -886,7 +877,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
<< ", ready: " << node_state.time_ready.count()
|
<< ", ready: " << node_state.time_ready.count()
|
||||||
<< ", scheduled: " << node_state.time_scheduled.count()
|
<< ", scheduled: " << node_state.time_scheduled.count()
|
||||||
<< ", finished: " << node_state.time_finished.count();
|
<< ", finished: " << node_state.time_finished.count();
|
||||||
|
std::vector<const NodeDef*> new_nodes;
|
||||||
if (previously_executed_merge) {
|
if (previously_executed_merge) {
|
||||||
// Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
|
// Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
|
||||||
VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
|
VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
|
||||||
|
@ -894,7 +885,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
<< "Skip scheduling its output nodes.";
|
<< "Skip scheduling its output nodes.";
|
||||||
} else {
|
} else {
|
||||||
// Checks outputs, and adds ready nodes to queue.
|
// Checks outputs, and adds ready nodes to queue.
|
||||||
AddOutputNodesToReadyQueue(node, curr_time);
|
GetOutputNodes(node, curr_time, &new_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment num_outputs_executed of the input nodes and maybe update memory.
|
// Increment num_outputs_executed of the input nodes and maybe update memory.
|
||||||
|
@ -929,13 +920,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return new_nodes;
|
||||||
ready_nodes_->RemoveCurrNode();
|
|
||||||
|
|
||||||
return !ready_nodes_->Empty();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Costs VirtualScheduler::Summary() const {
|
Costs SchedulerState::Summary() const {
|
||||||
// Overall statement about accuracy
|
// Overall statement about accuracy
|
||||||
VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
|
VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
|
||||||
<< graph_costs_.num_ops_with_unknown_shapes
|
<< graph_costs_.num_ops_with_unknown_shapes
|
||||||
|
@ -1109,12 +1097,12 @@ Costs VirtualScheduler::Summary() const {
|
||||||
return critical_path_costs;
|
return critical_path_costs;
|
||||||
}
|
}
|
||||||
|
|
||||||
Costs VirtualScheduler::Summary(RunMetadata* metadata) {
|
Costs SchedulerState::Summary(RunMetadata* metadata) {
|
||||||
if (metadata) GenerateRunMetadata(metadata);
|
if (metadata) GenerateRunMetadata(metadata);
|
||||||
return Summary();
|
return Summary();
|
||||||
}
|
}
|
||||||
|
|
||||||
void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
|
void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
|
||||||
// Fill RunMetadata's step_stats and partition_graphs fields.
|
// Fill RunMetadata's step_stats and partition_graphs fields.
|
||||||
StepStats* stepstats = metadata->mutable_step_stats();
|
StepStats* stepstats = metadata->mutable_step_stats();
|
||||||
for (const auto& device : device_) {
|
for (const auto& device : device_) {
|
||||||
|
@ -1176,7 +1164,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
|
||||||
nodestate.time_scheduled.count());
|
nodestate.time_scheduled.count());
|
||||||
|
|
||||||
auto* mem_stats = node_stats->mutable_memory_stats();
|
auto* mem_stats = node_stats->mutable_memory_stats();
|
||||||
// VirtualScheduler does not specify scratch pad memory usage.
|
// SchedulerState does not specify scratch pad memory usage.
|
||||||
mem_stats->set_temp_memory_size(0);
|
mem_stats->set_temp_memory_size(0);
|
||||||
int64 persistent_memory_size = 0;
|
int64 persistent_memory_size = 0;
|
||||||
if (IsPersistent(*node_def)) {
|
if (IsPersistent(*node_def)) {
|
||||||
|
@ -1188,7 +1176,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
|
const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage()
|
||||||
const {
|
const {
|
||||||
std::unordered_map<string, int64> result;
|
std::unordered_map<string, int64> result;
|
||||||
for (const auto& device : device_) {
|
for (const auto& device : device_) {
|
||||||
|
@ -1200,7 +1188,7 @@ const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<string, int64>
|
const std::unordered_map<string, int64>
|
||||||
VirtualScheduler::GetPersistentMemoryUsage() const {
|
SchedulerState::GetPersistentMemoryUsage() const {
|
||||||
std::unordered_map<string, int64> result;
|
std::unordered_map<string, int64> result;
|
||||||
for (const auto& device : device_) {
|
for (const auto& device : device_) {
|
||||||
const string& name = device.first;
|
const string& name = device.first;
|
||||||
|
@ -1217,5 +1205,51 @@ VirtualScheduler::GetPersistentMemoryUsage() const {
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||||
|
const bool use_aggressive_shape_inference,
|
||||||
|
Cluster* cluster,
|
||||||
|
ReadyNodeManager* ready_nodes,
|
||||||
|
std::unique_ptr<VirtualPlacer> placer)
|
||||||
|
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
|
||||||
|
cluster, std::move(placer)),
|
||||||
|
ready_nodes_(ready_nodes) {}
|
||||||
|
|
||||||
|
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
|
// SchedulerState::Init() preprocesses the input grappler_item and
|
||||||
|
// graph_properties to extract necessary information for emulating tensorflow
|
||||||
|
// op scheduling and construct internal data structures (NodeState and
|
||||||
|
// DeviceState) for virtual scheduling.
|
||||||
|
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
||||||
|
std::vector<const NodeDef*> initial_nodes;
|
||||||
|
auto status = scheduler_state_.Init(item, &initial_nodes);
|
||||||
|
if (status.ok()) {
|
||||||
|
// Add the set of initial nodes to ready_nodes_
|
||||||
|
for (auto node : initial_nodes) {
|
||||||
|
ready_nodes_->AddNode(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpContext VirtualScheduler::GetCurrNode() const {
|
||||||
|
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||||
|
return scheduler_state_.CreateOpContext(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
|
// Update graph_costs_ and per-op costs.
|
||||||
|
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||||
|
auto new_nodes = scheduler_state_.MarkNodeExecuted(
|
||||||
|
node, node_costs,
|
||||||
|
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
|
||||||
|
ready_nodes_->RemoveCurrNode();
|
||||||
|
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
|
||||||
|
for (auto node : new_nodes) {
|
||||||
|
ready_nodes_->AddNode(node);
|
||||||
|
}
|
||||||
|
return !ready_nodes_->Empty();
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
|
@ -32,6 +32,12 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
inline constexpr char kAttrInputSrc[] = "input_source_";
|
||||||
|
inline constexpr char kAttrSrcDevice[] = "send_device";
|
||||||
|
inline constexpr char kAttrDstDevice[] = "recv_device";
|
||||||
|
inline constexpr char kAttrTensorName[] = "tensor_name";
|
||||||
|
inline constexpr char kChannelDevice[] = "Channel";
|
||||||
|
|
||||||
struct NodeState {
|
struct NodeState {
|
||||||
// A node (i.e., an op) takes a set of input:port pairs and produces
|
// A node (i.e., an op) takes a set of input:port pairs and produces
|
||||||
// a set of output ports.
|
// a set of output ports.
|
||||||
|
@ -233,7 +239,7 @@ class HeapReadyManager : public ReadyNodeManager {
|
||||||
// functor for keeping the smallest time_ready node at the front of heap.
|
// functor for keeping the smallest time_ready node at the front of heap.
|
||||||
std::function<bool(const NodeDef*, const NodeDef*)> greater_;
|
std::function<bool(const NodeDef*, const NodeDef*)> greater_;
|
||||||
|
|
||||||
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
|
// NodeState structure from SchedulerState to get time_ready of ready nodes.
|
||||||
// Not owned by FirstReadyManager.
|
// Not owned by FirstReadyManager.
|
||||||
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
||||||
};
|
};
|
||||||
|
@ -298,7 +304,7 @@ class CompositeNodeManager : public ReadyNodeManager {
|
||||||
FirstReadyManager send_manager_;
|
FirstReadyManager send_manager_;
|
||||||
FirstReadyManager recv_manager_;
|
FirstReadyManager recv_manager_;
|
||||||
|
|
||||||
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
|
// NodeState structure from SchedulerState to get time_ready of ready nodes.
|
||||||
// Not owned by CompositeReadyManager.
|
// Not owned by CompositeReadyManager.
|
||||||
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
||||||
|
|
||||||
|
@ -310,32 +316,22 @@ class CompositeNodeManager : public ReadyNodeManager {
|
||||||
std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
||||||
const string& ready_node_manager);
|
const string& ready_node_manager);
|
||||||
|
|
||||||
// The virtual scheduler emulates execution of nodes in a graph, considering
|
// Encapsulates all of the various pieces uses to track state of a scheduler;
|
||||||
// dependencies, device, etc.
|
// enables reuse of all scheduler state-related utilities across different
|
||||||
class VirtualScheduler {
|
// scheduler implementations.
|
||||||
|
class SchedulerState {
|
||||||
public:
|
public:
|
||||||
// Does not take ownership of cluster or ready_nodes.
|
SchedulerState(const bool use_static_shapes,
|
||||||
VirtualScheduler(const bool use_static_shapes,
|
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
std::unique_ptr<VirtualPlacer> placer);
|
||||||
ReadyNodeManager* ready_nodes,
|
// Sets up the graph while also performing some necessary transformations
|
||||||
std::unique_ptr<VirtualPlacer> placer);
|
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
|
||||||
|
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
|
||||||
|
// to begin node schedule and graph simulation.
|
||||||
|
Status Init(const GrapplerItem* item,
|
||||||
|
std::vector<const NodeDef*>* initial_nodes,
|
||||||
|
bool create_explicit_channel_device = true);
|
||||||
|
|
||||||
// Initializes the scheduler for the specific grappler item.
|
|
||||||
// Should be called immediately after the c'tor or when the scheduler will be
|
|
||||||
// reused for a new grappler item. All internal states of the scheduler
|
|
||||||
// related to the previous grappler item will be reset/cleared.
|
|
||||||
//
|
|
||||||
// This function should be called at least once after the scheduler is
|
|
||||||
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
|
||||||
// undefined behavior.
|
|
||||||
Status Init(const GrapplerItem* item);
|
|
||||||
|
|
||||||
OpContext GetCurrNode() const;
|
|
||||||
|
|
||||||
// Returns true if there is any node to be scheduled.
|
|
||||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
|
||||||
|
|
||||||
// Prints out summary of execution (timing, memory usage, etc.)
|
|
||||||
Costs Summary() const;
|
Costs Summary() const;
|
||||||
// Like the above, but writes detailed stats to RunMetadata.
|
// Like the above, but writes detailed stats to RunMetadata.
|
||||||
// If metadata is nullptr, then just calls and return Summary().
|
// If metadata is nullptr, then just calls and return Summary().
|
||||||
|
@ -347,34 +343,40 @@ class VirtualScheduler {
|
||||||
// Returns per device memory usage.
|
// Returns per device memory usage.
|
||||||
const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
|
const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
|
||||||
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
|
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
|
||||||
|
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
|
||||||
// Returns VirtualScheduler (read only) device and node states.
|
// Returns (read only) device and node states.
|
||||||
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
||||||
return &device_;
|
return &device_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
||||||
return &node_map_;
|
return &node_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
|
OpContext CreateOpContext(const NodeDef* node) const;
|
||||||
|
std::vector<const NodeDef*> MarkNodeExecuted(const NodeDef* node,
|
||||||
|
const Costs& node_costs,
|
||||||
|
const OpContext& op_context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Methods called from Init(). Fails if initialize_ is set.
|
// Methods called from Init(). Fails if initialize_ is set.
|
||||||
|
|
||||||
void MaybeUpdateInputOutput(const NodeDef* node);
|
void MaybeUpdateInputOutput(const NodeDef* node);
|
||||||
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
||||||
|
// Creates a Send_ and Recv_ pair between from and to. The argument
|
||||||
|
// create_channel_device tells the function to create an explicit device for
|
||||||
|
// the channel.
|
||||||
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
|
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
|
||||||
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
|
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
|
||||||
const string& input_name);
|
const string& input_name, bool create_channel_device);
|
||||||
string DeviceName(const NodeDef* node) const;
|
string DeviceName(const NodeDef* node) const;
|
||||||
string SanitizedDeviceName(const NodeDef* node) const;
|
string SanitizedDeviceName(const NodeDef* node) const;
|
||||||
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
|
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
|
||||||
|
|
||||||
// Helper methods.
|
// Helper methods.
|
||||||
void AddOutputNodesToReadyQueue(const NodeDef* node,
|
void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time,
|
||||||
const Costs::Duration& curr_time);
|
std::vector<const NodeDef*>* output_nodes);
|
||||||
|
|
||||||
// Scheduler states:
|
|
||||||
ReadyNodeManager* ready_nodes_; // Not owned.
|
|
||||||
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
||||||
std::unordered_map<string, DeviceState> device_;
|
std::unordered_map<string, DeviceState> device_;
|
||||||
|
|
||||||
|
@ -396,16 +398,81 @@ class VirtualScheduler {
|
||||||
// Auxiliary data structures for constructing NodeState and DeviceState.
|
// Auxiliary data structures for constructing NodeState and DeviceState.
|
||||||
std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init().
|
std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init().
|
||||||
Cluster* cluster_; // Not owned.
|
Cluster* cluster_; // Not owned.
|
||||||
|
|
||||||
const GrapplerItem* grappler_item_; // Not owned.
|
const GrapplerItem* grappler_item_; // Not owned.
|
||||||
bool use_static_shapes_;
|
bool use_static_shapes_;
|
||||||
bool initialized_;
|
bool initialized_;
|
||||||
bool track_mem_usage_snapshot_;
|
bool track_mem_usage_snapshot_;
|
||||||
const bool use_aggressive_shape_inference_;
|
const bool use_aggressive_shape_inference_;
|
||||||
|
|
||||||
std::unique_ptr<VirtualPlacer> placer_;
|
std::unique_ptr<VirtualPlacer> placer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The virtual scheduler emulates execution of nodes in a graph, considering
|
||||||
|
// dependencies, device, etc.
|
||||||
|
class VirtualScheduler {
|
||||||
|
public:
|
||||||
|
// Does not take ownership of cluster or ready_nodes.
|
||||||
|
VirtualScheduler(const bool use_static_shapes,
|
||||||
|
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||||
|
ReadyNodeManager* ready_nodes,
|
||||||
|
std::unique_ptr<VirtualPlacer> placer);
|
||||||
|
|
||||||
|
// Initializes the scheduler for the specific grappler item.
|
||||||
|
// Should be called immediately after the c'tor or when the scheduler will be
|
||||||
|
// reused for a new grappler item. All internal states of the scheduler
|
||||||
|
// related to the previous grappler item will be reset/cleared.
|
||||||
|
//
|
||||||
|
// This function should be called at least once after the scheduler is
|
||||||
|
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
||||||
|
// undefined behavior.
|
||||||
|
Status Init(const GrapplerItem* item);
|
||||||
|
|
||||||
|
// Gets the current scheduled node for execution; the caller of this function
|
||||||
|
// can accordingly simulate the execution of the current scheduled node.
|
||||||
|
OpContext GetCurrNode() const;
|
||||||
|
// Marks the current scheduled node as executed. Note that we should call this
|
||||||
|
// function only after the execution of the node has been simulated;
|
||||||
|
// node_costs_ capture the simulated costs of the node.
|
||||||
|
// Returns true if there is any node to be scheduled.
|
||||||
|
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||||
|
|
||||||
|
// Prints out summary of execution (timing, memory usage, etc.)
|
||||||
|
Costs Summary() const { return scheduler_state_.Summary(); }
|
||||||
|
// Like the above, but writes detailed stats to RunMetadata.
|
||||||
|
// If metadata is nullptr, then just calls and return Summary().
|
||||||
|
Costs Summary(RunMetadata* metadata) {
|
||||||
|
return scheduler_state_.Summary(metadata);
|
||||||
|
}
|
||||||
|
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
||||||
|
// of the virtual execution of the graph.
|
||||||
|
void GenerateRunMetadata(RunMetadata* metadata) {
|
||||||
|
scheduler_state_.GenerateRunMetadata(metadata);
|
||||||
|
}
|
||||||
|
// Returns per device memory usage.
|
||||||
|
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
|
||||||
|
return scheduler_state_.GetPeakMemoryUsage();
|
||||||
|
}
|
||||||
|
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
|
||||||
|
return scheduler_state_.GetPersistentMemoryUsage();
|
||||||
|
}
|
||||||
|
// Returns VirtualScheduler (read only) device and node states.
|
||||||
|
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
||||||
|
return scheduler_state_.GetDeviceStates();
|
||||||
|
}
|
||||||
|
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
||||||
|
return scheduler_state_.GetNodeStates();
|
||||||
|
}
|
||||||
|
void enable_mem_usage_tracking() {
|
||||||
|
scheduler_state_.enable_mem_usage_tracking();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The state of the scheduler and the execution of the graph is encapsulated
|
||||||
|
// by the scheduler_state_ object.
|
||||||
|
SchedulerState scheduler_state_;
|
||||||
|
// ready_nodes_ is responsible for ordering the traversal of the graph.
|
||||||
|
ReadyNodeManager* ready_nodes_; // Not owned.
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue