From e08382691bfb897d584c5d5a8e8a0abe0472373d Mon Sep 17 00:00:00 2001 From: Chenkai Kuang <chenkai@google.com> Date: Thu, 18 Jun 2020 15:54:03 -0700 Subject: [PATCH] Make "map_resources" overridable by subclass of `Trackable`. This allows moving the implementation of map_resources from `tf.saved_model.save` to subclass of `Trackable`, e.g, Variable, DistributedVariable. This is a non-functional change. PiperOrigin-RevId: 317198449 Change-Id: I4aa48d4974b6547b5de8ac0f5c38f3da29d364bc --- tensorflow/python/distribute/BUILD | 4 +-- tensorflow/python/distribute/values.py | 12 +++++++ .../experimental/autocast_variable.py | 7 ++++ .../python/ops/resource_variable_ops.py | 7 ++++ tensorflow/python/saved_model/save.py | 36 +++++-------------- tensorflow/python/training/tracking/base.py | 18 ++++++++++ .../python/training/tracking/tracking.py | 13 +++++++ 7 files changed, 66 insertions(+), 31 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 96559a9a740..7208807a18c 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -744,14 +744,12 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_util", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tf_export", "//tensorflow/python:type_spec", - "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", "//tensorflow/python/training/saving:saveable_object", "//tensorflow/python/training/saving:saveable_object_util", "//tensorflow/python/training/tracking:base", diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index d0ed27c69de..60b2ea4fe31 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training.saving import saveable_object @@ -793,6 +794,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return ops.convert_to_tensor( self._get(), dtype=dtype, name=name, as_ref=as_ref) + def _map_resources(self): + """For implementing `Trackable`.""" + new_obj = resource_variable_ops.copy_to_graph_uninitialized(self._primary) + obj_map, resource_map = {}, {} + for v in self._values: + obj_map[v] = new_obj + resource_map[v.handle] = new_obj.handle + obj_map[self] = new_obj + resource_map[self] = new_obj.handle + return obj_map, resource_map + class _DistributedVariableSaveable(saveable_object.SaveableObject): """Class for defining how to restore a DistributedVariable.""" diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py index 7d0abe30581..57e8ced65a0 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py @@ -285,6 +285,13 @@ class AutoCastVariable(variables.Variable, core.Tensor): # models with normal variables, and vice versa. return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access + def _map_resources(self): + # By delegating this method to the wrapped variable, SavedModel with + # AutoCastVariables are identical to SavedModel with normal variables. + obj_map, resource_map = self._variable._map_resources() # pylint:disable=protected-access + obj_map[self] = obj_map[self._variable] + return obj_map, resource_map + # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in # to_proto(). def to_proto(self, export_scope=None): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 25f6347f034..cb235fcbe2d 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -633,6 +633,13 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor): return gen_state_ops.resource_count_up_to(self.handle, limit=limit, T=self.dtype) + def _map_resources(self): + """For implementing `Trackable`.""" + new_variable = copy_to_graph_uninitialized(self) + obj_map = {self: new_variable} + resource_map = {self._handle: new_variable.handle} + return obj_map, resource_map + def _read_variable_op(self): variable_accessed(self) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 5844c80995f..802ce1d61b7 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -19,14 +19,12 @@ from __future__ import division from __future__ import print_function import collections -import copy import os from tensorflow.core.framework import versions_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saved_object_graph_pb2 -from tensorflow.python.distribute import distribute_utils as ds_utils from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun @@ -241,7 +239,7 @@ class _SaveableView(object): Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the - C++ loader API to interact with variables. + C++ loader API to interact with resources. Returns: A tuple of (object_map, resource_map, asset_info): @@ -265,33 +263,15 @@ class _SaveableView(object): asset_index={}) for node_id, obj in enumerate(self.nodes): - if isinstance(obj, tracking.CapturableResource): - new_obj = object_map[obj] = copy.copy(obj) - # pylint: disable=protected-access - with ops.device(obj._resource_device): - new_resource = new_obj._create_resource() - new_obj._resource_handle = new_resource - # pylint: enable=protected-access - resource_map[obj.resource_handle] = new_resource - self.captured_tensor_node_ids[obj.resource_handle] = node_id - elif (ds_utils.is_distributed_variable(obj) or - resource_variable_ops.is_resource_variable(obj)): - obj_to_copy = obj._primary if ds_utils.is_distributed_variable( # pylint: disable=protected-access - obj) else obj - new_variable = resource_variable_ops.copy_to_graph_uninitialized( - obj_to_copy) - if ds_utils.is_distributed_variable(obj): - self.captured_tensor_node_ids[obj] = node_id - for v in obj.values: - object_map[v] = new_variable - resource_map[v.handle] = new_variable.handle - self.captured_tensor_node_ids[v.handle] = node_id - object_map[obj] = new_variable - resource_map[obj.handle] = new_variable.handle - self.captured_tensor_node_ids[obj.handle] = node_id - elif isinstance(obj, tracking.Asset): + if isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id + elif isinstance(obj, base.Trackable): + node_object_map, node_resource_map = obj._map_resources() # pylint: disable=protected-access + for capturable in node_resource_map.keys(): + self.captured_tensor_node_ids[capturable] = node_id + object_map.update(node_object_map) + resource_map.update(node_resource_map) # Note: some concrete functions can have been realized when tracing other # functions, and might closure-capture tensors from their parent functions. diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index e3cd9828724..ea76ad8db47 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -1021,3 +1021,21 @@ class Trackable(object): """ del serialization_cache return dict() + + def _map_resources(self): + """Makes new resource handle ops corresponding to existing resource tensors. + + Internal sub-classes can override this to inform model saving how to add new + resource handle ops to the main GraphDef of a SavedModel (TF 1.x style + graph), which allows session based APIs (e.g, C++ loader API) to interact + with resources owned by this object. + + Returns: + A tuple of (object_map, resource_map): + object_map: A dictionary mapping from objects that hold existing + resource tensors to replacement objects created to hold the new + resource tensors. + resource_map: A dictionary mapping from existing resource tensors to + newly created resource tensors. + """ + return {}, {} diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index 553f0ec73bf..fb2735e6445 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import functools import weakref @@ -243,6 +244,18 @@ class CapturableResource(base.Trackable): self._resource_handle = self._create_resource() return self._resource_handle + def _map_resources(self): + """For implementing `Trackable`.""" + new_obj = copy.copy(self) + # pylint: disable=protected-access + with ops.device(self._resource_device): + new_resource = new_obj._create_resource() + new_obj._resource_handle = new_resource + # pylint: enable=protected-access + obj_map = {self: new_obj} + resource_map = {self.resource_handle: new_resource} + return obj_map, resource_map + def _list_functions_for_serialization(self, unused_functions): @def_function.function(input_signature=[], autograph=False) def _creator():