Created a new class, SchedulerState, to encapsulate all of the scheduler state-related functionality we would like to reuse across different scheduler implementations. Scheduler state-related member functions/variables in VirtualScheduler have been moved to SchedulerState accordingly.

PiperOrigin-RevId: 308204183
Change-Id: Ie8ffe167d31844cc82c865ec5ac4a28d0e53d3a9
This commit is contained in:
A. Unique TensorFlower 2020-04-23 23:42:35 -07:00 committed by TensorFlower Gardener
parent ecfcb090c5
commit 3921264ef7
2 changed files with 196 additions and 95 deletions

View File

@ -40,12 +40,6 @@ namespace {
using ::tensorflow::strings::HumanReadableNumBytes; using ::tensorflow::strings::HumanReadableNumBytes;
constexpr char kAttrInputSrc[] = "input_source_";
constexpr char kAttrSrcDevice[] = "send_device";
constexpr char kAttrDstDevice[] = "recv_device";
constexpr char kAttrTensorName[] = "tensor_name";
constexpr char kChannelDevice[] = "Channel";
float Round2(const float x) { float Round2(const float x) {
// Not using std::round from <cmath> here because not all platforms seem to // Not using std::round from <cmath> here because not all platforms seem to
// support that (specifically Android). // support that (specifically Android).
@ -347,13 +341,11 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
return nullptr; return nullptr;
} }
VirtualScheduler::VirtualScheduler(const bool use_static_shapes, SchedulerState::SchedulerState(const bool use_static_shapes,
const bool use_aggressive_shape_inference, const bool use_aggressive_shape_inference,
Cluster* cluster, Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer) std::unique_ptr<VirtualPlacer> placer)
: ready_nodes_(ready_nodes), : graph_costs_(Costs::ZeroCosts()),
graph_costs_(Costs::ZeroCosts()),
cluster_(cluster), cluster_(cluster),
use_static_shapes_(use_static_shapes), use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference), use_aggressive_shape_inference_(use_aggressive_shape_inference),
@ -364,10 +356,12 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
track_mem_usage_snapshot_ = VLOG_IS_ON(1); track_mem_usage_snapshot_ = VLOG_IS_ON(1);
} }
Status VirtualScheduler::Init(const GrapplerItem* item) { Status SchedulerState::Init(const GrapplerItem* item,
std::vector<const NodeDef*>* initial_nodes,
bool create_explicit_channel_device) {
initialized_ = false; initialized_ = false;
// Clear all internal states so that the VirtualScheduler is reusable for // Clear all internal states so that the SchedulerState is reusable for
// different GrapplerItems // different GrapplerItems
node_map_.clear(); node_map_.clear();
device_.clear(); device_.clear();
@ -380,14 +374,12 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
op_counts_.clear(); op_counts_.clear();
op_costs_.clear(); op_costs_.clear();
// Init() preprocesses the input grappler_item and graph_properties to extract initial_nodes->clear();
// necessary information for emulating tensorflow op scheduling and
// construct internal data structures (NodeState and DeviceState) for virtual
// scheduling.
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
// Constructs graph properties and performs shape inference. // Constructs graph properties and performs shape inference.
graph_properties_ = absl::make_unique<GraphProperties>(*item); graph_properties_ = absl::make_unique<GraphProperties>(*item);
// TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
// to get rid of use_static_shapes_ and cluster_.
if (use_static_shapes_) { if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically( TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_, true)); true, use_aggressive_shape_inference_, true));
@ -399,6 +391,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
const auto& graph = grappler_item_->graph; const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch; const auto& fetch_nodes = grappler_item_->fetch;
std::set<string> feed_nodes; std::set<string> feed_nodes;
for (const auto& f : grappler_item_->feed) { for (const auto& f : grappler_item_->feed) {
auto iter_and_inserted_flag = feed_nodes.insert(f.first); auto iter_and_inserted_flag = feed_nodes.insert(f.first);
QCHECK(iter_and_inserted_flag.second) QCHECK(iter_and_inserted_flag.second)
@ -486,8 +479,9 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
} else { } else {
// Different device, no cached copy; transfer input_node to the // Different device, no cached copy; transfer input_node to the
// curr_node's device. // curr_node's device.
auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node, auto send_and_recv =
input_node_name); CreateSendRecv(input_node, curr_node, input_node, input_node_name,
create_explicit_channel_device);
// Note that CreateSendRecv() already connected input/output between // Note that CreateSendRecv() already connected input/output between
// _Send and _Recv ops. // _Send and _Recv ops.
const auto* send = send_and_recv.first; const auto* send = send_and_recv.first;
@ -514,7 +508,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
if (given_as_feed || has_no_inputs) { if (given_as_feed || has_no_inputs) {
curr_node_state.time_ready = Costs::Duration(); curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node); initial_nodes->push_back(curr_node);
VLOG(3) << "Added ready node: " << curr_node->name(); VLOG(3) << "Added ready node: " << curr_node->name();
} }
@ -530,7 +524,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
} }
} }
if (ready_nodes_->Empty()) { if (initial_nodes->empty()) {
return errors::InvalidArgument("No ready nodes in the graph."); return errors::InvalidArgument("No ready nodes in the graph.");
} }
@ -546,20 +540,20 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
return Status::OK(); return Status::OK();
} }
void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init()."; CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
// This method is called when NodeState is created and adds input and output // This method is called when NodeState is created and adds input and output
// properties for a few exceptional cases that GraphProperties cannot provide // properties for a few exceptional cases that GraphProperties cannot provide
// input/output properties. // input/output properties.
if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) { if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
// _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc // _Send and _Recv ops created from SchedulerState have kAttrInputSrc
// attr; normal _Send and _Recv ops (from the input graph) do not have that // attr; normal _Send and _Recv ops (from the input graph) do not have that
// attr. // attr.
auto& node_state = node_map_[node]; auto& node_state = node_map_[node];
auto& inputs = node_state.input_properties; auto& inputs = node_state.input_properties;
auto& outputs = node_state.output_properties; auto& outputs = node_state.output_properties;
// _Send and _Recv ops are created from VirtualScheduler, so // _Send and _Recv ops are created from SchedulerState, so
// there should be no inputs TensorProperties. // there should be no inputs TensorProperties.
CHECK(inputs.empty()); CHECK(inputs.empty());
CHECK(outputs.empty()); CHECK(outputs.empty());
@ -595,27 +589,27 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
} }
} }
string VirtualScheduler::DeviceName(const NodeDef* node) const { string SchedulerState::DeviceName(const NodeDef* node) const {
return placer_->get_canonical_device_name(*node); return placer_->get_canonical_device_name(*node);
} }
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const { string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
// Replace the ":" characters that may be present in the device name with "_". // Replace the ":" characters that may be present in the device name with "_".
// This makes it possible to then use the resulting string in a node name. // This makes it possible to then use the resulting string in a node name.
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node), return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
{{":", "_"}}); {{":", "_"}});
} }
string VirtualScheduler::ChannelDeviceName(const NodeDef* from, string SchedulerState::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const { const NodeDef* to) const {
CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from), return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
"_to_", SanitizedDeviceName(to)); "_to_", SanitizedDeviceName(to));
} }
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
const NodeDef* from, const NodeDef* to, const NodeDef* input_node, const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
const string& input_name) { const string& input_name, bool create_channel_device) {
CHECK(!initialized_) << "CreateSendRecv is called after Init()."; CHECK(!initialized_) << "CreateSendRecv is called after Init().";
// Connect "from" node to "to" node with _Send and _Recv such that // Connect "from" node to "to" node with _Send and _Recv such that
@ -643,7 +637,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
"_to_" + SanitizedDeviceName(to)); "_to_" + SanitizedDeviceName(to));
send->set_op("_Send"); send->set_op("_Send");
send->add_input(from->name()); send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to)); auto send_device =
create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
send->set_device(send_device);
auto& send_attr = *(send->mutable_attr()); auto& send_attr = *(send->mutable_attr());
send_attr[kAttrInputSrc].set_s(input_name); send_attr[kAttrInputSrc].set_s(input_name);
send_attr[kAttrSrcDevice].set_s(DeviceName(from)); send_attr[kAttrSrcDevice].set_s(DeviceName(from));
@ -687,9 +683,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
return std::make_pair(send, recv); return std::make_pair(send, recv);
} }
OpContext VirtualScheduler::GetCurrNode() const { OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
const NodeDef* node = ready_nodes_->GetCurrNode();
// Get the device from the placer. // Get the device from the placer.
DeviceProperties device; DeviceProperties device;
device = placer_->get_device(*node); device = placer_->get_device(*node);
@ -721,7 +715,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
return op_context; return op_context;
} }
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init()."; CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
auto it = node_map_.find(node); auto it = node_map_.find(node);
@ -766,8 +760,9 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
return it->second; return it->second;
} }
void VirtualScheduler::AddOutputNodesToReadyQueue( void SchedulerState::GetOutputNodes(const NodeDef* node,
const NodeDef* node, const Costs::Duration& curr_time) { const Costs::Duration& curr_time,
std::vector<const NodeDef*>* output_nodes) {
// Checks whether the Switch's output slots change over iterations. // Checks whether the Switch's output slots change over iterations.
int slot = -1; int slot = -1;
if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 && if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
@ -780,7 +775,6 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
} }
} }
} }
// Increment num_inputs_ready of the output nodes and maybe add to ready // Increment num_inputs_ready of the output nodes and maybe add to ready
// nodes. // nodes.
auto& node_state = node_map_[node]; auto& node_state = node_map_[node];
@ -799,16 +793,15 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
IsMerge(*output_node)) { IsMerge(*output_node)) {
// This output node is now ready. // This output node is now ready.
output_state.time_ready = curr_time; output_state.time_ready = curr_time;
ready_nodes_->AddNode(output_node); output_nodes->push_back(output_node);
VLOG(3) << " Add output: " << output_node->name(); VLOG(3) << " Add output: " << output_node->name();
} }
} }
} }
} }
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
// Update graph_costs_ and per-op costs. const NodeDef* node, const Costs& node_costs, const OpContext& op_context) {
const NodeDef* node = ready_nodes_->GetCurrNode();
auto& node_state = node_map_[node]; auto& node_state = node_map_[node];
// TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
// Merge -- it can create a loop which may include loop-carried dependency, // Merge -- it can create a loop which may include loop-carried dependency,
@ -834,8 +827,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
// Also keep track of op counts and costs per op (with their shapes). // Also keep track of op counts and costs per op (with their shapes).
OpContext op_context = GetCurrNode();
string node_description = GetOpDescription(op_context.op_info); string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1; op_counts_[node_description] += 1;
op_costs_[node_description] = op_costs_[node_description] =
@ -886,7 +877,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< ", ready: " << node_state.time_ready.count() << ", ready: " << node_state.time_ready.count()
<< ", scheduled: " << node_state.time_scheduled.count() << ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count(); << ", finished: " << node_state.time_finished.count();
std::vector<const NodeDef*> new_nodes;
if (previously_executed_merge) { if (previously_executed_merge) {
// Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge. // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] " VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
@ -894,7 +885,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< "Skip scheduling its output nodes."; << "Skip scheduling its output nodes.";
} else { } else {
// Checks outputs, and adds ready nodes to queue. // Checks outputs, and adds ready nodes to queue.
AddOutputNodesToReadyQueue(node, curr_time); GetOutputNodes(node, curr_time, &new_nodes);
} }
// Increment num_outputs_executed of the input nodes and maybe update memory. // Increment num_outputs_executed of the input nodes and maybe update memory.
@ -929,13 +920,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
} }
} }
} }
return new_nodes;
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty();
} }
Costs VirtualScheduler::Summary() const { Costs SchedulerState::Summary() const {
// Overall statement about accuracy // Overall statement about accuracy
VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with " VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
<< graph_costs_.num_ops_with_unknown_shapes << graph_costs_.num_ops_with_unknown_shapes
@ -1109,12 +1097,12 @@ Costs VirtualScheduler::Summary() const {
return critical_path_costs; return critical_path_costs;
} }
Costs VirtualScheduler::Summary(RunMetadata* metadata) { Costs SchedulerState::Summary(RunMetadata* metadata) {
if (metadata) GenerateRunMetadata(metadata); if (metadata) GenerateRunMetadata(metadata);
return Summary(); return Summary();
} }
void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) { void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
// Fill RunMetadata's step_stats and partition_graphs fields. // Fill RunMetadata's step_stats and partition_graphs fields.
StepStats* stepstats = metadata->mutable_step_stats(); StepStats* stepstats = metadata->mutable_step_stats();
for (const auto& device : device_) { for (const auto& device : device_) {
@ -1176,7 +1164,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
nodestate.time_scheduled.count()); nodestate.time_scheduled.count());
auto* mem_stats = node_stats->mutable_memory_stats(); auto* mem_stats = node_stats->mutable_memory_stats();
// VirtualScheduler does not specify scratch pad memory usage. // SchedulerState does not specify scratch pad memory usage.
mem_stats->set_temp_memory_size(0); mem_stats->set_temp_memory_size(0);
int64 persistent_memory_size = 0; int64 persistent_memory_size = 0;
if (IsPersistent(*node_def)) { if (IsPersistent(*node_def)) {
@ -1188,7 +1176,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
} }
} }
const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage() const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage()
const { const {
std::unordered_map<string, int64> result; std::unordered_map<string, int64> result;
for (const auto& device : device_) { for (const auto& device : device_) {
@ -1200,7 +1188,7 @@ const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
} }
const std::unordered_map<string, int64> const std::unordered_map<string, int64>
VirtualScheduler::GetPersistentMemoryUsage() const { SchedulerState::GetPersistentMemoryUsage() const {
std::unordered_map<string, int64> result; std::unordered_map<string, int64> result;
for (const auto& device : device_) { for (const auto& device : device_) {
const string& name = device.first; const string& name = device.first;
@ -1217,5 +1205,51 @@ VirtualScheduler::GetPersistentMemoryUsage() const {
} }
return result; return result;
} }
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer)
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
cluster, std::move(placer)),
ready_nodes_(ready_nodes) {}
Status VirtualScheduler::Init(const GrapplerItem* item) {
// SchedulerState::Init() preprocesses the input grappler_item and
// graph_properties to extract necessary information for emulating tensorflow
// op scheduling and construct internal data structures (NodeState and
// DeviceState) for virtual scheduling.
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
std::vector<const NodeDef*> initial_nodes;
auto status = scheduler_state_.Init(item, &initial_nodes);
if (status.ok()) {
// Add the set of initial nodes to ready_nodes_
for (auto node : initial_nodes) {
ready_nodes_->AddNode(node);
}
}
return status;
}
OpContext VirtualScheduler::GetCurrNode() const {
const NodeDef* node = ready_nodes_->GetCurrNode();
return scheduler_state_.CreateOpContext(node);
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
const NodeDef* node = ready_nodes_->GetCurrNode();
auto new_nodes = scheduler_state_.MarkNodeExecuted(
node, node_costs,
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
ready_nodes_->RemoveCurrNode();
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
for (auto node : new_nodes) {
ready_nodes_->AddNode(node);
}
return !ready_nodes_->Empty();
}
} // end namespace grappler } // end namespace grappler
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -32,6 +32,12 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
inline constexpr char kAttrInputSrc[] = "input_source_";
inline constexpr char kAttrSrcDevice[] = "send_device";
inline constexpr char kAttrDstDevice[] = "recv_device";
inline constexpr char kAttrTensorName[] = "tensor_name";
inline constexpr char kChannelDevice[] = "Channel";
struct NodeState { struct NodeState {
// A node (i.e., an op) takes a set of input:port pairs and produces // A node (i.e., an op) takes a set of input:port pairs and produces
// a set of output ports. // a set of output ports.
@ -233,7 +239,7 @@ class HeapReadyManager : public ReadyNodeManager {
// functor for keeping the smallest time_ready node at the front of heap. // functor for keeping the smallest time_ready node at the front of heap.
std::function<bool(const NodeDef*, const NodeDef*)> greater_; std::function<bool(const NodeDef*, const NodeDef*)> greater_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes. // NodeState structure from SchedulerState to get time_ready of ready nodes.
// Not owned by FirstReadyManager. // Not owned by FirstReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_map_; const std::unordered_map<const NodeDef*, NodeState>* node_map_;
}; };
@ -298,7 +304,7 @@ class CompositeNodeManager : public ReadyNodeManager {
FirstReadyManager send_manager_; FirstReadyManager send_manager_;
FirstReadyManager recv_manager_; FirstReadyManager recv_manager_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes. // NodeState structure from SchedulerState to get time_ready of ready nodes.
// Not owned by CompositeReadyManager. // Not owned by CompositeReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_map_; const std::unordered_map<const NodeDef*, NodeState>* node_map_;
@ -310,32 +316,22 @@ class CompositeNodeManager : public ReadyNodeManager {
std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory( std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
const string& ready_node_manager); const string& ready_node_manager);
// The virtual scheduler emulates execution of nodes in a graph, considering // Encapsulates all of the various pieces uses to track state of a scheduler;
// dependencies, device, etc. // enables reuse of all scheduler state-related utilities across different
class VirtualScheduler { // scheduler implementations.
class SchedulerState {
public: public:
// Does not take ownership of cluster or ready_nodes. SchedulerState(const bool use_static_shapes,
VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster, const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer); std::unique_ptr<VirtualPlacer> placer);
// Sets up the graph while also performing some necessary transformations
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
// to begin node schedule and graph simulation.
Status Init(const GrapplerItem* item,
std::vector<const NodeDef*>* initial_nodes,
bool create_explicit_channel_device = true);
// Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be
// reused for a new grappler item. All internal states of the scheduler
// related to the previous grappler item will be reset/cleared.
//
// This function should be called at least once after the scheduler is
// constructed. An uninitialized or failed-to-initialize scheduler will cause
// undefined behavior.
Status Init(const GrapplerItem* item);
OpContext GetCurrNode() const;
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
// Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const; Costs Summary() const;
// Like the above, but writes detailed stats to RunMetadata. // Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary(). // If metadata is nullptr, then just calls and return Summary().
@ -347,34 +343,40 @@ class VirtualScheduler {
// Returns per device memory usage. // Returns per device memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const; const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const; const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
// Returns VirtualScheduler (read only) device and node states. // Returns (read only) device and node states.
const std::unordered_map<string, DeviceState>* GetDeviceStates() const { const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return &device_; return &device_;
} }
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
return &node_map_; return &node_map_;
} }
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } OpContext CreateOpContext(const NodeDef* node) const;
std::vector<const NodeDef*> MarkNodeExecuted(const NodeDef* node,
const Costs& node_costs,
const OpContext& op_context);
private: private:
// Methods called from Init(). Fails if initialize_ is set. // Methods called from Init(). Fails if initialize_ is set.
void MaybeUpdateInputOutput(const NodeDef* node); void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node); NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
// Creates a Send_ and Recv_ pair between from and to. The argument
// create_channel_device tells the function to create an explicit device for
// the channel.
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
const NodeDef* from, const NodeDef* to, const NodeDef* input_node, const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
const string& input_name); const string& input_name, bool create_channel_device);
string DeviceName(const NodeDef* node) const; string DeviceName(const NodeDef* node) const;
string SanitizedDeviceName(const NodeDef* node) const; string SanitizedDeviceName(const NodeDef* node) const;
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
// Helper methods. // Helper methods.
void AddOutputNodesToReadyQueue(const NodeDef* node, void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time,
const Costs::Duration& curr_time); std::vector<const NodeDef*>* output_nodes);
// Scheduler states:
ReadyNodeManager* ready_nodes_; // Not owned.
std::unordered_map<const NodeDef*, NodeState> node_map_; std::unordered_map<const NodeDef*, NodeState> node_map_;
std::unordered_map<string, DeviceState> device_; std::unordered_map<string, DeviceState> device_;
@ -396,16 +398,81 @@ class VirtualScheduler {
// Auxiliary data structures for constructing NodeState and DeviceState. // Auxiliary data structures for constructing NodeState and DeviceState.
std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init(). std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init().
Cluster* cluster_; // Not owned. Cluster* cluster_; // Not owned.
const GrapplerItem* grappler_item_; // Not owned. const GrapplerItem* grappler_item_; // Not owned.
bool use_static_shapes_; bool use_static_shapes_;
bool initialized_; bool initialized_;
bool track_mem_usage_snapshot_; bool track_mem_usage_snapshot_;
const bool use_aggressive_shape_inference_; const bool use_aggressive_shape_inference_;
std::unique_ptr<VirtualPlacer> placer_; std::unique_ptr<VirtualPlacer> placer_;
}; };
// The virtual scheduler emulates execution of nodes in a graph, considering
// dependencies, device, etc.
class VirtualScheduler {
public:
// Does not take ownership of cluster or ready_nodes.
VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer);
// Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be
// reused for a new grappler item. All internal states of the scheduler
// related to the previous grappler item will be reset/cleared.
//
// This function should be called at least once after the scheduler is
// constructed. An uninitialized or failed-to-initialize scheduler will cause
// undefined behavior.
Status Init(const GrapplerItem* item);
// Gets the current scheduled node for execution; the caller of this function
// can accordingly simulate the execution of the current scheduled node.
OpContext GetCurrNode() const;
// Marks the current scheduled node as executed. Note that we should call this
// function only after the execution of the node has been simulated;
// node_costs_ capture the simulated costs of the node.
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
// Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const { return scheduler_state_.Summary(); }
// Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary().
Costs Summary(RunMetadata* metadata) {
return scheduler_state_.Summary(metadata);
}
// Generates RunMetadata's step_stats and partition_graphs fields from results
// of the virtual execution of the graph.
void GenerateRunMetadata(RunMetadata* metadata) {
scheduler_state_.GenerateRunMetadata(metadata);
}
// Returns per device memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
return scheduler_state_.GetPeakMemoryUsage();
}
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
return scheduler_state_.GetPersistentMemoryUsage();
}
// Returns VirtualScheduler (read only) device and node states.
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return scheduler_state_.GetDeviceStates();
}
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
return scheduler_state_.GetNodeStates();
}
void enable_mem_usage_tracking() {
scheduler_state_.enable_mem_usage_tracking();
}
private:
// The state of the scheduler and the execution of the graph is encapsulated
// by the scheduler_state_ object.
SchedulerState scheduler_state_;
// ready_nodes_ is responsible for ordering the traversal of the graph.
ReadyNodeManager* ready_nodes_; // Not owned.
};
} // namespace grappler } // namespace grappler
} // end namespace tensorflow } // end namespace tensorflow