This CL makes VirtualScheduler and SchedulerState polymorphic.
PiperOrigin-RevId: 336700189 Change-Id: I2dfe391f7e12ee325e88260d10f650b5e702cea7
This commit is contained in:
parent
ed5360e8f6
commit
0c4e2e7bc7
@ -361,6 +361,8 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SchedulerState::~SchedulerState() {}
|
||||||
|
|
||||||
SchedulerState::SchedulerState(const bool use_static_shapes,
|
SchedulerState::SchedulerState(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference,
|
const bool use_aggressive_shape_inference,
|
||||||
Cluster* cluster,
|
Cluster* cluster,
|
||||||
@ -1259,15 +1261,23 @@ void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
|
|||||||
node_state.time_scheduled = device.GetCurrTime();
|
node_state.time_scheduled = device.GetCurrTime();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VirtualScheduler::~VirtualScheduler() {}
|
||||||
|
|
||||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference,
|
const bool use_aggressive_shape_inference,
|
||||||
Cluster* cluster,
|
Cluster* cluster,
|
||||||
ReadyNodeManager* ready_nodes,
|
ReadyNodeManager* ready_nodes,
|
||||||
std::unique_ptr<VirtualPlacer> placer)
|
std::unique_ptr<VirtualPlacer> placer)
|
||||||
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
|
: scheduler_state_(absl::make_unique<SchedulerState>(
|
||||||
cluster, std::move(placer)),
|
use_static_shapes, use_aggressive_shape_inference, cluster,
|
||||||
|
std::move(placer))),
|
||||||
ready_nodes_(ready_nodes) {}
|
ready_nodes_(ready_nodes) {}
|
||||||
|
|
||||||
|
VirtualScheduler::VirtualScheduler(
|
||||||
|
ReadyNodeManager* ready_nodes,
|
||||||
|
std::unique_ptr<SchedulerState> scheduler_state)
|
||||||
|
: scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
|
||||||
|
|
||||||
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||||
// SchedulerState::Init() preprocesses the input grappler_item and
|
// SchedulerState::Init() preprocesses the input grappler_item and
|
||||||
// graph_properties to extract necessary information for emulating tensorflow
|
// graph_properties to extract necessary information for emulating tensorflow
|
||||||
@ -1275,7 +1285,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||||||
// DeviceState) for virtual scheduling.
|
// DeviceState) for virtual scheduling.
|
||||||
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
|
||||||
std::vector<const NodeDef*> initial_nodes;
|
std::vector<const NodeDef*> initial_nodes;
|
||||||
auto status = scheduler_state_.Init(item, &initial_nodes);
|
auto status = scheduler_state_->Init(item, &initial_nodes);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
// Add the set of initial nodes to ready_nodes_
|
// Add the set of initial nodes to ready_nodes_
|
||||||
for (auto node : initial_nodes) {
|
for (auto node : initial_nodes) {
|
||||||
@ -1285,17 +1295,17 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpContext VirtualScheduler::GetCurrNode() const {
|
OpContext VirtualScheduler::GetCurrNode() {
|
||||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||||
return scheduler_state_.CreateOpContext(node);
|
return scheduler_state_->CreateOpContext(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
// Update graph_costs_ and per-op costs.
|
// Update graph_costs_ and per-op costs.
|
||||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||||
auto new_nodes = scheduler_state_.MarkNodeExecuted(
|
auto new_nodes = scheduler_state_->MarkNodeExecuted(
|
||||||
node, node_costs,
|
node, node_costs,
|
||||||
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
|
scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
|
||||||
ready_nodes_->RemoveCurrNode();
|
ready_nodes_->RemoveCurrNode();
|
||||||
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
|
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
|
||||||
for (auto node : new_nodes) {
|
for (auto node : new_nodes) {
|
||||||
|
@ -324,6 +324,21 @@ class SchedulerState {
|
|||||||
SchedulerState(const bool use_static_shapes,
|
SchedulerState(const bool use_static_shapes,
|
||||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||||
std::unique_ptr<VirtualPlacer> placer);
|
std::unique_ptr<VirtualPlacer> placer);
|
||||||
|
// Move constructor. Explicitly defined because it otherwise gets implicitly
|
||||||
|
// deleted. SchedulerState is a move-only class, as we have a <unique_ptr>
|
||||||
|
// for it in VirtualScheduler. A derivative of VirtualScheduler can move a
|
||||||
|
// <unique_ptr> SchedulerState to VirtualScheduler when it is constructed,
|
||||||
|
// which is where this move constructor is needed.
|
||||||
|
SchedulerState(SchedulerState&& arg) = default;
|
||||||
|
// We explicitly delete assinment and copy operators, this is done implicitly,
|
||||||
|
// but we state it here explicitly for clarity.
|
||||||
|
SchedulerState& operator=(SchedulerState&& arg) = delete;
|
||||||
|
SchedulerState(const SchedulerState&) = delete;
|
||||||
|
SchedulerState& operator=(const SchedulerState&) = delete;
|
||||||
|
// Destructor. Must be defined such that a derivative class can override it
|
||||||
|
// and allow proper desctruction of the derivative class. If this is not done
|
||||||
|
// properly, memory leaks can occur.
|
||||||
|
virtual ~SchedulerState();
|
||||||
// Sets up the graph while also performing some necessary transformations
|
// Sets up the graph while also performing some necessary transformations
|
||||||
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
|
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
|
||||||
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
|
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
|
||||||
@ -332,12 +347,14 @@ class SchedulerState {
|
|||||||
std::vector<const NodeDef*>* initial_nodes,
|
std::vector<const NodeDef*>* initial_nodes,
|
||||||
bool create_explicit_channel_device = true);
|
bool create_explicit_channel_device = true);
|
||||||
|
|
||||||
Costs Summary() const;
|
virtual Costs Summary() const;
|
||||||
// Like the above, but writes detailed stats to RunMetadata.
|
// Like the above, but writes detailed stats to RunMetadata.
|
||||||
// If metadata is nullptr, then just calls and return Summary().
|
// If metadata is nullptr, then just calls and return Summary().
|
||||||
Costs Summary(RunMetadata* metadata);
|
virtual Costs Summary(RunMetadata* metadata);
|
||||||
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
||||||
// of the virtual execution of the graph.
|
// of the virtual execution of the graph.
|
||||||
|
// TODO(rdegruijl) See if we can make this function and caller Summary()
|
||||||
|
// const.
|
||||||
void GenerateRunMetadata(RunMetadata* metadata);
|
void GenerateRunMetadata(RunMetadata* metadata);
|
||||||
|
|
||||||
// Returns per device memory usage.
|
// Returns per device memory usage.
|
||||||
@ -438,6 +455,15 @@ class VirtualScheduler {
|
|||||||
const bool use_aggressive_shape_inference, Cluster* cluster,
|
const bool use_aggressive_shape_inference, Cluster* cluster,
|
||||||
ReadyNodeManager* ready_nodes,
|
ReadyNodeManager* ready_nodes,
|
||||||
std::unique_ptr<VirtualPlacer> placer);
|
std::unique_ptr<VirtualPlacer> placer);
|
||||||
|
// This constructor can be called by a derivative of VirtualScheduler to
|
||||||
|
// construct the base class. It lets VirtualScheduler take ownership of
|
||||||
|
// a new SchedulerState or a derivative thereof.
|
||||||
|
// Note that this constructor does not set a VirtualPlacer, in this
|
||||||
|
// constructor the VirtialPlacer is passed as a member of the SchedulerState
|
||||||
|
// that is passed as an argument.
|
||||||
|
VirtualScheduler(ReadyNodeManager* ready_nodes,
|
||||||
|
std::unique_ptr<SchedulerState> scheduler_state);
|
||||||
|
virtual ~VirtualScheduler();
|
||||||
|
|
||||||
// Initializes the scheduler for the specific grappler item.
|
// Initializes the scheduler for the specific grappler item.
|
||||||
// Should be called immediately after the c'tor or when the scheduler will be
|
// Should be called immediately after the c'tor or when the scheduler will be
|
||||||
@ -447,51 +473,51 @@ class VirtualScheduler {
|
|||||||
// This function should be called at least once after the scheduler is
|
// This function should be called at least once after the scheduler is
|
||||||
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
// constructed. An uninitialized or failed-to-initialize scheduler will cause
|
||||||
// undefined behavior.
|
// undefined behavior.
|
||||||
Status Init(const GrapplerItem* item);
|
virtual Status Init(const GrapplerItem* item);
|
||||||
|
|
||||||
// Gets the current scheduled node for execution; the caller of this function
|
// Gets the current scheduled node for execution; the caller of this function
|
||||||
// can accordingly simulate the execution of the current scheduled node.
|
// can accordingly simulate the execution of the current scheduled node.
|
||||||
OpContext GetCurrNode() const;
|
virtual OpContext GetCurrNode();
|
||||||
// Marks the current scheduled node as executed. Note that we should call this
|
// Marks the current scheduled node as executed. Note that we should call this
|
||||||
// function only after the execution of the node has been simulated;
|
// function only after the execution of the node has been simulated;
|
||||||
// node_costs_ capture the simulated costs of the node.
|
// node_costs_ capture the simulated costs of the node.
|
||||||
// Returns true if there is any node to be scheduled.
|
// Returns true if there is any node to be scheduled.
|
||||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
virtual bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||||
|
|
||||||
// Prints out summary of execution (timing, memory usage, etc.)
|
// Prints out summary of execution (timing, memory usage, etc.)
|
||||||
Costs Summary() const { return scheduler_state_.Summary(); }
|
Costs Summary() const { return scheduler_state_->Summary(); }
|
||||||
// Like the above, but writes detailed stats to RunMetadata.
|
// Like the above, but writes detailed stats to RunMetadata.
|
||||||
// If metadata is nullptr, then just calls and return Summary().
|
// If metadata is nullptr, then just calls and return Summary().
|
||||||
Costs Summary(RunMetadata* metadata) {
|
Costs Summary(RunMetadata* metadata) {
|
||||||
return scheduler_state_.Summary(metadata);
|
return scheduler_state_->Summary(metadata);
|
||||||
}
|
}
|
||||||
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
// Generates RunMetadata's step_stats and partition_graphs fields from results
|
||||||
// of the virtual execution of the graph.
|
// of the virtual execution of the graph.
|
||||||
void GenerateRunMetadata(RunMetadata* metadata) {
|
void GenerateRunMetadata(RunMetadata* metadata) {
|
||||||
scheduler_state_.GenerateRunMetadata(metadata);
|
scheduler_state_->GenerateRunMetadata(metadata);
|
||||||
}
|
}
|
||||||
// Returns per device memory usage.
|
// Returns per device memory usage.
|
||||||
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
|
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
|
||||||
return scheduler_state_.GetPeakMemoryUsage();
|
return scheduler_state_->GetPeakMemoryUsage();
|
||||||
}
|
}
|
||||||
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
|
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
|
||||||
return scheduler_state_.GetPersistentMemoryUsage();
|
return scheduler_state_->GetPersistentMemoryUsage();
|
||||||
}
|
}
|
||||||
// Returns VirtualScheduler (read only) device and node states.
|
// Returns VirtualScheduler (read only) device and node states.
|
||||||
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
|
||||||
return scheduler_state_.GetDeviceStates();
|
return scheduler_state_->GetDeviceStates();
|
||||||
}
|
}
|
||||||
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
|
||||||
return scheduler_state_.GetNodeStates();
|
return scheduler_state_->GetNodeStates();
|
||||||
}
|
}
|
||||||
void enable_mem_usage_tracking() {
|
void enable_mem_usage_tracking() {
|
||||||
scheduler_state_.enable_mem_usage_tracking();
|
scheduler_state_->enable_mem_usage_tracking();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
// The state of the scheduler and the execution of the graph is encapsulated
|
// The state of the scheduler and the execution of the graph is encapsulated
|
||||||
// by the scheduler_state_ object.
|
// by the scheduler_state_ object.
|
||||||
SchedulerState scheduler_state_;
|
std::unique_ptr<SchedulerState> scheduler_state_;
|
||||||
// ready_nodes_ is responsible for ordering the traversal of the graph.
|
// ready_nodes_ is responsible for ordering the traversal of the graph.
|
||||||
ReadyNodeManager* ready_nodes_; // Not owned.
|
ReadyNodeManager* ready_nodes_; // Not owned.
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user