diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index c1643bc7bee..a8a337bc3fa 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -361,6 +361,8 @@ std::unique_ptr ReadyNodeManagerFactory( return nullptr; } +SchedulerState::~SchedulerState() {} + SchedulerState::SchedulerState(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, @@ -1259,15 +1261,23 @@ void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) { node_state.time_scheduled = device.GetCurrTime(); } +VirtualScheduler::~VirtualScheduler() {} + VirtualScheduler::VirtualScheduler(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, ReadyNodeManager* ready_nodes, std::unique_ptr placer) - : scheduler_state_(use_static_shapes, use_aggressive_shape_inference, - cluster, std::move(placer)), + : scheduler_state_(absl::make_unique( + use_static_shapes, use_aggressive_shape_inference, cluster, + std::move(placer))), ready_nodes_(ready_nodes) {} +VirtualScheduler::VirtualScheduler( + ReadyNodeManager* ready_nodes, + std::unique_ptr scheduler_state) + : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {} + Status VirtualScheduler::Init(const GrapplerItem* item) { // SchedulerState::Init() preprocesses the input grappler_item and // graph_properties to extract necessary information for emulating tensorflow @@ -1275,7 +1285,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { // DeviceState) for virtual scheduling. TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates())); std::vector initial_nodes; - auto status = scheduler_state_.Init(item, &initial_nodes); + auto status = scheduler_state_->Init(item, &initial_nodes); if (status.ok()) { // Add the set of initial nodes to ready_nodes_ for (auto node : initial_nodes) { @@ -1285,17 +1295,17 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { return status; } -OpContext VirtualScheduler::GetCurrNode() const { +OpContext VirtualScheduler::GetCurrNode() { const NodeDef* node = ready_nodes_->GetCurrNode(); - return scheduler_state_.CreateOpContext(node); + return scheduler_state_->CreateOpContext(node); } bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { // Update graph_costs_ and per-op costs. const NodeDef* node = ready_nodes_->GetCurrNode(); - auto new_nodes = scheduler_state_.MarkNodeExecuted( + auto new_nodes = scheduler_state_->MarkNodeExecuted( node, node_costs, - scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode())); + scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode())); ready_nodes_->RemoveCurrNode(); // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_. for (auto node : new_nodes) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 0968d2ae11d..04f1e571ae5 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -324,6 +324,21 @@ class SchedulerState { SchedulerState(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, std::unique_ptr placer); + // Move constructor. Explicitly defined because it otherwise gets implicitly + // deleted. SchedulerState is a move-only class, as we have a + // for it in VirtualScheduler. A derivative of VirtualScheduler can move a + // SchedulerState to VirtualScheduler when it is constructed, + // which is where this move constructor is needed. + SchedulerState(SchedulerState&& arg) = default; + // We explicitly delete assinment and copy operators, this is done implicitly, + // but we state it here explicitly for clarity. + SchedulerState& operator=(SchedulerState&& arg) = delete; + SchedulerState(const SchedulerState&) = delete; + SchedulerState& operator=(const SchedulerState&) = delete; + // Destructor. Must be defined such that a derivative class can override it + // and allow proper desctruction of the derivative class. If this is not done + // properly, memory leaks can occur. + virtual ~SchedulerState(); // Sets up the graph while also performing some necessary transformations // initial_nodes is the set of nodes (primary inputs) discovered by Init() // which may be added by a ReadyNodeManager (or related/derivative scheduler) @@ -332,12 +347,14 @@ class SchedulerState { std::vector* initial_nodes, bool create_explicit_channel_device = true); - Costs Summary() const; + virtual Costs Summary() const; // Like the above, but writes detailed stats to RunMetadata. // If metadata is nullptr, then just calls and return Summary(). - Costs Summary(RunMetadata* metadata); + virtual Costs Summary(RunMetadata* metadata); // Generates RunMetadata's step_stats and partition_graphs fields from results // of the virtual execution of the graph. + // TODO(rdegruijl) See if we can make this function and caller Summary() + // const. void GenerateRunMetadata(RunMetadata* metadata); // Returns per device memory usage. @@ -438,6 +455,15 @@ class VirtualScheduler { const bool use_aggressive_shape_inference, Cluster* cluster, ReadyNodeManager* ready_nodes, std::unique_ptr placer); + // This constructor can be called by a derivative of VirtualScheduler to + // construct the base class. It lets VirtualScheduler take ownership of + // a new SchedulerState or a derivative thereof. + // Note that this constructor does not set a VirtualPlacer, in this + // constructor the VirtialPlacer is passed as a member of the SchedulerState + // that is passed as an argument. + VirtualScheduler(ReadyNodeManager* ready_nodes, + std::unique_ptr scheduler_state); + virtual ~VirtualScheduler(); // Initializes the scheduler for the specific grappler item. // Should be called immediately after the c'tor or when the scheduler will be @@ -447,51 +473,51 @@ class VirtualScheduler { // This function should be called at least once after the scheduler is // constructed. An uninitialized or failed-to-initialize scheduler will cause // undefined behavior. - Status Init(const GrapplerItem* item); + virtual Status Init(const GrapplerItem* item); // Gets the current scheduled node for execution; the caller of this function // can accordingly simulate the execution of the current scheduled node. - OpContext GetCurrNode() const; + virtual OpContext GetCurrNode(); // Marks the current scheduled node as executed. Note that we should call this // function only after the execution of the node has been simulated; // node_costs_ capture the simulated costs of the node. // Returns true if there is any node to be scheduled. - bool MarkCurrNodeExecuted(const Costs& node_costs); + virtual bool MarkCurrNodeExecuted(const Costs& node_costs); // Prints out summary of execution (timing, memory usage, etc.) - Costs Summary() const { return scheduler_state_.Summary(); } + Costs Summary() const { return scheduler_state_->Summary(); } // Like the above, but writes detailed stats to RunMetadata. // If metadata is nullptr, then just calls and return Summary(). Costs Summary(RunMetadata* metadata) { - return scheduler_state_.Summary(metadata); + return scheduler_state_->Summary(metadata); } // Generates RunMetadata's step_stats and partition_graphs fields from results // of the virtual execution of the graph. void GenerateRunMetadata(RunMetadata* metadata) { - scheduler_state_.GenerateRunMetadata(metadata); + scheduler_state_->GenerateRunMetadata(metadata); } // Returns per device memory usage. const std::unordered_map GetPeakMemoryUsage() const { - return scheduler_state_.GetPeakMemoryUsage(); + return scheduler_state_->GetPeakMemoryUsage(); } const std::unordered_map GetPersistentMemoryUsage() const { - return scheduler_state_.GetPersistentMemoryUsage(); + return scheduler_state_->GetPersistentMemoryUsage(); } // Returns VirtualScheduler (read only) device and node states. const std::unordered_map* GetDeviceStates() const { - return scheduler_state_.GetDeviceStates(); + return scheduler_state_->GetDeviceStates(); } const std::unordered_map* GetNodeStates() const { - return scheduler_state_.GetNodeStates(); + return scheduler_state_->GetNodeStates(); } void enable_mem_usage_tracking() { - scheduler_state_.enable_mem_usage_tracking(); + scheduler_state_->enable_mem_usage_tracking(); } - private: + protected: // The state of the scheduler and the execution of the graph is encapsulated // by the scheduler_state_ object. - SchedulerState scheduler_state_; + std::unique_ptr scheduler_state_; // ready_nodes_ is responsible for ordering the traversal of the graph. ReadyNodeManager* ready_nodes_; // Not owned. };