Make "map_resources" overridable by subclass of Trackable.

This allows moving the implementation of map_resources from `tf.saved_model.save` to subclass of `Trackable`, e.g, Variable, DistributedVariable.

This is a non-functional change.

PiperOrigin-RevId: 317198449
Change-Id: I4aa48d4974b6547b5de8ac0f5c38f3da29d364bc
This commit is contained in:
Chenkai Kuang 2020-06-18 15:54:03 -07:00 committed by TensorFlower Gardener
parent 0deffad6ac
commit e08382691b
7 changed files with 66 additions and 31 deletions
tensorflow/python
distribute
keras/mixed_precision/experimental
ops
saved_model
training/tracking

View File

@ -744,14 +744,12 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tf_export",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
"//tensorflow/python/training/tracking:base",

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training.saving import saveable_object
@ -793,6 +794,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
def _map_resources(self):
"""For implementing `Trackable`."""
new_obj = resource_variable_ops.copy_to_graph_uninitialized(self._primary)
obj_map, resource_map = {}, {}
for v in self._values:
obj_map[v] = new_obj
resource_map[v.handle] = new_obj.handle
obj_map[self] = new_obj
resource_map[self] = new_obj.handle
return obj_map, resource_map
class _DistributedVariableSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a DistributedVariable."""

View File

@ -285,6 +285,13 @@ class AutoCastVariable(variables.Variable, core.Tensor):
# models with normal variables, and vice versa.
return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access
def _map_resources(self):
# By delegating this method to the wrapped variable, SavedModel with
# AutoCastVariables are identical to SavedModel with normal variables.
obj_map, resource_map = self._variable._map_resources() # pylint:disable=protected-access
obj_map[self] = obj_map[self._variable]
return obj_map, resource_map
# TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
# to_proto().
def to_proto(self, export_scope=None):

View File

@ -633,6 +633,13 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
return gen_state_ops.resource_count_up_to(self.handle, limit=limit,
T=self.dtype)
def _map_resources(self):
"""For implementing `Trackable`."""
new_variable = copy_to_graph_uninitialized(self)
obj_map = {self: new_variable}
resource_map = {self._handle: new_variable.handle}
return obj_map, resource_map
def _read_variable_op(self):
variable_accessed(self)

View File

@ -19,14 +19,12 @@ from __future__ import division
from __future__ import print_function
import collections
import copy
import os
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.distribute import distribute_utils as ds_utils
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@ -241,7 +239,7 @@ class _SaveableView(object):
Creates resource handle ops in the current default graph, whereas
`accessible_objects` will be from an eager context. Resource mapping adds
resource handle ops to the main GraphDef of a SavedModel, which allows the
C++ loader API to interact with variables.
C++ loader API to interact with resources.
Returns:
A tuple of (object_map, resource_map, asset_info):
@ -265,33 +263,15 @@ class _SaveableView(object):
asset_index={})
for node_id, obj in enumerate(self.nodes):
if isinstance(obj, tracking.CapturableResource):
new_obj = object_map[obj] = copy.copy(obj)
# pylint: disable=protected-access
with ops.device(obj._resource_device):
new_resource = new_obj._create_resource()
new_obj._resource_handle = new_resource
# pylint: enable=protected-access
resource_map[obj.resource_handle] = new_resource
self.captured_tensor_node_ids[obj.resource_handle] = node_id
elif (ds_utils.is_distributed_variable(obj) or
resource_variable_ops.is_resource_variable(obj)):
obj_to_copy = obj._primary if ds_utils.is_distributed_variable( # pylint: disable=protected-access
obj) else obj
new_variable = resource_variable_ops.copy_to_graph_uninitialized(
obj_to_copy)
if ds_utils.is_distributed_variable(obj):
self.captured_tensor_node_ids[obj] = node_id
for v in obj.values:
object_map[v] = new_variable
resource_map[v.handle] = new_variable.handle
self.captured_tensor_node_ids[v.handle] = node_id
object_map[obj] = new_variable
resource_map[obj.handle] = new_variable.handle
self.captured_tensor_node_ids[obj.handle] = node_id
elif isinstance(obj, tracking.Asset):
if isinstance(obj, tracking.Asset):
_process_asset(obj, asset_info, resource_map)
self.captured_tensor_node_ids[obj.asset_path] = node_id
elif isinstance(obj, base.Trackable):
node_object_map, node_resource_map = obj._map_resources() # pylint: disable=protected-access
for capturable in node_resource_map.keys():
self.captured_tensor_node_ids[capturable] = node_id
object_map.update(node_object_map)
resource_map.update(node_resource_map)
# Note: some concrete functions can have been realized when tracing other
# functions, and might closure-capture tensors from their parent functions.

View File

@ -1021,3 +1021,21 @@ class Trackable(object):
"""
del serialization_cache
return dict()
def _map_resources(self):
"""Makes new resource handle ops corresponding to existing resource tensors.
Internal sub-classes can override this to inform model saving how to add new
resource handle ops to the main GraphDef of a SavedModel (TF 1.x style
graph), which allows session based APIs (e.g, C++ loader API) to interact
with resources owned by this object.
Returns:
A tuple of (object_map, resource_map):
object_map: A dictionary mapping from objects that hold existing
resource tensors to replacement objects created to hold the new
resource tensors.
resource_map: A dictionary mapping from existing resource tensors to
newly created resource tensors.
"""
return {}, {}

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import functools
import weakref
@ -243,6 +244,18 @@ class CapturableResource(base.Trackable):
self._resource_handle = self._create_resource()
return self._resource_handle
def _map_resources(self):
"""For implementing `Trackable`."""
new_obj = copy.copy(self)
# pylint: disable=protected-access
with ops.device(self._resource_device):
new_resource = new_obj._create_resource()
new_obj._resource_handle = new_resource
# pylint: enable=protected-access
obj_map = {self: new_obj}
resource_map = {self.resource_handle: new_resource}
return obj_map, resource_map
def _list_functions_for_serialization(self, unused_functions):
@def_function.function(input_signature=[], autograph=False)
def _creator():