Created a new class, SchedulerState, to encapsulate all of the scheduler state-related functionality we would like to reuse across different scheduler implementations. Scheduler state-related member functions/variables in VirtualScheduler have been moved to SchedulerState accordingly.

PiperOrigin-RevId: 308204183
Change-Id: Ie8ffe167d31844cc82c865ec5ac4a28d0e53d3a9
This commit is contained in:
A. Unique TensorFlower 2020-04-23 23:42:35 -07:00 committed by TensorFlower Gardener
parent ecfcb090c5
commit 3921264ef7
2 changed files with 196 additions and 95 deletions

View File

@ -40,12 +40,6 @@ namespace {
using ::tensorflow::strings::HumanReadableNumBytes;
constexpr char kAttrInputSrc[] = "input_source_";
constexpr char kAttrSrcDevice[] = "send_device";
constexpr char kAttrDstDevice[] = "recv_device";
constexpr char kAttrTensorName[] = "tensor_name";
constexpr char kChannelDevice[] = "Channel";
float Round2(const float x) {
// Not using std::round from <cmath> here because not all platforms seem to
// support that (specifically Android).
@ -347,13 +341,11 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
return nullptr;
}
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer)
: ready_nodes_(ready_nodes),
graph_costs_(Costs::ZeroCosts()),
SchedulerState::SchedulerState(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
std::unique_ptr<VirtualPlacer> placer)
: graph_costs_(Costs::ZeroCosts()),
cluster_(cluster),
use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference),
@ -364,10 +356,12 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
}
Status VirtualScheduler::Init(const GrapplerItem* item) {
Status SchedulerState::Init(const GrapplerItem* item,
std::vector<const NodeDef*>* initial_nodes,
bool create_explicit_channel_device) {
initialized_ = false;
// Clear all internal states so that the VirtualScheduler is reusable for
// Clear all internal states so that the SchedulerState is reusable for
// different GrapplerItems
node_map_.clear();
device_.clear();
@ -380,14 +374,12 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
op_counts_.clear();
op_costs_.clear();
// Init() preprocesses the input grappler_item and graph_properties to extract
// necessary information for emulating tensorflow op scheduling and
// construct internal data structures (NodeState and DeviceState) for virtual
// scheduling.
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
initial_nodes->clear();
// Constructs graph properties and performs shape inference.
graph_properties_ = absl::make_unique<GraphProperties>(*item);
// TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
// to get rid of use_static_shapes_ and cluster_.
if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_, true));
@ -399,6 +391,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch;
std::set<string> feed_nodes;
for (const auto& f : grappler_item_->feed) {
auto iter_and_inserted_flag = feed_nodes.insert(f.first);
QCHECK(iter_and_inserted_flag.second)
@ -486,8 +479,9 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
} else {
// Different device, no cached copy; transfer input_node to the
// curr_node's device.
auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
input_node_name);
auto send_and_recv =
CreateSendRecv(input_node, curr_node, input_node, input_node_name,
create_explicit_channel_device);
// Note that CreateSendRecv() already connected input/output between
// _Send and _Recv ops.
const auto* send = send_and_recv.first;
@ -514,7 +508,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
if (given_as_feed || has_no_inputs) {
curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node);
initial_nodes->push_back(curr_node);
VLOG(3) << "Added ready node: " << curr_node->name();
}
@ -530,7 +524,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
}
}
if (ready_nodes_->Empty()) {
if (initial_nodes->empty()) {
return errors::InvalidArgument("No ready nodes in the graph.");
}
@ -546,20 +540,20 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
return Status::OK();
}
void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
// This method is called when NodeState is created and adds input and output
// properties for a few exceptional cases that GraphProperties cannot provide
// input/output properties.
if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
// _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
// _Send and _Recv ops created from SchedulerState have kAttrInputSrc
// attr; normal _Send and _Recv ops (from the input graph) do not have that
// attr.
auto& node_state = node_map_[node];
auto& inputs = node_state.input_properties;
auto& outputs = node_state.output_properties;
// _Send and _Recv ops are created from VirtualScheduler, so
// _Send and _Recv ops are created from SchedulerState, so
// there should be no inputs TensorProperties.
CHECK(inputs.empty());
CHECK(outputs.empty());
@ -595,27 +589,27 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
}
}
string VirtualScheduler::DeviceName(const NodeDef* node) const {
string SchedulerState::DeviceName(const NodeDef* node) const {
return placer_->get_canonical_device_name(*node);
}
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
// Replace the ":" characters that may be present in the device name with "_".
// This makes it possible to then use the resulting string in a node name.
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
{{":", "_"}});
}
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const {
string SchedulerState::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const {
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
"_to_", SanitizedDeviceName(to));
}
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
const string& input_name) {
const string& input_name, bool create_channel_device) {
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
// Connect "from" node to "to" node with _Send and _Recv such that
@ -643,7 +637,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
"_to_" + SanitizedDeviceName(to));
send->set_op("_Send");
send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to));
auto send_device =
create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
send->set_device(send_device);
auto& send_attr = *(send->mutable_attr());
send_attr[kAttrInputSrc].set_s(input_name);
send_attr[kAttrSrcDevice].set_s(DeviceName(from));
@ -687,9 +683,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
return std::make_pair(send, recv);
}
OpContext VirtualScheduler::GetCurrNode() const {
const NodeDef* node = ready_nodes_->GetCurrNode();
OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
// Get the device from the placer.
DeviceProperties device;
device = placer_->get_device(*node);
@ -721,7 +715,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
return op_context;
}
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
auto it = node_map_.find(node);
@ -766,8 +760,9 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
return it->second;
}
void VirtualScheduler::AddOutputNodesToReadyQueue(
const NodeDef* node, const Costs::Duration& curr_time) {
void SchedulerState::GetOutputNodes(const NodeDef* node,
const Costs::Duration& curr_time,
std::vector<const NodeDef*>* output_nodes) {
// Checks whether the Switch's output slots change over iterations.
int slot = -1;
if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
@ -780,7 +775,6 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
}
}
}
// Increment num_inputs_ready of the output nodes and maybe add to ready
// nodes.
auto& node_state = node_map_[node];
@ -799,16 +793,15 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
IsMerge(*output_node)) {
// This output node is now ready.
output_state.time_ready = curr_time;
ready_nodes_->AddNode(output_node);
output_nodes->push_back(output_node);
VLOG(3) << " Add output: " << output_node->name();
}
}
}
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
const NodeDef* node = ready_nodes_->GetCurrNode();
std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
const NodeDef* node, const Costs& node_costs, const OpContext& op_context) {
auto& node_state = node_map_[node];
// TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
// Merge -- it can create a loop which may include loop-carried dependency,
@ -834,8 +827,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
if (VLOG_IS_ON(2)) {
// Also keep track of op counts and costs per op (with their shapes).
OpContext op_context = GetCurrNode();
string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1;
op_costs_[node_description] =
@ -886,7 +877,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< ", ready: " << node_state.time_ready.count()
<< ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count();
std::vector<const NodeDef*> new_nodes;
if (previously_executed_merge) {
// Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
@ -894,7 +885,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< "Skip scheduling its output nodes.";
} else {
// Checks outputs, and adds ready nodes to queue.
AddOutputNodesToReadyQueue(node, curr_time);
GetOutputNodes(node, curr_time, &new_nodes);
}
// Increment num_outputs_executed of the input nodes and maybe update memory.
@ -929,13 +920,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
}
}
}
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty();
return new_nodes;
}
Costs VirtualScheduler::Summary() const {
Costs SchedulerState::Summary() const {
// Overall statement about accuracy
VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
<< graph_costs_.num_ops_with_unknown_shapes
@ -1109,12 +1097,12 @@ Costs VirtualScheduler::Summary() const {
return critical_path_costs;
}
Costs VirtualScheduler::Summary(RunMetadata* metadata) {
Costs SchedulerState::Summary(RunMetadata* metadata) {
if (metadata) GenerateRunMetadata(metadata);
return Summary();
}
void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
// Fill RunMetadata's step_stats and partition_graphs fields.
StepStats* stepstats = metadata->mutable_step_stats();
for (const auto& device : device_) {
@ -1176,7 +1164,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
nodestate.time_scheduled.count());
auto* mem_stats = node_stats->mutable_memory_stats();
// VirtualScheduler does not specify scratch pad memory usage.
// SchedulerState does not specify scratch pad memory usage.
mem_stats->set_temp_memory_size(0);
int64 persistent_memory_size = 0;
if (IsPersistent(*node_def)) {
@ -1188,7 +1176,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
}
}
const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage()
const {
std::unordered_map<string, int64> result;
for (const auto& device : device_) {
@ -1200,7 +1188,7 @@ const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
}
const std::unordered_map<string, int64>
VirtualScheduler::GetPersistentMemoryUsage() const {
SchedulerState::GetPersistentMemoryUsage() const {
std::unordered_map<string, int64> result;
for (const auto& device : device_) {
const string& name = device.first;
@ -1217,5 +1205,51 @@ VirtualScheduler::GetPersistentMemoryUsage() const {
}
return result;
}
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer)
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
cluster, std::move(placer)),
ready_nodes_(ready_nodes) {}
Status VirtualScheduler::Init(const GrapplerItem* item) {
// SchedulerState::Init() preprocesses the input grappler_item and
// graph_properties to extract necessary information for emulating tensorflow
// op scheduling and construct internal data structures (NodeState and
// DeviceState) for virtual scheduling.
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
std::vector<const NodeDef*> initial_nodes;
auto status = scheduler_state_.Init(item, &initial_nodes);
if (status.ok()) {
// Add the set of initial nodes to ready_nodes_
for (auto node : initial_nodes) {
ready_nodes_->AddNode(node);
}
}
return status;
}
OpContext VirtualScheduler::GetCurrNode() const {
const NodeDef* node = ready_nodes_->GetCurrNode();
return scheduler_state_.CreateOpContext(node);
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
const NodeDef* node = ready_nodes_->GetCurrNode();
auto new_nodes = scheduler_state_.MarkNodeExecuted(
node, node_costs,
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
ready_nodes_->RemoveCurrNode();
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
for (auto node : new_nodes) {
ready_nodes_->AddNode(node);
}
return !ready_nodes_->Empty();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -32,6 +32,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
inline constexpr char kAttrInputSrc[] = "input_source_";
inline constexpr char kAttrSrcDevice[] = "send_device";
inline constexpr char kAttrDstDevice[] = "recv_device";
inline constexpr char kAttrTensorName[] = "tensor_name";
inline constexpr char kChannelDevice[] = "Channel";
struct NodeState {
// A node (i.e., an op) takes a set of input:port pairs and produces
// a set of output ports.
@ -233,7 +239,7 @@ class HeapReadyManager : public ReadyNodeManager {
// functor for keeping the smallest time_ready node at the front of heap.
std::function<bool(const NodeDef*, const NodeDef*)> greater_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
// NodeState structure from SchedulerState to get time_ready of ready nodes.
// Not owned by FirstReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
};
@ -298,7 +304,7 @@ class CompositeNodeManager : public ReadyNodeManager {
FirstReadyManager send_manager_;
FirstReadyManager recv_manager_;
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
// NodeState structure from SchedulerState to get time_ready of ready nodes.
// Not owned by CompositeReadyManager.
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
@ -310,32 +316,22 @@ class CompositeNodeManager : public ReadyNodeManager {
std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
const string& ready_node_manager);
// The virtual scheduler emulates execution of nodes in a graph, considering
// dependencies, device, etc.
class VirtualScheduler {
// Encapsulates all of the various pieces uses to track state of a scheduler;
// enables reuse of all scheduler state-related utilities across different
// scheduler implementations.
class SchedulerState {
public:
// Does not take ownership of cluster or ready_nodes.
VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer);
SchedulerState(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster,
std::unique_ptr<VirtualPlacer> placer);
// Sets up the graph while also performing some necessary transformations
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
// to begin node schedule and graph simulation.
Status Init(const GrapplerItem* item,
std::vector<const NodeDef*>* initial_nodes,
bool create_explicit_channel_device = true);
// Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be
// reused for a new grappler item. All internal states of the scheduler
// related to the previous grappler item will be reset/cleared.
//
// This function should be called at least once after the scheduler is
// constructed. An uninitialized or failed-to-initialize scheduler will cause
// undefined behavior.
Status Init(const GrapplerItem* item);
OpContext GetCurrNode() const;
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
// Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const;
// Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary().
@ -347,34 +343,40 @@ class VirtualScheduler {
// Returns per device memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
// Returns VirtualScheduler (read only) device and node states.
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
// Returns (read only) device and node states.
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return &device_;
}
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
return &node_map_;
}
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
OpContext CreateOpContext(const NodeDef* node) const;
std::vector<const NodeDef*> MarkNodeExecuted(const NodeDef* node,
const Costs& node_costs,
const OpContext& op_context);
private:
// Methods called from Init(). Fails if initialize_ is set.
void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
// Creates a Send_ and Recv_ pair between from and to. The argument
// create_channel_device tells the function to create an explicit device for
// the channel.
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
const string& input_name);
const string& input_name, bool create_channel_device);
string DeviceName(const NodeDef* node) const;
string SanitizedDeviceName(const NodeDef* node) const;
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
// Helper methods.
void AddOutputNodesToReadyQueue(const NodeDef* node,
const Costs::Duration& curr_time);
void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time,
std::vector<const NodeDef*>* output_nodes);
// Scheduler states:
ReadyNodeManager* ready_nodes_; // Not owned.
std::unordered_map<const NodeDef*, NodeState> node_map_;
std::unordered_map<string, DeviceState> device_;
@ -396,16 +398,81 @@ class VirtualScheduler {
// Auxiliary data structures for constructing NodeState and DeviceState.
std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init().
Cluster* cluster_; // Not owned.
const GrapplerItem* grappler_item_; // Not owned.
bool use_static_shapes_;
bool initialized_;
bool track_mem_usage_snapshot_;
const bool use_aggressive_shape_inference_;
std::unique_ptr<VirtualPlacer> placer_;
};
// The virtual scheduler emulates execution of nodes in a graph, considering
// dependencies, device, etc.
class VirtualScheduler {
public:
// Does not take ownership of cluster or ready_nodes.
VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer);
// Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be
// reused for a new grappler item. All internal states of the scheduler
// related to the previous grappler item will be reset/cleared.
//
// This function should be called at least once after the scheduler is
// constructed. An uninitialized or failed-to-initialize scheduler will cause
// undefined behavior.
Status Init(const GrapplerItem* item);
// Gets the current scheduled node for execution; the caller of this function
// can accordingly simulate the execution of the current scheduled node.
OpContext GetCurrNode() const;
// Marks the current scheduled node as executed. Note that we should call this
// function only after the execution of the node has been simulated;
// node_costs_ capture the simulated costs of the node.
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
// Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const { return scheduler_state_.Summary(); }
// Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary().
Costs Summary(RunMetadata* metadata) {
return scheduler_state_.Summary(metadata);
}
// Generates RunMetadata's step_stats and partition_graphs fields from results
// of the virtual execution of the graph.
void GenerateRunMetadata(RunMetadata* metadata) {
scheduler_state_.GenerateRunMetadata(metadata);
}
// Returns per device memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
return scheduler_state_.GetPeakMemoryUsage();
}
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
return scheduler_state_.GetPersistentMemoryUsage();
}
// Returns VirtualScheduler (read only) device and node states.
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return scheduler_state_.GetDeviceStates();
}
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
return scheduler_state_.GetNodeStates();
}
void enable_mem_usage_tracking() {
scheduler_state_.enable_mem_usage_tracking();
}
private:
// The state of the scheduler and the execution of the graph is encapsulated
// by the scheduler_state_ object.
SchedulerState scheduler_state_;
// ready_nodes_ is responsible for ordering the traversal of the graph.
ReadyNodeManager* ready_nodes_; // Not owned.
};
} // namespace grappler
} // end namespace tensorflow