This CL makes VirtualScheduler and SchedulerState polymorphic.
PiperOrigin-RevId: 336700189 Change-Id: I2dfe391f7e12ee325e88260d10f650b5e702cea7
This commit is contained in:
parent
ed5360e8f6
commit
0c4e2e7bc7
@ -361,6 +361,8 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SchedulerState::~SchedulerState() {}
|
||||
|
||||
SchedulerState::SchedulerState(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster,
|
||||
@ -1259,15 +1261,23 @@ void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
|
||||
node_state.time_scheduled = device.GetCurrTime();
|
||||
}
|
||||
|
||||
VirtualScheduler::~VirtualScheduler() {}
|
||||
|
||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer)
|
||||
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
|
||||
cluster, std::move(placer)),
|
||||
: scheduler_state_(absl::make_unique<SchedulerState>(
|
||||
use_static_shapes, use_aggressive_shape_inference, cluster,
|
||||
std::move(placer))),
|
||||
ready_nodes_(ready_nodes) {}
|
||||
|
||||
VirtualScheduler::VirtualScheduler(
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<SchedulerState> scheduler_state)
|
||||
: scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
|
||||
|
||||
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
// SchedulerState::Init() preprocesses the input grappler_item and
|
||||
// graph_properties to extract necessary information for emulating tensorflow
|
||||
@ -1275,7 +1285,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
// DeviceState) for virtual scheduling.
|
||||
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
||||
std::vector<const NodeDef*> initial_nodes;
|
||||
auto status = scheduler_state_.Init(item, &initial_nodes);
|
||||
auto status = scheduler_state_->Init(item, &initial_nodes);
|
||||
if (status.ok()) {
|
||||
// Add the set of initial nodes to ready_nodes_
|
||||
for (auto node : initial_nodes) {
|
||||
@ -1285,17 +1295,17 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
return status;
|
||||
}
|
||||
|
||||
OpContext VirtualScheduler::GetCurrNode() const {
|
||||
OpContext VirtualScheduler::GetCurrNode() {
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
return scheduler_state_.CreateOpContext(node);
|
||||
return scheduler_state_->CreateOpContext(node);
|
||||
}
|
||||
|
||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
// Update graph_costs_ and per-op costs.
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
auto new_nodes = scheduler_state_.MarkNodeExecuted(
|
||||
auto new_nodes = scheduler_state_->MarkNodeExecuted(
|
||||
node, node_costs,
|
||||
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
|
||||
scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
|
||||
for (auto node : new_nodes) {
|
||||
|
@ -324,6 +324,21 @@ class SchedulerState {
|
||||
SchedulerState(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||
std::unique_ptr<VirtualPlacer> placer);
|
||||
// Move constructor. Explicitly defined because it otherwise gets implicitly
|
||||
// deleted. SchedulerState is a move-only class, as we have a <unique_ptr>
|
||||
// for it in VirtualScheduler. A derivative of VirtualScheduler can move a
|
||||
// <unique_ptr> SchedulerState to VirtualScheduler when it is constructed,
|
||||
// which is where this move constructor is needed.
|
||||
SchedulerState(SchedulerState&& arg) = default;
|
||||
// We explicitly delete assinment and copy operators, this is done implicitly,
|
||||
// but we state it here explicitly for clarity.
|
||||
SchedulerState& operator=(SchedulerState&& arg) = delete;
|
||||
SchedulerState(const SchedulerState&) = delete;
|
||||
SchedulerState& operator=(const SchedulerState&) = delete;
|
||||
// Destructor. Must be defined such that a derivative class can override it
|
||||
// and allow proper desctruction of the derivative class. If this is not done
|
||||
// properly, memory leaks can occur.
|
||||
virtual ~SchedulerState();
|
||||
// Sets up the graph while also performing some necessary transformations
|
||||
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
|
||||
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
|
||||
@ -332,12 +347,14 @@ class SchedulerState {
|
||||
std::vector<const NodeDef*>* initial_nodes,
|
||||
bool create_explicit_channel_device = true);
|
||||
|
||||
Costs Summary() const;
|
||||
virtual Costs Summary() const;
|
||||
// Like the above, but writes detailed stats to RunMetadata.
|
||||
// If metadata is nullptr, then just calls and return Summary().
|
||||
Costs Summary(RunMetadata* metadata);
|
||||
virtual Costs Summary(RunMetadata* metadata);
|
||||
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
||||
// of the virtual execution of the graph.
|
||||
// TODO(rdegruijl) See if we can make this function and caller Summary()
|
||||
// const.
|
||||
void GenerateRunMetadata(RunMetadata* metadata);
|
||||
|
||||
// Returns per device memory usage.
|
||||
@ -438,6 +455,15 @@ class VirtualScheduler {
|
||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||
ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<VirtualPlacer> placer);
|
||||
// This constructor can be called by a derivative of VirtualScheduler to
|
||||
// construct the base class. It lets VirtualScheduler take ownership of
|
||||
// a new SchedulerState or a derivative thereof.
|
||||
// Note that this constructor does not set a VirtualPlacer, in this
|
||||
// constructor the VirtialPlacer is passed as a member of the SchedulerState
|
||||
// that is passed as an argument.
|
||||
VirtualScheduler(ReadyNodeManager* ready_nodes,
|
||||
std::unique_ptr<SchedulerState> scheduler_state);
|
||||
virtual ~VirtualScheduler();
|
||||
|
||||
// Initializes the scheduler for the specific grappler item.
|
||||
// Should be called immediately after the c'tor or when the scheduler will be
|
||||
@ -447,51 +473,51 @@ class VirtualScheduler {
|
||||
// This function should be called at least once after the scheduler is
|
||||
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
||||
// undefined behavior.
|
||||
Status Init(const GrapplerItem* item);
|
||||
virtual Status Init(const GrapplerItem* item);
|
||||
|
||||
// Gets the current scheduled node for execution; the caller of this function
|
||||
// can accordingly simulate the execution of the current scheduled node.
|
||||
OpContext GetCurrNode() const;
|
||||
virtual OpContext GetCurrNode();
|
||||
// Marks the current scheduled node as executed. Note that we should call this
|
||||
// function only after the execution of the node has been simulated;
|
||||
// node_costs_ capture the simulated costs of the node.
|
||||
// Returns true if there is any node to be scheduled.
|
||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||
virtual bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||
|
||||
// Prints out summary of execution (timing, memory usage, etc.)
|
||||
Costs Summary() const { return scheduler_state_.Summary(); }
|
||||
Costs Summary() const { return scheduler_state_->Summary(); }
|
||||
// Like the above, but writes detailed stats to RunMetadata.
|
||||
// If metadata is nullptr, then just calls and return Summary().
|
||||
Costs Summary(RunMetadata* metadata) {
|
||||
return scheduler_state_.Summary(metadata);
|
||||
return scheduler_state_->Summary(metadata);
|
||||
}
|
||||
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
||||
// of the virtual execution of the graph.
|
||||
void GenerateRunMetadata(RunMetadata* metadata) {
|
||||
scheduler_state_.GenerateRunMetadata(metadata);
|
||||
scheduler_state_->GenerateRunMetadata(metadata);
|
||||
}
|
||||
// Returns per device memory usage.
|
||||
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
|
||||
return scheduler_state_.GetPeakMemoryUsage();
|
||||
return scheduler_state_->GetPeakMemoryUsage();
|
||||
}
|
||||
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
|
||||
return scheduler_state_.GetPersistentMemoryUsage();
|
||||
return scheduler_state_->GetPersistentMemoryUsage();
|
||||
}
|
||||
// Returns VirtualScheduler (read only) device and node states.
|
||||
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
||||
return scheduler_state_.GetDeviceStates();
|
||||
return scheduler_state_->GetDeviceStates();
|
||||
}
|
||||
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
||||
return scheduler_state_.GetNodeStates();
|
||||
return scheduler_state_->GetNodeStates();
|
||||
}
|
||||
void enable_mem_usage_tracking() {
|
||||
scheduler_state_.enable_mem_usage_tracking();
|
||||
scheduler_state_->enable_mem_usage_tracking();
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
// The state of the scheduler and the execution of the graph is encapsulated
|
||||
// by the scheduler_state_ object.
|
||||
SchedulerState scheduler_state_;
|
||||
std::unique_ptr<SchedulerState> scheduler_state_;
|
||||
// ready_nodes_ is responsible for ordering the traversal of the graph.
|
||||
ReadyNodeManager* ready_nodes_; // Not owned.
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user