From e08382691bfb897d584c5d5a8e8a0abe0472373d Mon Sep 17 00:00:00 2001
From: Chenkai Kuang <chenkai@google.com>
Date: Thu, 18 Jun 2020 15:54:03 -0700
Subject: [PATCH] Make "map_resources" overridable by subclass of `Trackable`.

This allows moving the implementation of map_resources from `tf.saved_model.save` to subclass of `Trackable`, e.g, Variable, DistributedVariable.

This is a non-functional change.

PiperOrigin-RevId: 317198449
Change-Id: I4aa48d4974b6547b5de8ac0f5c38f3da29d364bc
---
 tensorflow/python/distribute/BUILD            |  4 +--
 tensorflow/python/distribute/values.py        | 12 +++++++
 .../experimental/autocast_variable.py         |  7 ++++
 .../python/ops/resource_variable_ops.py       |  7 ++++
 tensorflow/python/saved_model/save.py         | 36 +++++--------------
 tensorflow/python/training/tracking/base.py   | 18 ++++++++++
 .../python/training/tracking/tracking.py      | 13 +++++++
 7 files changed, 66 insertions(+), 31 deletions(-)

diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 96559a9a740..7208807a18c 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -744,14 +744,12 @@ py_library(
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:tensor_util",
+        "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tf_export",
         "//tensorflow/python:type_spec",
-        "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/eager:tape",
         "//tensorflow/python/training/saving:saveable_object",
         "//tensorflow/python/training/saving:saveable_object_util",
         "//tensorflow/python/training/tracking:base",
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index d0ed27c69de..60b2ea4fe31 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.training.saving import saveable_object
@@ -793,6 +794,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
       return ops.convert_to_tensor(
           self._get(), dtype=dtype, name=name, as_ref=as_ref)
 
+  def _map_resources(self):
+    """For implementing `Trackable`."""
+    new_obj = resource_variable_ops.copy_to_graph_uninitialized(self._primary)
+    obj_map, resource_map = {}, {}
+    for v in self._values:
+      obj_map[v] = new_obj
+      resource_map[v.handle] = new_obj.handle
+    obj_map[self] = new_obj
+    resource_map[self] = new_obj.handle
+    return obj_map, resource_map
+
 
 class _DistributedVariableSaveable(saveable_object.SaveableObject):
   """Class for defining how to restore a DistributedVariable."""
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
index 7d0abe30581..57e8ced65a0 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
@@ -285,6 +285,13 @@ class AutoCastVariable(variables.Variable, core.Tensor):
     # models with normal variables, and vice versa.
     return self._variable._gather_saveables_for_checkpoint()  # pylint:disable=protected-access
 
+  def _map_resources(self):
+    # By delegating this method to the wrapped variable, SavedModel with
+    # AutoCastVariables are identical to SavedModel with normal variables.
+    obj_map, resource_map = self._variable._map_resources()  # pylint:disable=protected-access
+    obj_map[self] = obj_map[self._variable]
+    return obj_map, resource_map
+
   # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
   # to_proto().
   def to_proto(self, export_scope=None):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 25f6347f034..cb235fcbe2d 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -633,6 +633,13 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
     return gen_state_ops.resource_count_up_to(self.handle, limit=limit,
                                               T=self.dtype)
 
+  def _map_resources(self):
+    """For implementing `Trackable`."""
+    new_variable = copy_to_graph_uninitialized(self)
+    obj_map = {self: new_variable}
+    resource_map = {self._handle: new_variable.handle}
+    return obj_map, resource_map
+
   def _read_variable_op(self):
     variable_accessed(self)
 
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 5844c80995f..802ce1d61b7 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -19,14 +19,12 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import copy
 import os
 
 from tensorflow.core.framework import versions_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import saved_model_pb2
 from tensorflow.core.protobuf import saved_object_graph_pb2
-from tensorflow.python.distribute import distribute_utils as ds_utils
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function as defun
@@ -241,7 +239,7 @@ class _SaveableView(object):
     Creates resource handle ops in the current default graph, whereas
     `accessible_objects` will be from an eager context. Resource mapping adds
     resource handle ops to the main GraphDef of a SavedModel, which allows the
-    C++ loader API to interact with variables.
+    C++ loader API to interact with resources.
 
     Returns:
       A tuple of (object_map, resource_map, asset_info):
@@ -265,33 +263,15 @@ class _SaveableView(object):
         asset_index={})
 
     for node_id, obj in enumerate(self.nodes):
