From 0c4e2e7bc7b5168efce7b0a73082fa4c6d5f08d1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 12 Oct 2020 11:02:21 -0700
Subject: [PATCH] This CL makes VirtualScheduler and SchedulerState
 polymorphic.

PiperOrigin-RevId: 336700189
Change-Id: I2dfe391f7e12ee325e88260d10f650b5e702cea7
---
 .../core/grappler/costs/virtual_scheduler.cc  | 24 +++++---
 .../core/grappler/costs/virtual_scheduler.h   | 56 ++++++++++++++-----
 2 files changed, 58 insertions(+), 22 deletions(-)

diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index c1643bc7bee..a8a337bc3fa 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -361,6 +361,8 @@ std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
   return nullptr;
 }
 
+SchedulerState::~SchedulerState() {}
+
 SchedulerState::SchedulerState(const bool use_static_shapes,
                                const bool use_aggressive_shape_inference,
                                Cluster* cluster,
@@ -1259,15 +1261,23 @@ void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
   node_state.time_scheduled = device.GetCurrTime();
 }
 
+VirtualScheduler::~VirtualScheduler() {}
+
 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
                                    const bool use_aggressive_shape_inference,
                                    Cluster* cluster,
                                    ReadyNodeManager* ready_nodes,
                                    std::unique_ptr<VirtualPlacer> placer)
-    : scheduler_state_(use_static_shapes, use_aggressive_shape_inference,
-                       cluster, std::move(placer)),
+    : scheduler_state_(absl::make_unique<SchedulerState>(
+          use_static_shapes, use_aggressive_shape_inference, cluster,
+          std::move(placer))),
       ready_nodes_(ready_nodes) {}
 
