From 216a46b13082e911e88efefe0b3ad40460b1cea0 Mon Sep 17 00:00:00 2001 From: Yanhua Sun Date: Thu, 5 Sep 2019 17:00:30 -0700 Subject: [PATCH] Fix memory issue in control flow PiperOrigin-RevId: 267487468 --- .../python/framework/function_def_to_graph.py | 30 +++++++++-- .../python/kernel_tests/while_v2_test.py | 51 +++++++++++++++---- tensorflow/python/ops/control_flow_util_v2.py | 17 ++++++- 3 files changed, 81 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py index 7e12dffb9a2..cd5d491ff89 100644 --- a/tensorflow/python/framework/function_def_to_graph.py +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 @@ -29,6 +30,7 @@ from tensorflow.python.framework import versions from tensorflow.python.framework.func_graph import FuncGraph +# TODO(yanhuasun): remove copy_functions. def function_def_to_graph(fdef, input_shapes=None, copy_functions=True): """Converts a FunctionDef to a FuncGraph (sub-class Graph). @@ -111,7 +113,14 @@ def is_function(fname): if context.executing_eagerly(): return context.context().has_function(fname) else: - return ops.get_default_graph()._is_function(fname) # pylint: disable=protected-access + graph = ops.get_default_graph() + while graph is not None: + if graph._is_function(fname): # pylint: disable=protected-access + return True + if hasattr(graph, "outer_graph"): + graph = graph.outer_graph + else: + return False def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): @@ -201,17 +210,28 @@ def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): nested_to_flat_tensor_name[control_name] = control_name for node_def in fdef.node_def: - f = default_graph._functions.get(node_def.op, None) # pylint: disable=protected-access - if f is not None and hasattr(f, "signature"): - op_def = f.signature + graph = default_graph + while True: + f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access + if f is not None or not hasattr(graph, "outer_graph"): + break + graph = graph.outer_graph + + if f is not None: + op_def = f.definition.signature if node_def.op not in copied_functions: # Since this function is referenced as an op type, we have no choice but # to copy it into the GraphDef if we want downstream tools to process # it. graph_def.library.function.add().CopyFrom(f.definition) copied_functions.add(node_def.op) + if f.grad_func_name: + grad_def = function_pb2.GradientDef() + grad_def.function_name = f.name + grad_def.gradient_func = f.grad_func_name + graph_def.library.gradient.extend([grad_def]) else: - op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access + op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access for attr in op_def.attr: if attr.type == "func": diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index fefeb594bea..3bb88c9aad3 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -25,11 +25,9 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.ops import control_flow_util_v2 -from tensorflow.python.ops import control_flow_v2_toggles -from tensorflow.python.ops import random_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -37,16 +35,18 @@ from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import control_flow_util_v2 +from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.ops import while_v2 from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 from tensorflow.python.platform import test - def random_gamma(shape): # pylint: disable=invalid-name return random_ops.random_gamma(shape, 1.0) @@ -859,9 +859,41 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): Body, [m, sum_of_powers], return_same_structure=False)[1] grad = gradients_impl.gradients(result, [n]) - with self.cached_session() as sess: - self.assertEqual(self.evaluate(result), 364.) - self.assertSequenceEqual(self.evaluate(grad), [547.]) + self.assertEqual(self.evaluate(result), 364.) + self.assertSequenceEqual(self.evaluate(grad), [547.]) + + @test_util.run_deprecated_v1 + def testNestedWhileWithLegacyDefun(self): + n = constant_op.constant(3.) + m = constant_op.constant(5.) + sum_of_powers = constant_op.constant(0.) + + def Body(i, previous_sum): + prod = constant_op.constant(1.) + + def InnerBodyWrapper(c, v): + + @function.Defun(dtypes.float32, dtypes.float32) + def InnerBody(c, v): + return c - 1., v * n + + results = InnerBody(c, v) + results[0].set_shape([]) + results[1].set_shape([]) + return results + + return i - 1., previous_sum + while_loop_v2( + lambda c, _: c > 0, + InnerBodyWrapper, [i, prod], + return_same_structure=False)[1] + + result = while_loop_v2( + lambda i, _: i >= 0, + Body, [m, sum_of_powers], + return_same_structure=False)[1] + grad = gradients_impl.gradients(result, [n]) + self.assertEqual(self.evaluate(result), 364.) + self.assertSequenceEqual(self.evaluate(grad), [547.]) @test_util.run_deprecated_v1 def testIdentityNodeInBody(self): @@ -875,9 +907,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): ret = while_loop_v2( lambda v: v < 8., Body, [x], return_same_structure=False) grad = gradients_impl.gradients(ret, [x]) - with self.cached_session() as sess: - self.assertEqual(self.evaluate(ret), 16.) - self.assertSequenceEqual(self.evaluate(grad), [32.]) + self.assertEqual(self.evaluate(ret), 16.) + self.assertSequenceEqual(self.evaluate(grad), [32.]) @test_util.run_deprecated_v1 def testForwardPassRewrite(self): diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 3aec9192698..0f9a1d4ef9e 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_v2_func_graphs from tensorflow.python.util import tf_contextlib + _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph @@ -258,7 +259,18 @@ def output_all_intermediates(): def get_func_graph(op, input_shapes, func_name): """Generates and returns a FuncGraph for the given op and input_shapes.""" - fdef = op.graph._get_function(func_name).definition # pylint: disable=protected-access + graph = op.graph + # Recursively search the func in graphs. + while graph is not None: + func = graph._get_function(func_name) # pylint: disable=protected-access + if func is not None: + fdef = func.definition + break + if hasattr(graph, "outer_graph"): + graph = graph.outer_graph + else: + break + # `op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `op.graph` as its @@ -266,5 +278,6 @@ def get_func_graph(op, input_shapes, func_name): # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with op.graph.as_default(): - func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + func_graph = function_def_to_graph.function_def_to_graph( + fdef, input_shapes, copy_functions=False) return func_graph