Merge pull request from geetachavan1/cherrypicks_AR9ZZ

[Cherrypick:r2.2] SavedModel loading: Create temporary graphs for importing functions to avoid leaking memory
This commit is contained in:
Mihai Maruseac 2020-12-17 09:25:20 -08:00 committed by GitHub
commit 10e6680ec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 5 deletions
tensorflow/python/saved_model

View File

@ -22,7 +22,6 @@ import collections
import re
from tensorflow.core.framework import function_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as function_lib
from tensorflow.python.framework import func_graph as func_graph_lib
@ -298,6 +297,19 @@ def load_function_def_library(library, load_shared_name_suffix=None):
functions = {}
renamed_functions = {}
# Our graph building code currently requires functions to be registered with
# some tf.Graph in order to import functions using the
# op-name-is-function-name calling convention. To avoid leaking memory into
# the global default graph when executing eagerly, we create a temporary
# Graph.
#
# TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
# fixing function_def_to_graph_def.
if ops.executing_eagerly_outside_functions():
graph = ops.Graph()
else:
graph = ops.get_default_graph()
if load_shared_name_suffix is None:
load_shared_name_suffix = "_load_{}".format(ops.uid())
for fdef in _sort_function_defs(library, library_function_names):
@ -308,18 +320,22 @@ def load_function_def_library(library, load_shared_name_suffix=None):
# extra function definitions are a no-op since they already imported as a
# function before and passed in explicitly (due to the topologic sort
# import).
func_graph = function_def_lib.function_def_to_graph(copy)
with graph.as_default():
func_graph = function_def_lib.function_def_to_graph(copy)
_restore_gradient_functions(func_graph, renamed_functions)
for dep in _list_function_deps(fdef, library_function_names):
functions[dep].add_to_graph(func_graph)
func = function_lib.ConcreteFunction(func_graph)
func.add_to_graph()
if context.executing_eagerly():
func.add_to_graph(ops.get_default_graph())
func.add_to_graph(graph)
functions[fdef.signature.name] = func
renamed_functions[func.name] = func
if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
# TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
# is fixed. Currently it's leaking memory to maintain bug compatibility
# with previous behavior.
func.add_to_graph(ops.get_default_graph())
return functions

View File

@ -1992,6 +1992,21 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
self.assertAllClose({"output_0": 13},
imported.signatures["serving_default"]())
# TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3
# iterations took hundreds of seconds). It would be really nice to check
# allocations at a lower level.
@test_util.assert_no_new_pyobjects_executing_eagerly
def test_functions_cleaned(self):
if sys.version_info.major < 3:
self.skipTest("Not working in Python 2")
root = module.Module()
root.v = variables.Variable(1.)
root.f = def_function.function(
lambda x: x + root.v,
input_signature=[
tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)])
cycle(root, 1)
if __name__ == "__main__":
test.main()