+VirtualScheduler::VirtualScheduler(
+    ReadyNodeManager* ready_nodes,
+    std::unique_ptr<SchedulerState> scheduler_state)
+    : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
+
 Status VirtualScheduler::Init(const GrapplerItem* item) {
   // SchedulerState::Init() preprocesses the input grappler_item and
   // graph_properties to extract necessary information for emulating tensorflow
@@ -1275,7 +1285,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
   // DeviceState) for virtual scheduling.
   TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
   std::vector<const NodeDef*> initial_nodes;
-  auto status = scheduler_state_.Init(item, &initial_nodes);
+  auto status = scheduler_state_->Init(item, &initial_nodes);
   if (status.ok()) {
     // Add the set of initial nodes to ready_nodes_
     for (auto node : initial_nodes) {
@@ -1285,17 +1295,17 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
   return status;
 }
 
-OpContext VirtualScheduler::GetCurrNode() const {
+OpContext VirtualScheduler::GetCurrNode() {
   const NodeDef* node = ready_nodes_->GetCurrNode();
-  return scheduler_state_.CreateOpContext(node);
+  return scheduler_state_->CreateOpContext(node);
 }
 
 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
   // Update graph_costs_ and per-op costs.
   const NodeDef* node = ready_nodes_->GetCurrNode();
-  auto new_nodes = scheduler_state_.MarkNodeExecuted(
+  auto new_nodes = scheduler_state_->MarkNodeExecuted(
       node, node_costs,
-      scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode()));
+      scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
   ready_nodes_->RemoveCurrNode();
   // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
   for (auto node : new_nodes) {
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 0968d2ae11d..04f1e571ae5 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -324,6 +324,21 @@ class SchedulerState {
   SchedulerState(const bool use_static_shapes,
                  const bool use_aggressive_shape_inference, Cluster* cluster,
                  std::unique_ptr<VirtualPlacer> placer);
+  // Move constructor. Explicitly defined because it otherwise gets implicitly
+  // deleted. SchedulerState is a move-only class, as we have a <unique_ptr>
+  // for it in VirtualScheduler. A derivative of VirtualScheduler can move a
+  // <unique_ptr> SchedulerState to VirtualScheduler when it is constructed,
+  // which is where this move constructor is needed.
+  SchedulerState(SchedulerState&& arg) = default;
+  // We explicitly delete assinment and copy operators, this is done implicitly,
+  // but we state it here explicitly for clarity.
+  SchedulerState& operator=(SchedulerState&& arg) = delete;
+  SchedulerState(const SchedulerState&) = delete;
+  SchedulerState& operator=(const SchedulerState&) = delete;
+  // Destructor. Must be defined such that a derivative class can override it
+  // and allow proper desctruction of the derivative class. If this is not done
+  // properly, memory leaks can occur.
+  virtual ~SchedulerState();
   // Sets up the graph while also performing some necessary transformations
   // initial_nodes is the set of nodes (primary inputs) discovered by Init()
   // which may be added by a ReadyNodeManager (or related/derivative scheduler)
@@ -332,12 +347,14 @@ class SchedulerState {
               std::vector<const NodeDef*>* initial_nodes,
               bool create_explicit_channel_device = true);
 
-  Costs Summary() const;
+  virtual Costs Summary() const;
   // Like the above, but writes detailed stats to RunMetadata.
   // If metadata is nullptr, then just calls and return Summary().
-  Costs Summary(RunMetadata* metadata);
+  virtual Costs Summary(RunMetadata* metadata);
   // Generates RunMetadata's step_stats and partition_graphs fields from results
   // of the virtual execution of the graph.
+  // TODO(rdegruijl) See if we can make this function and caller Summary()
+  // const.
   void GenerateRunMetadata(RunMetadata* metadata);
 
   // Returns per device memory usage.
@@ -438,6 +455,15 @@ class VirtualScheduler {
                    const bool use_aggressive_shape_inference, Cluster* cluster,
                    ReadyNodeManager* ready_nodes,
                    std::unique_ptr<VirtualPlacer> placer);
+  // This constructor can be called by a derivative of VirtualScheduler to
+  // construct the base class. It lets VirtualScheduler take ownership of
+  // a new SchedulerState or a derivative thereof.
+  // Note that this constructor does not set a VirtualPlacer, in this
+  // constructor the VirtialPlacer is passed as a member of the SchedulerState
+  // that is passed as an argument.
+  VirtualScheduler(ReadyNodeManager* ready_nodes,
+                   std::unique_ptr<SchedulerState> scheduler_state);
+  virtual ~VirtualScheduler();
 
   // Initializes the scheduler for the specific grappler item.
   // Should be called immediately after the c'tor or when the scheduler will be
@@ -447,51 +473,51 @@ class VirtualScheduler {
   // This function should be called at least once after the scheduler is
   // constructed. An uninitialized or failed-to-initialize scheduler will cause
   // undefined behavior.
-  Status Init(const GrapplerItem* item);
+  virtual Status Init(const GrapplerItem* item);
 
   // Gets the current scheduled node for execution; the caller of this function
   // can accordingly simulate the execution of the current scheduled node.
-  OpContext GetCurrNode() const;
+  virtual OpContext GetCurrNode();
   // Marks the current scheduled node as executed. Note that we should call this
   // function only after the execution of the node has been simulated;
   // node_costs_ capture the simulated costs of the node.
   // Returns true if there is any node to be scheduled.
-  bool MarkCurrNodeExecuted(const Costs& node_costs);
+  virtual bool MarkCurrNodeExecuted(const Costs& node_costs);
 
   // Prints out summary of execution (timing, memory usage, etc.)
-  Costs Summary() const { return scheduler_state_.Summary(); }
+  Costs Summary() const { return scheduler_state_->Summary(); }
   // Like the above, but writes detailed stats to RunMetadata.
   // If metadata is nullptr, then just calls and return Summary().
   Costs Summary(RunMetadata* metadata) {
-    return scheduler_state_.Summary(metadata);
+    return scheduler_state_->Summary(metadata);
   }
   // Generates RunMetadata's step_stats and partition_graphs fields from results
   // of the virtual execution of the graph.
   void GenerateRunMetadata(RunMetadata* metadata) {
-    scheduler_state_.GenerateRunMetadata(metadata);
+    scheduler_state_->GenerateRunMetadata(metadata);
   }
   // Returns per device memory usage.
   const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
-    return scheduler_state_.GetPeakMemoryUsage();
+    return scheduler_state_->GetPeakMemoryUsage();
   }
   const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
-    return scheduler_state_.GetPersistentMemoryUsage();
+    return scheduler_state_->GetPersistentMemoryUsage();
   }
   // Returns VirtualScheduler (read only) device and node states.
   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
-    return scheduler_state_.GetDeviceStates();
+    return scheduler_state_->GetDeviceStates();
   }
   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
-    return scheduler_state_.GetNodeStates();
+    return scheduler_state_->GetNodeStates();
   }
   void enable_mem_usage_tracking() {
-    scheduler_state_.enable_mem_usage_tracking();
+    scheduler_state_->enable_mem_usage_tracking();
   }
 
- private:
+ protected:
   // The state of the scheduler and the execution of the graph is encapsulated
   // by the scheduler_state_ object.
-  SchedulerState scheduler_state_;
+  std::unique_ptr<SchedulerState> scheduler_state_;
   // ready_nodes_ is responsible for ordering the traversal of the graph.
   ReadyNodeManager* ready_nodes_;  // Not owned.
 };