SavedModel loading: Create temporary graphs for importing functions to avoid leaking memory
We need to add functions to some tf.Graph at the moment, but it doesn't need to be held in a global variable. Leaves the memory leak for functions containing TRTEngineOp, since tests are failing otherwise. I've filed a bug to investigate. PiperOrigin-RevId: 298701457 Change-Id: I5e2c8a07eeaec10c019bcec419a992cb65adc5f7
This commit is contained in:
parent
77a6f3b983
commit
3421416220
@ -22,7 +22,6 @@ import collections
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from tensorflow.core.framework import function_pb2
|
from tensorflow.core.framework import function_pb2
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function as function_lib
|
from tensorflow.python.eager import function as function_lib
|
||||||
from tensorflow.python.framework import func_graph as func_graph_lib
|
from tensorflow.python.framework import func_graph as func_graph_lib
|
||||||
@ -298,6 +297,19 @@ def load_function_def_library(library, load_shared_name_suffix=None):
|
|||||||
functions = {}
|
functions = {}
|
||||||
renamed_functions = {}
|
renamed_functions = {}
|
||||||
|
|
||||||
|
# Our graph building code currently requires functions to be registered with
|
||||||
|
# some tf.Graph in order to import functions using the
|
||||||
|
# op-name-is-function-name calling convention. To avoid leaking memory into
|
||||||
|
# the global default graph when executing eagerly, we create a temporary
|
||||||
|
# Graph.
|
||||||
|
#
|
||||||
|
# TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
|
||||||
|
# fixing function_def_to_graph_def.
|
||||||
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
graph = ops.Graph()
|
||||||
|
else:
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
|
||||||
if load_shared_name_suffix is None:
|
if load_shared_name_suffix is None:
|
||||||
load_shared_name_suffix = "_load_{}".format(ops.uid())
|
load_shared_name_suffix = "_load_{}".format(ops.uid())
|
||||||
for fdef in _sort_function_defs(library, library_function_names):
|
for fdef in _sort_function_defs(library, library_function_names):
|
||||||
@ -308,18 +320,22 @@ def load_function_def_library(library, load_shared_name_suffix=None):
|
|||||||
# extra function definitions are a no-op since they already imported as a
|
# extra function definitions are a no-op since they already imported as a
|
||||||
# function before and passed in explicitly (due to the topologic sort
|
# function before and passed in explicitly (due to the topologic sort
|
||||||
# import).
|
# import).
|
||||||
func_graph = function_def_lib.function_def_to_graph(copy)
|
with graph.as_default():
|
||||||
|
func_graph = function_def_lib.function_def_to_graph(copy)
|
||||||
_restore_gradient_functions(func_graph, renamed_functions)
|
_restore_gradient_functions(func_graph, renamed_functions)
|
||||||
|
|
||||||
for dep in _list_function_deps(fdef, library_function_names):
|
for dep in _list_function_deps(fdef, library_function_names):
|
||||||
functions[dep].add_to_graph(func_graph)
|
functions[dep].add_to_graph(func_graph)
|
||||||
func = function_lib.ConcreteFunction(func_graph)
|
func = function_lib.ConcreteFunction(func_graph)
|
||||||
func.add_to_graph()
|
func.add_to_graph(graph)
|
||||||
if context.executing_eagerly():
|
|
||||||
func.add_to_graph(ops.get_default_graph())
|
|
||||||
|
|
||||||
functions[fdef.signature.name] = func
|
functions[fdef.signature.name] = func
|
||||||
renamed_functions[func.name] = func
|
renamed_functions[func.name] = func
|
||||||
|
if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
|
||||||
|
# TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
|
||||||
|
# is fixed. Currently it's leaking memory to maintain bug compatibility
|
||||||
|
# with previous behavior.
|
||||||
|
func.add_to_graph(ops.get_default_graph())
|
||||||
|
|
||||||
return functions
|
return functions
|
||||||
|
|
||||||
|
@ -1992,6 +1992,21 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose({"output_0": 13},
|
self.assertAllClose({"output_0": 13},
|
||||||
imported.signatures["serving_default"]())
|
imported.signatures["serving_default"]())
|
||||||
|
|
||||||
|
# TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3
|
||||||
|
# iterations took hundreds of seconds). It would be really nice to check
|
||||||
|
# allocations at a lower level.
|
||||||
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
|
def test_functions_cleaned(self):
|
||||||
|
if sys.version_info.major < 3:
|
||||||
|
self.skipTest("Not working in Python 2")
|
||||||
|
root = module.Module()
|
||||||
|
root.v = variables.Variable(1.)
|
||||||
|
root.f = def_function.function(
|
||||||
|
lambda x: x + root.v,
|
||||||
|
input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)])
|
||||||
|
cycle(root, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user