Created a new class, SchedulerState, to encapsulate all of the scheduler state-related functionality we would like to reuse across different scheduler implementations. Scheduler state-related member functions/variables in VirtualScheduler have been moved to SchedulerState accordingly.
PiperOrigin-RevId: 308204183 Change-Id: Ie8ffe167d31844cc82c865ec5ac4a28d0e53d3a9
This commit is contained in:
parent
ecfcb090c5
commit
3921264ef7
|
@ -40,12 +40,6 @@ namespace {
|
|||
|
||||
using ::tensorflow::strings::HumanReadableNumBytes;
|
||||
|
||||
constexpr char kAttrInputSrc[] = "input_source_";
|
||||
constexpr char kAttrSrcDevice[] = "send_device";
|
||||
constexpr char kAttrDstDevice[] = "recv_device";
|
||||
constexpr char kAttrTensorName[] = "tensor_name";
|
||||
constexpr char kChannelDevice[] = "Channel";
|
||||
|
||||
float Round2(const float x) {
|
||||
// Not using std::round from <cmath> here because not all platforms seem to
|
||||
// support that (specifically Android).
|
||||
|
@ -347,13 +341,11 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
SchedulerState::SchedulerState(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer)
|
||||
: ready_nodes_(ready_nodes),
|
||||
graph_costs_(Costs::ZeroCosts()),
|
||||
: graph_costs_(Costs::ZeroCosts()),
|
||||
cluster_(cluster),
|
||||
use_static_shapes_(use_static_shapes),
|
||||
use_aggressive_shape_inference_(use_aggressive_shape_inference),
|
||||
|
@ -364,10 +356,12 @@ VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
|||
track_mem_usage_snapshot_ = VLOG_IS_ON(1);
|
||||
}
|
||||
|
||||
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
Status SchedulerState::Init(const GrapplerItem* item,
|
||||
std::vector<const NodeDef*>* initial_nodes,
|
||||
bool create_explicit_channel_device) {
|
||||
initialized_ = false;
|
||||
|
||||
// Clear all internal states so that the VirtualScheduler is reusable for
|
||||
// Clear all internal states so that the SchedulerState is reusable for
|
||||
// different GrapplerItems
|
||||
node_map_.clear();
|
||||
device_.clear();
|
||||
|
@ -380,14 +374,12 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||
op_counts_.clear();
|
||||
op_costs_.clear();
|
||||
|
||||
// Init() preprocesses the input grappler_item and graph_properties to extract
|
||||
// necessary information for emulating tensorflow op scheduling and
|
||||
// construct internal data structures (NodeState and DeviceState) for virtual
|
||||
// scheduling.
|
||||
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
||||
initial_nodes->clear();
|
||||
|
||||
// Constructs graph properties and performs shape inference.
|
||||
graph_properties_ = absl::make_unique<GraphProperties>(*item);
|
||||
// TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
|
||||
// to get rid of use_static_shapes_ and cluster_.
|
||||
if (use_static_shapes_) {
|
||||
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
|
||||
true, use_aggressive_shape_inference_, true));
|
||||
|
@ -399,6 +391,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||
const auto& graph = grappler_item_->graph;
|
||||
const auto& fetch_nodes = grappler_item_->fetch;
|
||||
std::set<string> feed_nodes;
|
||||
|
||||
for (const auto& f : grappler_item_->feed) {
|
||||
auto iter_and_inserted_flag = feed_nodes.insert(f.first);
|
||||
QCHECK(iter_and_inserted_flag.second)
|
||||
|
@ -486,8 +479,9 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||
} else {
|
||||
// Different device, no cached copy; transfer input_node to the
|
||||
// curr_node's device.
|
||||
auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
|
||||
input_node_name);
|
||||
auto send_and_recv =
|
||||
CreateSendRecv(input_node, curr_node, input_node, input_node_name,
|
||||
create_explicit_channel_device);
|
||||
// Note that CreateSendRecv() already connected input/output between
|
||||
// _Send and _Recv ops.
|
||||
const auto* send = send_and_recv.first;
|
||||
|
@ -514,7 +508,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||
|
||||
if (given_as_feed || has_no_inputs) {
|
||||
curr_node_state.time_ready = Costs::Duration();
|
||||
ready_nodes_->AddNode(curr_node);
|
||||
initial_nodes->push_back(curr_node);
|
||||
VLOG(3) << "Added ready node: " << curr_node->name();
|
||||
}
|
||||
|
||||
|
@ -530,7 +524,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||
}
|
||||
}
|
||||
|
||||
if (ready_nodes_->Empty()) {
|
||||
if (initial_nodes->empty()) {
|
||||
return errors::InvalidArgument("No ready nodes in the graph.");
|
||||
}
|
||||
|
||||
|
@ -546,20 +540,20 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
|
||||
void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
|
||||
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
|
||||
// This method is called when NodeState is created and adds input and output
|
||||
// properties for a few exceptional cases that GraphProperties cannot provide
|
||||
// input/output properties.
|
||||
if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
|
||||
// _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
|
||||
// _Send and _Recv ops created from SchedulerState have kAttrInputSrc
|
||||
// attr; normal _Send and _Recv ops (from the input graph) do not have that
|
||||
// attr.
|
||||
auto& node_state = node_map_[node];
|
||||
auto& inputs = node_state.input_properties;
|
||||
auto& outputs = node_state.output_properties;
|
||||
|
||||
// _Send and _Recv ops are created from VirtualScheduler, so
|
||||
// _Send and _Recv ops are created from SchedulerState, so
|
||||
// there should be no inputs TensorProperties.
|
||||
CHECK(inputs.empty());
|
||||
CHECK(outputs.empty());
|
||||
|
@ -595,27 +589,27 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
|
|||
}
|
||||
}
|
||||
|
||||
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||
string SchedulerState::DeviceName(const NodeDef* node) const {
|
||||
return placer_->get_canonical_device_name(*node);
|
||||
}
|
||||
|
||||
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
|
||||
string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
|
||||
// Replace the ":" characters that may be present in the device name with "_".
|
||||
// This makes it possible to then use the resulting string in a node name.
|
||||
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
|
||||
{{":", "_"}});
|
||||
}
|
||||
|
||||
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
|
||||
string SchedulerState::ChannelDeviceName(const NodeDef* from,
|
||||
const NodeDef* to) const {
|
||||
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
|
||||
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
|
||||
"_to_", SanitizedDeviceName(to));
|
||||
}
|
||||
|
||||
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||
std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
|
||||
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
|
||||
const string& input_name) {
|
||||
const string& input_name, bool create_channel_device) {
|
||||
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
|
||||
|
||||
// Connect "from" node to "to" node with _Send and _Recv such that
|
||||
|
@ -643,7 +637,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
|||
"_to_" + SanitizedDeviceName(to));
|
||||
send->set_op("_Send");
|
||||
send->add_input(from->name());
|
||||
send->set_device(ChannelDeviceName(from, to));
|
||||
auto send_device =
|
||||
create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
|
||||
send->set_device(send_device);
|
||||
auto& send_attr = *(send->mutable_attr());
|
||||
send_attr[kAttrInputSrc].set_s(input_name);
|
||||
send_attr[kAttrSrcDevice].set_s(DeviceName(from));
|
||||
|
@ -687,9 +683,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
|||
return std::make_pair(send, recv);
|
||||
}
|
||||
|
||||
OpContext VirtualScheduler::GetCurrNode() const {
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
|
||||
OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
|
||||
// Get the device from the placer.
|
||||
DeviceProperties device;
|
||||
device = placer_->get_device(*node);
|
||||
|
@ -721,7 +715,7 @@ OpContext VirtualScheduler::GetCurrNode() const {
|
|||
return op_context;
|
||||
}
|
||||
|
||||
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||
NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
|
||||
|
||||
auto it = node_map_.find(node);
|
||||
|
@ -766,8 +760,9 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
void VirtualScheduler::AddOutputNodesToReadyQueue(
|
||||
const NodeDef* node, const Costs::Duration& curr_time) {
|
||||
void SchedulerState::GetOutputNodes(const NodeDef* node,
|
||||
const Costs::Duration& curr_time,
|
||||
std::vector<const NodeDef*>* output_nodes) {
|
||||
// Checks whether the Switch's output slots change over iterations.
|
||||
int slot = -1;
|
||||
if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
|
||||
|
@ -780,7 +775,6 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment num_inputs_ready of the output nodes and maybe add to ready
|
||||
// nodes.
|
||||
auto& node_state = node_map_[node];
|
||||
|
@ -799,16 +793,15 @@ void VirtualScheduler::AddOutputNodesToReadyQueue(
|
|||
IsMerge(*output_node)) {
|
||||
// This output node is now ready.
|
||||
output_state.time_ready = curr_time;
|
||||
ready_nodes_->AddNode(output_node);
|
||||
output_nodes->push_back(output_node);
|
||||
VLOG(3) << " Add output: " << output_node->name();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
// Update graph_costs_ and per-op costs.
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
|
||||
const NodeDef* node, const Costs& node_costs, const OpContext& op_context) {
|
||||
auto& node_state = node_map_[node];
|
||||
// TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
|
||||
// Merge -- it can create a loop which may include loop-carried dependency,
|
||||
|
@ -834,8 +827,6 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
// Also keep track of op counts and costs per op (with their shapes).
|
||||
OpContext op_context = GetCurrNode();
|
||||
|
||||
string node_description = GetOpDescription(op_context.op_info);
|
||||
op_counts_[node_description] += 1;
|
||||
op_costs_[node_description] =
|
||||
|
@ -886,7 +877,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
<< ", ready: " << node_state.time_ready.count()
|
||||
<< ", scheduled: " << node_state.time_scheduled.count()
|
||||
<< ", finished: " << node_state.time_finished.count();
|
||||
|
||||
std::vector<const NodeDef*> new_nodes;
|
||||
if (previously_executed_merge) {
|
||||
// Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
|
||||
VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
|
||||
|
@ -894,7 +885,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
<< "Skip scheduling its output nodes.";
|
||||
} else {
|
||||
// Checks outputs, and adds ready nodes to queue.
|
||||
AddOutputNodesToReadyQueue(node, curr_time);
|
||||
GetOutputNodes(node, curr_time, &new_nodes);
|
||||
}
|
||||
|
||||
// Increment num_outputs_executed of the input nodes and maybe update memory.
|
||||
|
@ -929,13 +920,10 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
|
||||
return !ready_nodes_->Empty();
|
||||
return new_nodes;
|
||||
}
|
||||
|
||||
Costs VirtualScheduler::Summary() const {
|
||||
Costs SchedulerState::Summary() const {
|
||||
// Overall statement about accuracy
|
||||
VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
|
||||
<< graph_costs_.num_ops_with_unknown_shapes
|
||||
|
@ -1109,12 +1097,12 @@ Costs VirtualScheduler::Summary() const {
|
|||
return critical_path_costs;
|
||||
}
|
||||
|
||||
Costs VirtualScheduler::Summary(RunMetadata* metadata) {
|
||||
Costs SchedulerState::Summary(RunMetadata* metadata) {
|
||||
if (metadata) GenerateRunMetadata(metadata);
|
||||
return Summary();
|
||||
}
|
||||
|
||||
void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
|
||||
void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
|
||||
// Fill RunMetadata's step_stats and partition_graphs fields.
|
||||
StepStats* stepstats = metadata->mutable_step_stats();
|
||||
for (const auto& device : device_) {
|
||||
|
@ -1176,7 +1164,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
|
|||
nodestate.time_scheduled.count());
|
||||
|
||||
auto* mem_stats = node_stats->mutable_memory_stats();
|
||||
// VirtualScheduler does not specify scratch pad memory usage.
|
||||
// SchedulerState does not specify scratch pad memory usage.
|
||||
mem_stats->set_temp_memory_size(0);
|
||||
int64 persistent_memory_size = 0;
|
||||
if (IsPersistent(*node_def)) {
|
||||
|
@ -1188,7 +1176,7 @@ void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
|
|||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
|
||||
const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage()
|
||||
const {
|
||||
std::unordered_map<string, int64> result;
|
||||
for (const auto& device : device_) {
|
||||
|
@ -1200,7 +1188,7 @@ const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
|
|||
}
|
||||
|
||||
const std::unordered_map<string, int64>
|
||||
VirtualScheduler::GetPersistentMemoryUsage() const {
|
||||
SchedulerState::GetPersistentMemoryUsage() const {
|
||||
std::unordered_map<string, int64> result;
|
||||
for (const auto& device : device_) {
|
||||
const string& name = device.first;
|
||||
|
@ -1217,5 +1205,51 @@ VirtualScheduler::GetPersistentMemoryUsage() const {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer)
|
||||
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
|
||||
cluster, std::move(placer)),
|
||||
ready_nodes_(ready_nodes) {}
|
||||
|
||||
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
// SchedulerState::Init() preprocesses the input grappler_item and
|
||||
// graph_properties to extract necessary information for emulating tensorflow
|
||||
// op scheduling and construct internal data structures (NodeState and
|
||||
// DeviceState) for virtual scheduling.
|
||||
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
||||
std::vector<const NodeDef*> initial_nodes;
|
||||
auto status = scheduler_state_.Init(item, &initial_nodes);
|
||||
if (status.ok()) {
|
||||
// Add the set of initial nodes to ready_nodes_
|
||||
for (auto node : initial_nodes) {
|
||||
ready_nodes_->AddNode(node);
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
OpContext VirtualScheduler::GetCurrNode() const {
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
return scheduler_state_.CreateOpContext(node);
|
||||
}
|
||||
|
||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
// Update graph_costs_ and per-op costs.
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
auto new_nodes = scheduler_state_.MarkNodeExecuted(
|
||||
node, node_costs,
|
||||
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
|
||||
for (auto node : new_nodes) {
|
||||
ready_nodes_->AddNode(node);
|
||||
}
|
||||
return !ready_nodes_->Empty();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
|
|
@ -32,6 +32,12 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
inline constexpr char kAttrInputSrc[] = "input_source_";
|
||||
inline constexpr char kAttrSrcDevice[] = "send_device";
|
||||
inline constexpr char kAttrDstDevice[] = "recv_device";
|
||||
inline constexpr char kAttrTensorName[] = "tensor_name";
|
||||
inline constexpr char kChannelDevice[] = "Channel";
|
||||
|
||||
struct NodeState {
|
||||
// A node (i.e., an op) takes a set of input:port pairs and produces
|
||||
// a set of output ports.
|
||||
|
@ -233,7 +239,7 @@ class HeapReadyManager : public ReadyNodeManager {
|
|||
// functor for keeping the smallest time_ready node at the front of heap.
|
||||
std::function<bool(const NodeDef*, const NodeDef*)> greater_;
|
||||
|
||||
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
|
||||
// NodeState structure from SchedulerState to get time_ready of ready nodes.
|
||||
// Not owned by FirstReadyManager.
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
||||
};
|
||||
|
@ -298,7 +304,7 @@ class CompositeNodeManager : public ReadyNodeManager {
|
|||
FirstReadyManager send_manager_;
|
||||
FirstReadyManager recv_manager_;
|
||||
|
||||
// NodeState structure from VirtualScheduler to get time_ready of ready nodes.
|
||||
// NodeState structure from SchedulerState to get time_ready of ready nodes.
|
||||
// Not owned by CompositeReadyManager.
|
||||
const std::unordered_map<const NodeDef*, NodeState>* node_map_;
|
||||
|
||||
|
@ -310,32 +316,22 @@ class CompositeNodeManager : public ReadyNodeManager {
|
|||
std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
||||
const string& ready_node_manager);
|
||||
|
||||
// The virtual scheduler emulates execution of nodes in a graph, considering
|
||||
// dependencies, device, etc.
|
||||
class VirtualScheduler {
|
||||
// Encapsulates all of the various pieces uses to track state of a scheduler;
|
||||
// enables reuse of all scheduler state-related utilities across different
|
||||
// scheduler implementations.
|
||||
class SchedulerState {
|
||||
public:
|
||||
// Does not take ownership of cluster or ready_nodes.
|
||||
VirtualScheduler(const bool use_static_shapes,
|
||||
SchedulerState(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer);
|
||||
// Sets up the graph while also performing some necessary transformations
|
||||
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
|
||||
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
|
||||
// to begin node schedule and graph simulation.
|
||||
Status Init(const GrapplerItem* item,
|
||||
std::vector<const NodeDef*>* initial_nodes,
|
||||
bool create_explicit_channel_device = true);
|
||||
|
||||
// Initializes the scheduler for the specific grappler item.
|
||||
// Should be called immediately after the c'tor or when the scheduler will be
|
||||
// reused for a new grappler item. All internal states of the scheduler
|
||||
// related to the previous grappler item will be reset/cleared.
|
||||
//
|
||||
// This function should be called at least once after the scheduler is
|
||||
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
||||
// undefined behavior.
|
||||
Status Init(const GrapplerItem* item);
|
||||
|
||||
OpContext GetCurrNode() const;
|
||||
|
||||
// Returns true if there is any node to be scheduled.
|
||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||
|
||||
// Prints out summary of execution (timing, memory usage, etc.)
|
||||
Costs Summary() const;
|
||||
// Like the above, but writes detailed stats to RunMetadata.
|
||||
// If metadata is nullptr, then just calls and return Summary().
|
||||
|
@ -347,34 +343,40 @@ class VirtualScheduler {
|
|||
// Returns per device memory usage.
|
||||
const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
|
||||
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
|
||||
|
||||
// Returns VirtualScheduler (read only) device and node states.
|
||||
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
|
||||
// Returns (read only) device and node states.
|
||||
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
||||
return &device_;
|
||||
}
|
||||
|
||||
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
||||
return &node_map_;
|
||||
}
|
||||
|
||||
void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
|
||||
OpContext CreateOpContext(const NodeDef* node) const;
|
||||
std::vector<const NodeDef*> MarkNodeExecuted(const NodeDef* node,
|
||||
const Costs& node_costs,
|
||||
const OpContext& op_context);
|
||||
|
||||
private:
|
||||
// Methods called from Init(). Fails if initialize_ is set.
|
||||
|
||||
void MaybeUpdateInputOutput(const NodeDef* node);
|
||||
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
||||
// Creates a Send_ and Recv_ pair between from and to. The argument
|
||||
// create_channel_device tells the function to create an explicit device for
|
||||
// the channel.
|
||||
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
|
||||
const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
|
||||
const string& input_name);
|
||||
const string& input_name, bool create_channel_device);
|
||||
string DeviceName(const NodeDef* node) const;
|
||||
string SanitizedDeviceName(const NodeDef* node) const;
|
||||
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
|
||||
|
||||
// Helper methods.
|
||||
void AddOutputNodesToReadyQueue(const NodeDef* node,
|
||||
const Costs::Duration& curr_time);
|
||||
void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time,
|
||||
std::vector<const NodeDef*>* output_nodes);
|
||||
|
||||
// Scheduler states:
|
||||
ReadyNodeManager* ready_nodes_; // Not owned.
|
||||
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
||||
std::unordered_map<string, DeviceState> device_;
|
||||
|
||||
|
@ -396,16 +398,81 @@ class VirtualScheduler {
|
|||
// Auxiliary data structures for constructing NodeState and DeviceState.
|
||||
std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init().
|
||||
Cluster* cluster_; // Not owned.
|
||||
|
||||
const GrapplerItem* grappler_item_; // Not owned.
|
||||
bool use_static_shapes_;
|
||||
bool initialized_;
|
||||
bool track_mem_usage_snapshot_;
|
||||
const bool use_aggressive_shape_inference_;
|
||||
|
||||
std::unique_ptr<VirtualPlacer> placer_;
|
||||
};
|
||||
|
||||
// The virtual scheduler emulates execution of nodes in a graph, considering
|
||||
// dependencies, device, etc.
|
||||
class VirtualScheduler {
|
||||
public:
|
||||
// Does not take ownership of cluster or ready_nodes.
|
||||
VirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer);
|
||||
|
||||
// Initializes the scheduler for the specific grappler item.
|
||||
// Should be called immediately after the c'tor or when the scheduler will be
|
||||
// reused for a new grappler item. All internal states of the scheduler
|
||||
// related to the previous grappler item will be reset/cleared.
|
||||
//
|
||||
// This function should be called at least once after the scheduler is
|
||||
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
||||
// undefined behavior.
|
||||
Status Init(const GrapplerItem* item);
|
||||
|
||||
// Gets the current scheduled node for execution; the caller of this function
|
||||
// can accordingly simulate the execution of the current scheduled node.
|
||||
OpContext GetCurrNode() const;
|
||||
// Marks the current scheduled node as executed. Note that we should call this
|
||||
// function only after the execution of the node has been simulated;
|
||||
// node_costs_ capture the simulated costs of the node.
|
||||
// Returns true if there is any node to be scheduled.
|
||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||
|
||||
// Prints out summary of execution (timing, memory usage, etc.)
|
||||
Costs Summary() const { return scheduler_state_.Summary(); }
|
||||
// Like the above, but writes detailed stats to RunMetadata.
|
||||
// If metadata is nullptr, then just calls and return Summary().
|
||||
Costs Summary(RunMetadata* metadata) {
|
||||
return scheduler_state_.Summary(metadata);
|
||||
}
|
||||
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
||||
// of the virtual execution of the graph.
|
||||
void GenerateRunMetadata(RunMetadata* metadata) {
|
||||
scheduler_state_.GenerateRunMetadata(metadata);
|
||||
}
|
||||
// Returns per device memory usage.
|
||||
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
|
||||
return scheduler_state_.GetPeakMemoryUsage();
|
||||
}
|
||||
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
|
||||
return scheduler_state_.GetPersistentMemoryUsage();
|
||||
}
|
||||
// Returns VirtualScheduler (read only) device and node states.
|
||||
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
||||
return scheduler_state_.GetDeviceStates();
|
||||
}
|
||||
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
||||
return scheduler_state_.GetNodeStates();
|
||||
}
|
||||
void enable_mem_usage_tracking() {
|
||||
scheduler_state_.enable_mem_usage_tracking();
|
||||
}
|
||||
|
||||
private:
|
||||
// The state of the scheduler and the execution of the graph is encapsulated
|
||||
// by the scheduler_state_ object.
|
||||
SchedulerState scheduler_state_;
|
||||
// ready_nodes_ is responsible for ordering the traversal of the graph.
|
||||
ReadyNodeManager* ready_nodes_; // Not owned.
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
|
Loading…
Reference in New Issue