Merge pull request #45719 from geetachavan1/cherrypicks_AR9ZZ
[Cherrypick:r2.2] SavedModel loading: Create temporary graphs for importing functions to avoid leaking memory
This commit is contained in:
commit
10e6680ec3
tensorflow/python/saved_model
@ -22,7 +22,6 @@ import collections
|
||||
import re
|
||||
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as function_lib
|
||||
from tensorflow.python.framework import func_graph as func_graph_lib
|
||||
@ -298,6 +297,19 @@ def load_function_def_library(library, load_shared_name_suffix=None):
|
||||
functions = {}
|
||||
renamed_functions = {}
|
||||
|
||||
# Our graph building code currently requires functions to be registered with
|
||||
# some tf.Graph in order to import functions using the
|
||||
# op-name-is-function-name calling convention. To avoid leaking memory into
|
||||
# the global default graph when executing eagerly, we create a temporary
|
||||
# Graph.
|
||||
#
|
||||
# TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
|
||||
# fixing function_def_to_graph_def.
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
graph = ops.Graph()
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
|
||||
if load_shared_name_suffix is None:
|
||||
load_shared_name_suffix = "_load_{}".format(ops.uid())
|
||||
for fdef in _sort_function_defs(library, library_function_names):
|
||||
@ -308,18 +320,22 @@ def load_function_def_library(library, load_shared_name_suffix=None):
|
||||
# extra function definitions are a no-op since they already imported as a
|
||||
# function before and passed in explicitly (due to the topologic sort
|
||||
# import).
|
||||
func_graph = function_def_lib.function_def_to_graph(copy)
|
||||
with graph.as_default():
|
||||
func_graph = function_def_lib.function_def_to_graph(copy)
|
||||
_restore_gradient_functions(func_graph, renamed_functions)
|
||||
|
||||
for dep in _list_function_deps(fdef, library_function_names):
|
||||
functions[dep].add_to_graph(func_graph)
|
||||
func = function_lib.ConcreteFunction(func_graph)
|
||||
func.add_to_graph()
|
||||
if context.executing_eagerly():
|
||||
func.add_to_graph(ops.get_default_graph())
|
||||
func.add_to_graph(graph)
|
||||
|
||||
functions[fdef.signature.name] = func
|
||||
renamed_functions[func.name] = func
|
||||
if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
|
||||
# TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
|
||||
# is fixed. Currently it's leaking memory to maintain bug compatibility
|
||||
# with previous behavior.
|
||||
func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
return functions
|
||||
|
||||
|
@ -1992,6 +1992,21 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose({"output_0": 13},
|
||||
imported.signatures["serving_default"]())
|
||||
|
||||
# TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3
|
||||
# iterations took hundreds of seconds). It would be really nice to check
|
||||
# allocations at a lower level.
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def test_functions_cleaned(self):
|
||||
if sys.version_info.major < 3:
|
||||
self.skipTest("Not working in Python 2")
|
||||
root = module.Module()
|
||||
root.v = variables.Variable(1.)
|
||||
root.f = def_function.function(
|
||||
lambda x: x + root.v,
|
||||
input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)])
|
||||
cycle(root, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user