From eabc157fd52200f21f06c61f7686230ea20295b9 Mon Sep 17 00:00:00 2001
From: Jiri Simsa <jsimsa@google.com>
Date: Mon, 6 Apr 2020 10:34:34 -0700
Subject: [PATCH] [tf.data] Adding a metric for bytes produced and consumed by
 individual transformations, refactoring infrastructure for recording tf.data
 metrics, and moving the metrics API and implementation from `common_runtime`
 to `framework`.

PiperOrigin-RevId: 305062865
Change-Id: I63911f00154baf36aa225f66dbef0843239b7392
---
 tensorflow/core/BUILD                         |   3 +-
 tensorflow/core/common_runtime/BUILD          |   1 -
 tensorflow/core/common_runtime/metrics.h      |  90 +--------
 tensorflow/core/framework/BUILD               |   5 +
 tensorflow/core/framework/dataset.cc          |   2 +-
 tensorflow/core/framework/dataset.h           |   9 +-
 .../{common_runtime => framework}/metrics.cc  |  46 +++--
 tensorflow/core/framework/metrics.h           | 123 ++++++++++++
 tensorflow/core/framework/model.cc            |  32 ++--
 tensorflow/core/framework/model.h             | 176 ++++++++++++------
 .../core/kernels/data/model_dataset_op.cc     |  18 +-
 11 files changed, 319 insertions(+), 186 deletions(-)
 rename tensorflow/core/{common_runtime => framework}/metrics.cc (92%)
 create mode 100644 tensorflow/core/framework/metrics.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 7d62274e87f..b309a1f2e24 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2194,7 +2194,8 @@ filegroup(
 filegroup(
     name = "framework_internal_public_headers",
     srcs = [
-        "//tensorflow/core/framework:model.h",  # only needed for tests
+        "//tensorflow/core/framework:metrics.h",
+        "//tensorflow/core/framework:model.h",
         "//tensorflow/core/framework:op_segment.h",
         "//tensorflow/core/framework:rendezvous.h",  # only needed for tests
         "//tensorflow/core/framework:resource_var.h",
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index bbfed7f8f5b..b64e14212c9 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -303,7 +303,6 @@ tf_cuda_library(
         "lower_if_op.cc",
         "lower_while_op.cc",
         "memory_types.cc",
-        "metrics.cc",
         "mkl_cpu_allocator.cc",
         "optimization_registry.cc",
         "parallel_concat_optimizer.cc",
diff --git a/tensorflow/core/common_runtime/metrics.h b/tensorflow/core/common_runtime/metrics.h
index e95e0495c04..f359eec9490 100644
--- a/tensorflow/core/common_runtime/metrics.h
+++ b/tensorflow/core/common_runtime/metrics.h
@@ -16,93 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
 #define TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
 
-#include "tensorflow/core/lib/monitoring/counter.h"
-#include "tensorflow/core/platform/types.h"
+// TODO(jsimsa): Remove this forwarding header once all users are migrated to
+// using the one in framework.
 
-namespace tensorflow {
-namespace metrics {
-
-// Records that a tf.data.Dataset executed by the program used autotuning.
-//
-// The `name` argument identifies the Dataset type (e.g. "ParallelMap").
-void RecordTFDataAutotune(const string& name);
-
-// Returns a counter than can be used to record the number of bytes read from
-// the filesystem by a tf.data.Dataset source.
-//
-// The `name` argument identifies the Dataset type (e.g. "TFRecordDataset").
-monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name);
-
-// Records the number of bytes fetched from tf.data.Dataset iterator.
-void RecordTFDataBytesFetched(int64 num_bytes);
-
-// Records the time spent in ItertatorResource::GetNext() in microseconds.
-void RecordTFDataGetNextDuration(uint64 duration_us);
-
-// Records the number of elements produced by a tf.data.Dataset.
-//
-// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map").
-void RecordTFDataElements(const string& name, int64 num_elements);
-
-// Records the number of times each tf.data fingerprint is used
-// to measure duplicate pre-processing.
-//
-// The `name` argument identifies the Dataset graph fingerprint,
-// created using GraphHash().
-void RecordTFDataFingerprint(const string& name);
-
-// Records the number of independent graph changes resulting from the
-// application of a tf.data optimization.
-//
-// The `name` argument identifies the optimization (e.g. "noop_elimination").
-void RecordTFDataOptimization(const string& name, int64 num_changes);
-
-// Records parsing of dense tensor features.
-void RecordParseDenseFeature(int64 num_features);
-
-// Records parsing of sparse tensor features.
-void RecordParseSparseFeature(int64 num_features);
-
-// Records parsing of ragged tensor features.
-void RecordParseRaggedFeature(int64 num_features);
-
-// Records the size of input/output tensors in bytes.
-void RecordGraphInputTensors(const size_t size);
-void RecordGraphOutputTensors(const size_t size);
-
-void UpdateGraphExecTime(const uint64 running_time_usecs);
-
-// Records that one output of an op of type `op_name` was unused.
-void RecordUnusedOutput(const string& op_name);
-
-// Updates the metrics stored about time spent building graphs.
-//
-// By "GraphBuild", we refer to building a client graph, which is a sub-graph of
-// the full graph, induced by a set of options. In particular, these options
-// include the feeds and fetches requested.
-//
-// This includes time spent:
-//   * optimizing the graphs with Grappler
-//   * pruning the sub-graph (unless the place_pruned_graph option is set)
-//
-// When executing eagerly, this will not record any activity.
-//
-// TODO(jtkeeling): Should we record building/optimizing tf.functions?
-void UpdateGraphBuildTime(const uint64 running_time_usecs);
-
-// Updates the metrics stored about graph optimizations.
-void UpdateGraphOptimizationPassTime(const string& pass_name,
-                                     const uint64 running_time_usecs);
-void UpdateGrapplerPassTime(const string& pass_name,
-                            const uint64 running_time_usecs);
-
-// Updates the metrics stored about time XLA spents compiling graphs.
-void UpdateXlaCompilationTime(const uint64 compilation_time_usecs);
-
-// Increment the number of jobs that failed during import to mlir.
-void IncrementMLIRImportFailureCount();
-
-}  // namespace metrics
-}  // namespace tensorflow
+#include "tensorflow/core/framework/metrics.h"
 
 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index 0d3d2a27d73..234c5a403bc 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -47,6 +47,7 @@ exports_files(
         "logging.h",
         "lookup_interface.h",
         "memory_types.h",
+        "metrics.h",
         "model.h",
         "node_def_builder.h",
         "numeric_op.h",
@@ -176,6 +177,7 @@ filegroup(
         "logging.h",
         "lookup_interface.h",
         "memory_types.h",
+        "metrics.h",
         "model.h",
         "node_def_builder.h",
         "node_def_util.h",
@@ -246,6 +248,7 @@ filegroup(
         "logging.cc",
         "lookup_interface.cc",
         "memory_types.cc",
+        "metrics.cc",
         "model.cc",
         "node_def_builder.cc",
         "op_kernel.cc",
@@ -346,6 +349,8 @@ filegroup(
         "lookup_interface.h",
         "memory_types.cc",
         "memory_types.h",
+        "metrics.cc",
+        "metrics.h",
         "model.cc",
         "model.h",
         "node_def_builder.cc",
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index cccbdd5d8e4..24d6bb3bdc9 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -484,7 +484,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
   DVLOG(3) << prefix() << " GetNext enter";
   RecordStart(ctx, /*stop_output=*/true);
   Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
-  if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+  if (s.ok() && !*end_of_sequence) RecordElement(ctx, out_tensors);
   RecordStop(ctx, /*start_output=*/true);
   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
     s = errors::Internal("Iterator \"", params_.prefix,
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index c559dd45fce..a909e8594a2 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -962,10 +962,15 @@ class DatasetBaseIterator : public IteratorBase {
   }
 
   // When modeling is enabled, this method records the fact that this iterator
-  // has produced an element.
-  void RecordElement(IteratorContext* ctx) {
+  // has produced an element and its size in bytes.
+  void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
     if (node_) {
+      int64 num_bytes = GetAllocatedBytes(*out_tensors);
       node_->record_element();
+      node_->record_bytes_produced(num_bytes);
+      if (node_->output()) {
+        node_->output()->record_bytes_consumed(num_bytes);
+      }
     }
   }
 
diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/framework/metrics.cc
similarity index 92%
rename from tensorflow/core/common_runtime/metrics.cc
rename to tensorflow/core/framework/metrics.cc
index a2065ff2bf1..4af3d7cffcf 100644
--- a/tensorflow/core/common_runtime/metrics.cc
+++ b/tensorflow/core/framework/metrics.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/common_runtime/metrics.h"
+#include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/lib/monitoring/counter.h"
 #include "tensorflow/core/lib/monitoring/sampler.h"
 
@@ -61,6 +61,14 @@ auto* graph_unused_outputs = monitoring::Counter<1>::New(
 auto* tf_data_autotune_counter = monitoring::Counter<1>::New(
     "/tensorflow/data/autotune", "tf.data autotuning", "name");
 
+auto* tf_data_bytes_consumed_counter = monitoring::Counter<1>::New(
+    "/tensorflow/data/bytes_consumed",
+    "The number of bytes consumed by a tf.data Dataset.", "name");
+
+auto* tf_data_bytes_produced_counter = monitoring::Counter<1>::New(
+    "/tensorflow/data/bytes_produced",
+    "The number of bytes produced by a tf.data Dataset.", "name");
+
 auto* tf_data_bytes_read_counter = monitoring::Counter<1>::New(
     "/tensorflow/data/bytes_read",
     "The number of bytes read by tf.data Dataset sources.", "name");
@@ -69,18 +77,18 @@ auto* tf_data_bytes_fetched_counter = monitoring::Counter<0>::New(
     "/tensorflow/data/bytes_fetched",
     "The number of bytes fetched from tf.data Dataset iterator.");
 
-auto* tf_data_getnext_duration_counter = monitoring::Sampler<0>::New(
-    {"/tensorflow/data/getnext_duration",
-     "Microseconds spent fetching an element from tf.data Dataset iterator."},
-    // Power of 2 with bucket count 10 (1024 ms)
-    {monitoring::Buckets::Exponential(1, 2, 10)});
-
 auto* tf_data_elements_counter = monitoring::Counter<1>::New(
     "/tensorflow/data/elements", "tf.data elements", "name");
 
 auto* tf_data_fingerprint_counter = monitoring::Counter<1>::New(
     "/tensorflow/data/fingerprint", "tf.data fingerprint", "name");
 
+auto* tf_data_getnext_duration_counter = monitoring::Sampler<0>::New(
+    {"/tensorflow/data/getnext_duration",
+     "Microseconds spent fetching an element from tf.data Dataset iterator."},
+    // Power of 2 with bucket count 10 (1024 ms)
+    {monitoring::Buckets::Exponential(1, 2, 10)});
+
 auto* tf_data_optimization_counter = monitoring::Counter<1>::New(
     "/tensorflow/data/optimization", "tf.data optimization", "name");
 
@@ -132,28 +140,36 @@ void RecordTFDataAutotune(const string& name) {
   tf_data_autotune_counter->GetCell(name)->IncrementBy(1);
 }
 
+monitoring::CounterCell* GetTFDataBytesConsumedCounter(const string& name) {
+  return tf_data_bytes_consumed_counter->GetCell(name);
+}
+
+monitoring::CounterCell* GetTFDataBytesProducedCounter(const string& name) {
+  return tf_data_bytes_produced_counter->GetCell(name);
+}
+
 monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name) {
   return tf_data_bytes_read_counter->GetCell(name);
 }
 
+monitoring::CounterCell* GetTFDataElementsCounter(const string& name) {
+  return tf_data_elements_counter->GetCell(name);
+}
+
 void RecordTFDataBytesFetched(int64 num_bytes) {
   tf_data_bytes_fetched_counter->GetCell()->IncrementBy(num_bytes);
 }
 
+void RecordTFDataFingerprint(const string& name) {
+  tf_data_fingerprint_counter->GetCell(name)->IncrementBy(1);
+}
+
 void RecordTFDataGetNextDuration(uint64 duration_us) {
   static auto* tfdata_getnext_duration_cell =
       tf_data_getnext_duration_counter->GetCell();
   tfdata_getnext_duration_cell->Add(duration_us);
 }
 
-void RecordTFDataElements(const string& name, int64 num_elements) {
-  tf_data_elements_counter->GetCell(name)->IncrementBy(num_elements);
-}
-
-void RecordTFDataFingerprint(const string& name) {
-  tf_data_fingerprint_counter->GetCell(name)->IncrementBy(1);
-}
-
 void RecordTFDataOptimization(const string& name, int64 num_changes) {
   tf_data_optimization_counter->GetCell(name)->IncrementBy(num_changes);
 }
diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h
new file mode 100644
index 00000000000..7d281f97c66
--- /dev/null
+++ b/tensorflow/core/framework/metrics.h
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_METRICS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_METRICS_H_
+
+#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Records that a tf.data.Dataset executed by the program used autotuning.
+//
+// The `name` argument identifies the Dataset type (e.g. "ParallelMap").
+void RecordTFDataAutotune(const string& name);
+
+// Returns a counter that can be used to record the number of bytes produced by
+// a tf.data.Dataset.
+//
+// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map").
+monitoring::CounterCell* GetTFDataBytesConsumedCounter(const string& name);
+
+// Returns a counter that can be used to record the number of bytes produced by
+// a tf.data.Dataset.
+//
+// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map").
+monitoring::CounterCell* GetTFDataBytesProducedCounter(const string& name);
+
+// Returns a counter than can be used to record the number of bytes read from
+// the filesystem by a tf.data.Dataset source.
+//
+// The `name` argument identifies the Dataset type (e.g. "TFRecordDataset").
+//
+// TODO(jsimsa): Remove this now that we have GetTFDataBytesConsumedCounter?
+monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name);
+
+// Returns a counter than can be used to record the number of elements produced
+// by a tf.data.Dataset.
+//
+// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map").
+monitoring::CounterCell* GetTFDataElementsCounter(const string& name);
+
+// Records the number of bytes fetched from tf.data.Dataset iterator.
+void RecordTFDataBytesFetched(int64 num_bytes);
+
+// Records the time spent in ItertatorResource::GetNext() in microseconds.
+void RecordTFDataGetNextDuration(uint64 duration_us);
+
+// Records the number of times each tf.data fingerprint is used
+// to measure duplicate pre-processing.
+//
+// The `name` argument identifies the Dataset graph fingerprint,
+// created using GraphHash().
+void RecordTFDataFingerprint(const string& name);
+
+// Records the number of independent graph changes resulting from the
+// application of a tf.data optimization.
+//
+// The `name` argument identifies the optimization (e.g. "noop_elimination").
+void RecordTFDataOptimization(const string& name, int64 num_changes);
+
+// Records parsing of dense tensor features.
+void RecordParseDenseFeature(int64 num_features);
+
+// Records parsing of sparse tensor features.
+void RecordParseSparseFeature(int64 num_features);
+
+// Records parsing of ragged tensor features.
+void RecordParseRaggedFeature(int64 num_features);
+
+// Records the size of input/output tensors in bytes.
+void RecordGraphInputTensors(const size_t size);
+void RecordGraphOutputTensors(const size_t size);
+
+void UpdateGraphExecTime(const uint64 running_time_usecs);
+
+// Records that one output of an op of type `op_name` was unused.
+void RecordUnusedOutput(const string& op_name);
+
+// Updates the metrics stored about time spent building graphs.
+//
+// By "GraphBuild", we refer to building a client graph, which is a sub-graph of
+// the full graph, induced by a set of options. In particular, these options
+// include the feeds and fetches requested.
+//
+// This includes time spent:
+//   * optimizing the graphs with Grappler
+//   * pruning the sub-graph (unless the place_pruned_graph option is set)
+//
+// When executing eagerly, this will not record any activity.
+//
+// TODO(jtkeeling): Should we record building/optimizing tf.functions?
+void UpdateGraphBuildTime(const uint64 running_time_usecs);
+
+// Updates the metrics stored about graph optimizations.
+void UpdateGraphOptimizationPassTime(const string& pass_name,
+                                     const uint64 running_time_usecs);
+void UpdateGrapplerPassTime(const string& pass_name,
+                            const uint64 running_time_usecs);
+
+// Updates the metrics stored about time XLA spents compiling graphs.
+void UpdateXlaCompilationTime(const uint64 compilation_time_usecs);
+
+// Increment the number of jobs that failed during import to mlir.
+void IncrementMLIRImportFailureCount();
+
+}  // namespace metrics
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_FRAMEWORK_METRICS_H_
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 48560d16084..16f36c65d9c 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -711,23 +711,10 @@ void Model::AddProcessingTime(const string& name, int64 delta) {
   }
 }
 
-void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
-                     int64 ram_budget) {
-  switch (algorithm) {
-    case AutotuneAlgorithm::HILL_CLIMB:
-      OptimizeHillClimb(cpu_budget, ram_budget);
-      break;
-    case AutotuneAlgorithm::GRADIENT_DESCENT:
-      OptimizeGradientDescent(cpu_budget, ram_budget);
-      break;
-  }
-}
-
-void Model::RecordElement(const string& name) {
+void Model::FlushMetrics() {
   tf_shared_lock l(mu_);
-  auto node = gtl::FindOrNull(lookup_table_, name);
-  if (node) {
-    (*node)->record_element();
+  for (const auto& pair : lookup_table_) {
+    pair.second->FlushMetrics();
   }
 }
 
@@ -740,6 +727,18 @@ int64 Model::NumElements(const string& name) {
   return 0;
 }
 
+void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
+                     int64 ram_budget) {
+  switch (algorithm) {
+    case AutotuneAlgorithm::HILL_CLIMB:
+      OptimizeHillClimb(cpu_budget, ram_budget);
+      break;
+    case AutotuneAlgorithm::GRADIENT_DESCENT:
+      OptimizeGradientDescent(cpu_budget, ram_budget);
+      break;
+  }
+}
+
 void Model::RecordStart(const string& name, bool stop_output) {
   tf_shared_lock l(mu_);
   auto node = gtl::FindOrNull(lookup_table_, name);
@@ -772,7 +771,6 @@ void Model::RemoveNode(const string& name) {
       (*node)->output()->remove_input(*node);
     }
     VLOG(3) << "Removing " << (*node)->long_name();
-    remove_node_hook_(*node);
   }
   lookup_table_.erase(name);
 }
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 48d600a9f18..81dfda7acb6 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -23,6 +23,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
+#include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -124,9 +125,19 @@ class Node {
   using Factory = std::function<std::shared_ptr<Node>(Args)>;
 
   explicit Node(Args args)
-      : id_(args.id), name_(args.name), output_(args.output.get()) {}
+      : id_(args.id),
+        name_(std::move(args.name)),
+        autotune_(true),
+        buffered_bytes_(0),
+        buffered_elements_(0),
+        bytes_consumed_(0),
+        bytes_produced_(0),
+        num_elements_(0),
+        record_metrics_(true),
+        metrics_(name_),
+        output_(args.output.get()) {}
 
-  virtual ~Node() {}
+  virtual ~Node() { FlushMetrics(); }
 
   // Adds an input.
   void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) {
@@ -142,22 +153,29 @@ class Node {
 
   // Returns an indication whether autotuning is enabled for this node.
   bool autotune() const TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     return autotune_;
   }
 
   // Returns the number of bytes stored in this node's buffer.
   int64 buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     return buffered_bytes_;
   }
 
   // Returns the number of elements stored in this node's buffer.
   int64 buffered_elements() const TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     return buffered_elements_;
   }
 
+  // Returns the number of bytes consumed by the node.
+  int64 bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) {
+    return bytes_consumed_;
+  }
+
+  // Returns the number of bytes produced by the node.
+  int64 bytes_produced() const TF_LOCKS_EXCLUDED(mu_) {
+    return bytes_produced_;
+  }
+
   // Indicates whether the node has tunable parameters.
   bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) {
     tf_shared_lock l(mu_);
@@ -184,7 +202,6 @@ class Node {
 
   // Returns the number of elements produced by the node.
   int64 num_elements() const TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     return num_elements_;
   }
 
@@ -197,17 +214,20 @@ class Node {
     return processing_time_;
   }
 
+  // Records that the node consumed the given number of bytes.
+  void record_bytes_consumed(int64 num_bytes) { bytes_consumed_ += num_bytes; }
+
+  // Records that the node produced the given number of bytes.
+  void record_bytes_produced(int64 num_bytes) { bytes_produced_ += num_bytes; }
+
   // Records the change in this node's buffer.
-  void record_buffer_event(int64 bytes_delta, int64 elements_delta)
-      TF_LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
+  void record_buffer_event(int64 bytes_delta, int64 elements_delta) {
     buffered_bytes_ += bytes_delta;
     buffered_elements_ += elements_delta;
   }
 
   // Records that the node produced an element.
   void record_element() TF_LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
     num_elements_++;
   }
 
