diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index a529c1fbd9b..e37a73b7b34 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -84,6 +84,30 @@ class VariableHolder(object): return wrapped +def _lift_single_variable(old_variable, graph, variable_holder): + """Lifts `old_variable` out of the `FuncGraph` `graph`.""" + new_variable = resource_variable_ops.UninitializedVariable( + shape=old_variable.shape, + dtype=old_variable.dtype, + name=old_variable.op.name, + trainable=old_variable.trainable, + extra_handle_data=old_variable.handle) + new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access + graph.inputs.append(old_variable.handle) + graph.captures[new_variable.handle] = old_variable.handle + # Now that we've added the new variable to graph.captures, + # graph.capture will use that cached value and do some post-processing + # on the capture like recording it on the tape. + graph.capture(new_variable.handle) + # pylint: disable=protected-access + variable_name = new_variable.name.split(":")[0] + variable_holder._variables_by_name[variable_name] = new_variable + graph._weak_variables.append(weakref.ref(new_variable)) + # pylint: enable=protected-access + graph.watch_variable(new_variable) + return new_variable + + def _lift_unlifted_variables(graph, variable_holder): """Finds resource variables and lifts them into the outer context. @@ -100,39 +124,44 @@ def _lift_unlifted_variables(graph, variable_holder): variable_holder: A VariableHolder to record the lifted variables in. """ with graph.as_default(): - collection_variables = ( - ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + - ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) + global_collection_variables = ops.get_collection( + ops.GraphKeys.GLOBAL_VARIABLES) + local_collection_variables = ops.get_collection( + ops.GraphKeys.LOCAL_VARIABLES) existing_captures = set(graph.internal_captures) lifted_variables = {} - for old_variable in collection_variables: - if (old_variable._in_graph_mode # pylint: disable=protected-access - and - isinstance(old_variable, resource_variable_ops.ResourceVariable)): - if old_variable.handle in existing_captures: - continue - new_variable = resource_variable_ops.UninitializedVariable( - shape=old_variable.shape, - dtype=old_variable.dtype, - name=old_variable.op.name, - trainable=old_variable.trainable, - extra_handle_data=old_variable.handle) - new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access - graph.inputs.append(old_variable.handle) - graph.captures[new_variable.handle] = old_variable.handle - # Now that we've added the new variable to graph.captures, - # graph.capture will use that cached value and do some post-processing - # on the capture like recording it on the tape. - graph.capture(new_variable.handle) - existing_captures.add(old_variable.handle) + + def _should_lift_variable(v): + return ((v._in_graph_mode # pylint: disable=protected-access + and v.graph.building_function) + and isinstance(v, resource_variable_ops.ResourceVariable) + and v.handle not in existing_captures) + + for old_variable in global_collection_variables: + if _should_lift_variable(old_variable): + new_variable = _lift_single_variable( + old_variable, graph, variable_holder) lifted_variables[old_variable] = new_variable - # pylint: disable=protected-access - variable_name = new_variable.name.split(":")[0] - variable_holder._variables_by_name[variable_name] = new_variable - graph._weak_variables.append(weakref.ref(new_variable)) - # pylint: enable=protected-access - graph.watch_variable(new_variable) - # Update the graph's collections, partly for the user and partly so this + existing_captures.add(old_variable.handle) + + for old_variable in local_collection_variables: + if _should_lift_variable(old_variable): + new_variable = _lift_single_variable( + old_variable, graph, variable_holder) + lifted_variables[old_variable] = new_variable + existing_captures.add(old_variable.handle) + if new_variable._in_graph_mode: # pylint: disable=protected-access + outer_graph = new_variable.graph + # Variables are added to the global collection by default. In this + # case we only want the variable in the local collection, so we'll pop + # it out. + global_collection = outer_graph.get_collection_ref( + ops.GraphKeys.GLOBAL_VARIABLES) + global_collection.remove(new_variable) + outer_graph.add_to_collection( + ops.GraphKeys.LOCAL_VARIABLES, new_variable) + + # Update the FuncGraph's collections, partly for the user and partly so this # function is idempotent when it runs again in prune() calls. for collection_name in [ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES @@ -148,9 +177,7 @@ class WrappedFunction(function.ConcreteFunction): def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): self._variable_holder = variable_holder - if ops.executing_eagerly_outside_functions(): - # TODO(allenl): Make this work in 1.x? - _lift_unlifted_variables(fn_graph, variable_holder) + _lift_unlifted_variables(fn_graph, variable_holder) # We call __init__ after lifting variables so that the function's signature # properly reflects the new captured inputs. super(WrappedFunction, self).__init__( diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index cb1464be780..15cd756bf7f 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -20,10 +20,12 @@ from __future__ import print_function import functools +from tensorflow.python.eager import context from tensorflow.python.eager import lift_to_graph from tensorflow.python.eager import wrap_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import signature_serialization @@ -56,7 +58,7 @@ class _Initializer(tracking.CapturableResource): dtype=dtypes.resource, shape=[], name="unused_resource") def _initialize(self): - self._init_fn(*[path.asset_path for path in self._asset_paths]) + return self._init_fn(*[path.asset_path for path in self._asset_paths]) class _EagerSavedModelLoader(loader_impl.SavedModelLoader): @@ -90,11 +92,21 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): """Restores variables from the checkpoint.""" if saver is not None: saver_def = saver.saver_def + filename_tensor = wrapped.graph.as_graph_element( + saver_def.filename_tensor_name) + # We both feed and fetch filename_tensor so we have an operation to use to + # feed into variable initializers (only relevant for v1 graph building). restore_fn = wrapped.prune( - feeds=[wrapped.graph.as_graph_element( - saver_def.filename_tensor_name)], - fetches=[wrapped.graph.as_graph_element(saver_def.restore_op_name)]) - restore_fn(constant_op.constant(self._variables_path)) + feeds=[filename_tensor], + fetches=[filename_tensor, + wrapped.graph.as_graph_element(saver_def.restore_op_name)]) + initializer, _ = restore_fn(constant_op.constant(self._variables_path)) + if not ops.executing_eagerly_outside_functions(): + for variable in wrapped.graph.get_collection_ref( + ops.GraphKeys.GLOBAL_VARIABLES): + # pylint: disable=protected-access + variable._initializer_op = initializer + # pylint: enable=protected-access def _extract_signatures(self, wrapped, meta_graph_def): """Creates ConcreteFunctions for signatures in `meta_graph_def`.""" @@ -151,6 +163,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): with wrapped.graph.as_default(): init_op = loader_impl.get_init_op( meta_graph_def) or monitored_session.Scaffold.default_local_init_op() + # Add a dummy Tensor we know we can fetch to add control dependencies to. + init_anchor = constant_op.constant(0., name="dummy_fetch") root = tracking.AutoTrackable() asset_feed_tensors = [] @@ -161,9 +175,19 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): asset_paths.append(tracking.TrackableAsset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, - fetches=[wrapped.graph.as_graph_element(init_op)]) + fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) - initializer._initialize() # pylint: disable=protected-access + # pylint: disable=protected-access + local_init_op, _ = initializer._initialize() + # pylint: enable=protected-access + with ops.init_scope(): + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) + for variable in wrapped.graph.get_collection_ref( + ops.GraphKeys.LOCAL_VARIABLES): + # pylint: disable=protected-access + variable._initializer_op = local_init_op + # pylint: enable=protected-access root.initializer = initializer root.asset_paths = asset_paths signature_functions = self._extract_signatures(wrapped, meta_graph_def) @@ -182,3 +206,4 @@ def load(export_dir, tags): """Load a v1-style SavedModel as an object.""" loader = _EagerSavedModelLoader(export_dir) return loader.load(tags=tags) + diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 6a27a268a41..9f09d524424 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.framework import versions from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops @@ -51,7 +52,7 @@ class LoadTest(test.TestCase): export_graph = ops.Graph() with export_graph.as_default(): start = array_ops.placeholder( - shape=[None], dtype=dtypes.float32, name="start") + shape=None, dtype=dtypes.float32, name="start") if use_resource: distractor = variables.RefVariable(-1., name="distractor") v = resource_variable_ops.ResourceVariable(3., name="v") @@ -81,17 +82,20 @@ class LoadTest(test.TestCase): legacy_init_op=local_variable.initializer) return path + @test_util.run_in_graph_and_eager_modes def test_resource_variable_import(self): imported = load.load(self._v1_single_metagraph_saved_model( use_resource=True)) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(variables.local_variables_initializer()) fn = imported.signatures["serving_default"] self.assertEqual({"output": 6.}, self.evaluate(fn(constant_op.constant(2.)))) self.assertAllEqual([3., 1.], self.evaluate(imported.variables)) - imported.variables[0].assign(4.) + self.evaluate(imported.variables[0].assign(4.)) self.assertEqual({"output": 8.}, self.evaluate(fn(start=constant_op.constant(2.)))) - imported.variables[1].assign(2.) + self.evaluate(imported.variables[1].assign(2.)) self.assertEqual({"output": 24.}, self.evaluate(fn(start=constant_op.constant(3.)))) self.assertTrue(imported.variables[0].trainable) @@ -99,7 +103,9 @@ class LoadTest(test.TestCase): with backprop.GradientTape() as tape: output = fn(start=constant_op.constant(4.)) self.assertEqual(imported.variables[:1], list(tape.watched_variables())) - self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy()) + self.assertEqual( + 8., + self.evaluate(tape.gradient(output, imported.variables[0]))) def test_ref_variable_import(self): saved = self._v1_single_metagraph_saved_model(use_resource=False) @@ -185,9 +191,11 @@ class LoadTest(test.TestCase): file_io.delete_file(vocab_path) return path + @test_util.run_in_graph_and_eager_modes def test_asset_loading(self): first_path = self._v1_asset_saved_model() imported = load.load(first_path) + self.evaluate(lookup_ops.tables_initializer()) fn = imported.signatures["serving_default"] self.assertAllClose({"output": [2, 0]}, fn(start=constant_op.constant(["gamma", "alpha"]))) @@ -195,7 +203,9 @@ class LoadTest(test.TestCase): str(ops.uid())) save.save(imported, second_path, signatures=imported.signatures) shutil.rmtree(first_path) + del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:] second_import = load.load(second_path) + self.evaluate(lookup_ops.tables_initializer()) fn = second_import.signatures["serving_default"] self.assertAllClose({"output": [2, 0]}, fn(start=constant_op.constant(["gamma", "alpha"]))) @@ -204,7 +214,9 @@ class LoadTest(test.TestCase): str(ops.uid())) save.save(second_import, third_path, signatures=second_import.signatures) shutil.rmtree(second_path) + del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:] third_import = load.load(third_path) + self.evaluate(lookup_ops.tables_initializer()) fn = third_import.signatures["serving_default"] self.assertAllClose({"output": [2, 0]}, fn(start=constant_op.constant(["gamma", "alpha"]))) @@ -368,3 +380,4 @@ class LoadTest(test.TestCase): if __name__ == "__main__": test.main() +