This CL makes VirtualScheduler and SchedulerState polymorphic.

PiperOrigin-RevId: 336700189
Change-Id: I2dfe391f7e12ee325e88260d10f650b5e702cea7
This commit is contained in:
A. Unique TensorFlower 2020-10-12 11:02:21 -07:00 committed by TensorFlower Gardener
parent ed5360e8f6
commit 0c4e2e7bc7
2 changed files with 58 additions and 22 deletions

View File

@ -361,6 +361,8 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
return nullptr;
}
SchedulerState::~SchedulerState() {}
SchedulerState::SchedulerState(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
@ -1259,15 +1261,23 @@ void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
node_state.time_scheduled = device.GetCurrTime();
}
VirtualScheduler::~VirtualScheduler() {}
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer)
: scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
cluster, std::move(placer)),
: scheduler_state_(absl::make_unique<SchedulerState>(
use_static_shapes, use_aggressive_shape_inference, cluster,
std::move(placer))),
ready_nodes_(ready_nodes) {}
VirtualScheduler::VirtualScheduler(
ReadyNodeManager* ready_nodes,
std::unique_ptr<SchedulerState> scheduler_state)
: scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
Status VirtualScheduler::Init(const GrapplerItem* item) {
// SchedulerState::Init() preprocesses the input grappler_item and
// graph_properties to extract necessary information for emulating tensorflow
@ -1275,7 +1285,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
// DeviceState) for virtual scheduling.
TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
std::vector<const NodeDef*> initial_nodes;
auto status = scheduler_state_.Init(item, &initial_nodes);
auto status = scheduler_state_->Init(item, &initial_nodes);
if (status.ok()) {
// Add the set of initial nodes to ready_nodes_
for (auto node : initial_nodes) {
@ -1285,17 +1295,17 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
return status;
}
OpContext VirtualScheduler::GetCurrNode() const {
OpContext VirtualScheduler::GetCurrNode() {
const NodeDef* node = ready_nodes_->GetCurrNode();
return scheduler_state_.CreateOpContext(node);
return scheduler_state_->CreateOpContext(node);
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
const NodeDef* node = ready_nodes_->GetCurrNode();
auto new_nodes = scheduler_state_.MarkNodeExecuted(
auto new_nodes = scheduler_state_->MarkNodeExecuted(
node, node_costs,
scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
ready_nodes_->RemoveCurrNode();
// Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
for (auto node : new_nodes) {

View File

@ -324,6 +324,21 @@ class SchedulerState {
SchedulerState(const bool use_static_shapes,
const bool use_aggressive_shape_inference, Cluster* cluster,
std::unique_ptr<VirtualPlacer> placer);
// Move constructor. Explicitly defined because it otherwise gets implicitly
// deleted. SchedulerState is a move-only class, as we have a <unique_ptr>
// for it in VirtualScheduler. A derivative of VirtualScheduler can move a
// <unique_ptr> SchedulerState to VirtualScheduler when it is constructed,
// which is where this move constructor is needed.
SchedulerState(SchedulerState&& arg) = default;
// We explicitly delete assinment and copy operators, this is done implicitly,
// but we state it here explicitly for clarity.
SchedulerState& operator=(SchedulerState&& arg) = delete;
SchedulerState(const SchedulerState&) = delete;
SchedulerState& operator=(const SchedulerState&) = delete;
// Destructor. Must be defined such that a derivative class can override it
// and allow proper desctruction of the derivative class. If this is not done
// properly, memory leaks can occur.
virtual ~SchedulerState();
// Sets up the graph while also performing some necessary transformations
// initial_nodes is the set of nodes (primary inputs) discovered by Init()
// which may be added by a ReadyNodeManager (or related/derivative scheduler)
@ -332,12 +347,14 @@ class SchedulerState {
std::vector<const NodeDef*>* initial_nodes,
bool create_explicit_channel_device = true);
Costs Summary() const;
virtual Costs Summary() const;
// Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary().
Costs Summary(RunMetadata* metadata);
virtual Costs Summary(RunMetadata* metadata);
// Generates RunMetadata's step_stats and partition_graphs fields from results
// of the virtual execution of the graph.
// TODO(rdegruijl) See if we can make this function and caller Summary()
// const.
void GenerateRunMetadata(RunMetadata* metadata);
// Returns per device memory usage.
@ -438,6 +455,15 @@ class VirtualScheduler {
const bool use_aggressive_shape_inference, Cluster* cluster,
ReadyNodeManager* ready_nodes,
std::unique_ptr<VirtualPlacer> placer);
// This constructor can be called by a derivative of VirtualScheduler to
// construct the base class. It lets VirtualScheduler take ownership of
// a new SchedulerState or a derivative thereof.
// Note that this constructor does not set a VirtualPlacer, in this
// constructor the VirtialPlacer is passed as a member of the SchedulerState
// that is passed as an argument.
VirtualScheduler(ReadyNodeManager* ready_nodes,
std::unique_ptr<SchedulerState> scheduler_state);
virtual ~VirtualScheduler();
// Initializes the scheduler for the specific grappler item.
// Should be called immediately after the c'tor or when the scheduler will be
@ -447,51 +473,51 @@ class VirtualScheduler {
// This function should be called at least once after the scheduler is
// constructed. An uninitialized or failed-to-initialize scheduler will cause
// undefined behavior.
Status Init(const GrapplerItem* item);
virtual Status Init(const GrapplerItem* item);
// Gets the current scheduled node for execution; the caller of this function
// can accordingly simulate the execution of the current scheduled node.
OpContext GetCurrNode() const;
virtual OpContext GetCurrNode();
// Marks the current scheduled node as executed. Note that we should call this
// function only after the execution of the node has been simulated;
// node_costs_ capture the simulated costs of the node.
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
virtual bool MarkCurrNodeExecuted(const Costs& node_costs);
// Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const { return scheduler_state_.Summary(); }
Costs Summary() const { return scheduler_state_->Summary(); }
// Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary().
Costs Summary(RunMetadata* metadata) {
return scheduler_state_.Summary(metadata);
return scheduler_state_->Summary(metadata);
}
// Generates RunMetadata's step_stats and partition_graphs fields from results
// of the virtual execution of the graph.
void GenerateRunMetadata(RunMetadata* metadata) {
scheduler_state_.GenerateRunMetadata(metadata);
scheduler_state_->GenerateRunMetadata(metadata);
}
// Returns per device memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
return scheduler_state_.GetPeakMemoryUsage();
return scheduler_state_->GetPeakMemoryUsage();
}
const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
return scheduler_state_.GetPersistentMemoryUsage();
return scheduler_state_->GetPersistentMemoryUsage();
}
// Returns VirtualScheduler (read only) device and node states.
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return scheduler_state_.GetDeviceStates();
return scheduler_state_->GetDeviceStates();
}
const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
return scheduler_state_.GetNodeStates();
return scheduler_state_->GetNodeStates();
}
void enable_mem_usage_tracking() {
scheduler_state_.enable_mem_usage_tracking();
scheduler_state_->enable_mem_usage_tracking();
}
private:
protected:
// The state of the scheduler and the execution of the graph is encapsulated
// by the scheduler_state_ object.
SchedulerState scheduler_state_;
std::unique_ptr<SchedulerState> scheduler_state_;
// ready_nodes_ is responsible for ordering the traversal of the graph.
ReadyNodeManager* ready_nodes_; // Not owned.
};