@@ -226,8 +246,7 @@ class Node {
       processing_time_ += time_nanos - iter->second;
       work_start_.erase(iter);
     } else {
-      VLOG(1)
-          << "Encountered a stop event that was not preceded by a start event.";
+      VLOG(1) << "Encountered a stop event without a matching start event.";
     }
   }
 
@@ -239,18 +258,17 @@ class Node {
 
   // Sets the value that determines whether autotuning is enabled for this node.
   void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    autotune_ = autotune;
+    autotune_.store(autotune);
   }
 
   // Collects tunable parameters in the subtree rooted in this node.
   void CollectTunableParameters(
       std::map<string, std::shared_ptr<Parameter>>* parameters) const
       TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     if (!autotune_) {
       return;
     }
+    tf_shared_lock l(mu_);
     for (auto& pair : parameters_) {
       if (pair.second->state->tunable) {
         parameters->insert(std::make_pair(long_name(), pair.second));
@@ -266,10 +284,17 @@ class Node {
     tf_shared_lock l(mu_);
     string result;
     strings::StrAppend(&result, long_name(), ":\n");
-    strings::StrAppend(&result, "  autotune=", autotune_, "\n");
-    strings::StrAppend(&result, "  buffered_bytes=", buffered_bytes_, "\n");
+    strings::StrAppend(&result, "  autotune=", autotune_.load(), "\n");
+    strings::StrAppend(&result, "  buffered_bytes=", buffered_bytes_.load(),
+                       "\n");
+    strings::StrAppend(&result,
+                       "  buffered_elements=", buffered_elements_.load(), "\n");
+    strings::StrAppend(&result, "  bytes_consumed=", bytes_consumed_.load(),
+                       "\n");
+    strings::StrAppend(&result, "  bytes_produced=", bytes_produced_.load(),
+                       "\n");
     strings::StrAppend(&result, "  processing_time=", processing_time_, "\n");
-    strings::StrAppend(&result, "  num_elements=", num_elements_, "\n");
+    strings::StrAppend(&result, "  num_elements=", num_elements_.load(), "\n");
     string inputs;
     for (auto& input : inputs_) {
       strings::StrAppend(&inputs, input->long_name(), ",");
@@ -281,6 +306,16 @@ class Node {
     return result;
   }
 
+  // Flushes the metrics recorded by this node.
+  void FlushMetrics() TF_LOCKS_EXCLUDED(mu_) {
+    if (!record_metrics_) {
+      return;
+    }
+    metrics_.record_bytes_consumed(bytes_consumed_);
+    metrics_.record_bytes_produced(bytes_produced_);
+    metrics_.record_num_elements(num_elements_);
+  }
+
   // Returns the per-element output time for this node and if `gradient` is not
   // `nullptr`, collects the gradient of the output time w.r.t. tunable
   // parameters of the subtree rooted in this node and the last input time.
@@ -301,13 +336,16 @@ class Node {
     tf_shared_lock l(mu_);
     std::shared_ptr<Node> result = Clone(output);
     {
+      result->autotune_.store(autotune_);
+      result->buffered_bytes_.store(buffered_bytes_);
+      result->buffered_elements_.store(buffered_elements_);
+      result->bytes_consumed_.store(bytes_consumed_);
+      result->bytes_produced_.store(bytes_produced_);
+      result->num_elements_.store(num_elements_);
+      result->record_metrics_.store(false);
       mutex_lock l2(result->mu_);
-      result->autotune_ = autotune_;
-      result->buffered_bytes_ = buffered_bytes_;
-      result->buffered_elements_ = buffered_elements_;
-      result->processing_time_ = processing_time_;
-      result->num_elements_ = num_elements_;
       result->parameters_ = parameters_;
+      result->processing_time_ = processing_time_;
     }
     for (auto& input : inputs_) {
       result->add_input(input->Snapshot(result));
@@ -324,10 +362,10 @@ class Node {
   // Returns the total number of bytes buffered in all nodes in the subtree for
   // which autotuning is enabled.
   double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     if (!autotune_) {
       return 0;
     }
+    tf_shared_lock l(mu_);
     double result = 0;
     auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
     if (!parameter) {
@@ -346,10 +384,10 @@ class Node {
   // autotuning is enabled. This number represents the amount of memory that
   // would be used by the subtree nodes if all of their buffers were full.
   double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
     if (!autotune_) {
       return 0;
     }
+    tf_shared_lock l(mu_);
     double result = 0;
     auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
     if (!parameter) {
@@ -374,6 +412,50 @@ class Node {
   }
 
  protected:
+  // Used for (incrementally) recording metrics. The class is thread-safe.
+  class Metrics {
+   public:
+    explicit Metrics(const string& name)
+        : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)),
+          bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)),
+          num_elements_counter_(metrics::GetTFDataElementsCounter(name)),
+          recorded_bytes_consumed_(0),
+          recorded_bytes_produced_(0),
+          recorded_num_elements_(0) {}
+
+    // Expects the total number of bytes consumed and records the delta since
+    // last invocation.
+    void record_bytes_consumed(int64 total_bytes) {
+      int64 delta =
+          total_bytes - recorded_bytes_consumed_.exchange(total_bytes);
+      bytes_consumed_counter_->IncrementBy(delta);
+    }
+
+    // Expects the total number of bytes produced and records the delta since
+    // last invocation.
+    void record_bytes_produced(int64 total_bytes) {
+      int64 delta =
+          total_bytes - recorded_bytes_produced_.exchange(total_bytes);
+      bytes_produced_counter_->IncrementBy(delta);
+    }
+
+    // Expects the total number of elements produced and records the delta since
+    // last invocation.
+    void record_num_elements(int64 total_elements) {
+      int64 delta =
+          total_elements - recorded_num_elements_.exchange(total_elements);
+      num_elements_counter_->IncrementBy(delta);
+    }
+
+   private:
+    monitoring::CounterCell* const bytes_consumed_counter_;
+    monitoring::CounterCell* const bytes_produced_counter_;
+    monitoring::CounterCell* const num_elements_counter_;
+    std::atomic<int64> recorded_bytes_consumed_;
+    std::atomic<int64> recorded_bytes_produced_;
+    std::atomic<int64> recorded_num_elements_;
+  };
+
   // Returns the number of inputs.
   int64 num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) {
     int64 num_inputs = 0;
@@ -495,13 +577,17 @@ class Node {
   // Indicates whether the subtree rooted in this node should be included in
   // autotuning. In particular, if this is `false`, then the subtree is excluded
   // from computation of output time and processing time.
-  bool autotune_ TF_GUARDED_BY(mu_) = true;
-  int64 buffered_bytes_ TF_GUARDED_BY(mu_) = 0;
-  int64 buffered_elements_ TF_GUARDED_BY(mu_) = 0;
-  int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
-  int64 num_elements_ TF_GUARDED_BY(mu_) = 0;
-  std::map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
+  std::atomic<bool> autotune_;
+  std::atomic<int64> buffered_bytes_;
+  std::atomic<int64> buffered_elements_;
+  std::atomic<int64> bytes_consumed_;
+  std::atomic<int64> bytes_produced_;
+  std::atomic<int64> num_elements_;
+  std::atomic<bool> record_metrics_;
+  Metrics metrics_;
   std::map<string, std::shared_ptr<Parameter>> parameters_ TF_GUARDED_BY(mu_);
+  int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
+  std::map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
 
   // Statistic of inputs processing time history.
   double input_processing_time_sum_ = 0.0L;
@@ -561,19 +647,8 @@ std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
 // implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
 class Model {
  public:
-  using NodeHook = std::function<void(std::shared_ptr<Node>)>;
-
   // Creates a new model.
-  //
-  // The `remove_node_hook` argument can be used to specify functionality that
-  // should be invoked before a node is removed from the model. The hook can be
-  // used for dependency injection -- to allow the model to invoke functionality
-  // from modules that it could not depend on statically.
-  Model(NodeHook remove_node_hook)
-      : collect_resource_usage_(false),
-        remove_node_hook_(std::move(remove_node_hook)) {
-    DCHECK(remove_node_hook_ != nullptr);
-  }
+  Model() : collect_resource_usage_(false) {}
 
   // Indicates whether to collect resource usage.
   bool collect_resource_usage() const { return collect_resource_usage_; }
@@ -588,16 +663,16 @@ class Model {
   void AddProcessingTime(const string& name, int64 delta)
       TF_LOCKS_EXCLUDED(mu_);
 
-  // Uses the given algorithm to perform the autotuning optimization.
-  void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget)
-      TF_LOCKS_EXCLUDED(mu_);
-
-  // Records that a node has produced an element.
-  void RecordElement(const string& name) TF_LOCKS_EXCLUDED(mu_);
+  // Flushes metrics record by the model.
+  void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
 
   // Returns the number of elements that the input pipeline has produced.
   int64 NumElements(const string& name) TF_LOCKS_EXCLUDED(mu_);
 
+  // Uses the given algorithm to perform the autotuning optimization.
+  void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget)
+      TF_LOCKS_EXCLUDED(mu_);
+
   // Records that the given node has started work. If `stop_output` is set, it
   // also records that the output of the given node has stopped work.
   void RecordStart(const string& name, bool stop_output) TF_LOCKS_EXCLUDED(mu_);
@@ -674,9 +749,6 @@ class Model {
   // tunable parameter (because the information is used for for tuning the value
   // of the parameter) and never stops.
   std::atomic<bool> collect_resource_usage_;
-
-  // A hook invoked immediately before a node is removed from the model.
-  const NodeHook remove_node_hook_;
 };
 
 }  // namespace model
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index 87e61e1d37c..8c630fd9646 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -14,8 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "absl/memory/memory.h"
-#include "tensorflow/core/common_runtime/metrics.h"
 #include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/framework/model.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -110,10 +110,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params) {
-        auto remove_node_hook = [](std::shared_ptr<model::Node> node) {
-          metrics::RecordTFDataElements(node->name(), node->num_elements());
-        };
-        model_ = std::make_shared<model::Model>(std::move(remove_node_hook));
+        model_ = std::make_shared<model::Model>();
       }
 
       ~Iterator() override {
@@ -168,16 +165,16 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
      private:
       Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-        if (!optimize_thread_) {
+        if (!model_thread_) {
           std::shared_ptr<IteratorContext> new_ctx =
               std::make_shared<IteratorContext>(*ctx);
-          optimize_thread_ = ctx->StartThread(
-              "tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); });
+          model_thread_ = ctx->StartThread(
+              "tf_data_model", [this, new_ctx]() { ModelThread(new_ctx); });
         }
         return Status::OK();
       }
 
-      void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+      void ModelThread(const std::shared_ptr<IteratorContext>& ctx) {
         int64 last_optimization_ms = 0;
         int64 optimization_period_ms = 10;
         int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
@@ -205,13 +202,14 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
           }
           current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
           last_optimization_ms = current_time_ms;
+          model_->FlushMetrics();
         }
       }
 
       mutex mu_;
       condition_variable cond_var_;
       std::shared_ptr<model::Model> model_;
-      std::unique_ptr<Thread> optimize_thread_ TF_GUARDED_BY(mu_);
+      std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
       bool cancelled_ TF_GUARDED_BY(mu_) = false;
       std::unique_ptr<IteratorBase> input_impl_;
     };