-      if isinstance(obj, tracking.CapturableResource):
-        new_obj = object_map[obj] = copy.copy(obj)
-        # pylint: disable=protected-access
-        with ops.device(obj._resource_device):
-          new_resource = new_obj._create_resource()
-        new_obj._resource_handle = new_resource
-        # pylint: enable=protected-access
-        resource_map[obj.resource_handle] = new_resource
-        self.captured_tensor_node_ids[obj.resource_handle] = node_id
-      elif (ds_utils.is_distributed_variable(obj) or
-            resource_variable_ops.is_resource_variable(obj)):
-        obj_to_copy = obj._primary if ds_utils.is_distributed_variable(  # pylint: disable=protected-access
-            obj) else obj
-        new_variable = resource_variable_ops.copy_to_graph_uninitialized(
-            obj_to_copy)
-        if ds_utils.is_distributed_variable(obj):
-          self.captured_tensor_node_ids[obj] = node_id
-          for v in obj.values:
-            object_map[v] = new_variable
-            resource_map[v.handle] = new_variable.handle
-            self.captured_tensor_node_ids[v.handle] = node_id
-        object_map[obj] = new_variable
-        resource_map[obj.handle] = new_variable.handle
-        self.captured_tensor_node_ids[obj.handle] = node_id
-      elif isinstance(obj, tracking.Asset):
+      if isinstance(obj, tracking.Asset):
         _process_asset(obj, asset_info, resource_map)
         self.captured_tensor_node_ids[obj.asset_path] = node_id
+      elif isinstance(obj, base.Trackable):
+        node_object_map, node_resource_map = obj._map_resources()  # pylint: disable=protected-access
+        for capturable in node_resource_map.keys():
+          self.captured_tensor_node_ids[capturable] = node_id
+        object_map.update(node_object_map)
+        resource_map.update(node_resource_map)
 
     # Note: some concrete functions can have been realized when tracing other
     # functions, and might closure-capture tensors from their parent functions.
diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py
index e3cd9828724..ea76ad8db47 100644
--- a/tensorflow/python/training/tracking/base.py
+++ b/tensorflow/python/training/tracking/base.py
@@ -1021,3 +1021,21 @@ class Trackable(object):
     """
     del serialization_cache
     return dict()
+
+  def _map_resources(self):
+    """Makes new resource handle ops corresponding to existing resource tensors.
+
+    Internal sub-classes can override this to inform model saving how to add new
+    resource handle ops to the main GraphDef of a SavedModel (TF 1.x style
+    graph), which allows session based APIs (e.g, C++ loader API) to interact
+    with resources owned by this object.
+
+    Returns:
+      A tuple of (object_map, resource_map):
+        object_map: A dictionary mapping from objects that hold existing
+          resource tensors to replacement objects created to hold the new
+          resource tensors.
+        resource_map: A dictionary mapping from existing resource tensors to
+          newly created resource tensors.
+    """
+    return {}, {}
diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py
index 553f0ec73bf..fb2735e6445 100644
--- a/tensorflow/python/training/tracking/tracking.py
+++ b/tensorflow/python/training/tracking/tracking.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import copy
 import functools
 import weakref
 
@@ -243,6 +244,18 @@ class CapturableResource(base.Trackable):
         self._resource_handle = self._create_resource()
     return self._resource_handle
 
+  def _map_resources(self):
+    """For implementing `Trackable`."""
+    new_obj = copy.copy(self)
+    # pylint: disable=protected-access
+    with ops.device(self._resource_device):
+      new_resource = new_obj._create_resource()
+    new_obj._resource_handle = new_resource
+    # pylint: enable=protected-access
+    obj_map = {self: new_obj}
+    resource_map = {self.resource_handle: new_resource}
+    return obj_map, resource_map
+
   def _list_functions_for_serialization(self, unused_functions):
     @def_function.function(input_signature=[], autograph=False)
     def _creator():