Fix memory issue in control flow

PiperOrigin-RevId: 267487468
This commit is contained in:
Yanhua Sun 2019-09-05 17:00:30 -07:00 committed by TensorFlower Gardener
parent 1e2c750e62
commit 216a46b130
3 changed files with 81 additions and 17 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
@ -29,6 +30,7 @@ from tensorflow.python.framework import versions
from tensorflow.python.framework.func_graph import FuncGraph
# TODO(yanhuasun): remove copy_functions.
def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
"""Converts a FunctionDef to a FuncGraph (sub-class Graph).
@ -111,7 +113,14 @@ def is_function(fname):
if context.executing_eagerly():
return context.context().has_function(fname)
else:
return ops.get_default_graph()._is_function(fname) # pylint: disable=protected-access
graph = ops.get_default_graph()
while graph is not None:
if graph._is_function(fname): # pylint: disable=protected-access
return True
if hasattr(graph, "outer_graph"):
graph = graph.outer_graph
else:
return False
def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
@ -201,17 +210,28 @@ def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
nested_to_flat_tensor_name[control_name] = control_name
for node_def in fdef.node_def:
f = default_graph._functions.get(node_def.op, None) # pylint: disable=protected-access
if f is not None and hasattr(f, "signature"):
op_def = f.signature
graph = default_graph
while True:
f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access
if f is not None or not hasattr(graph, "outer_graph"):
break
graph = graph.outer_graph
if f is not None:
op_def = f.definition.signature
if node_def.op not in copied_functions:
# Since this function is referenced as an op type, we have no choice but
# to copy it into the GraphDef if we want downstream tools to process
# it.
graph_def.library.function.add().CopyFrom(f.definition)
copied_functions.add(node_def.op)
if f.grad_func_name:
grad_def = function_pb2.GradientDef()
grad_def.function_name = f.name
grad_def.gradient_func = f.grad_func_name
graph_def.library.gradient.extend([grad_def])
else:
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access
for attr in op_def.attr:
if attr.type == "func":

View File

@ -25,11 +25,9 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import random_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -37,16 +35,18 @@ from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_v2
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
from tensorflow.python.platform import test
def random_gamma(shape): # pylint: disable=invalid-name
return random_ops.random_gamma(shape, 1.0)
@ -859,9 +859,41 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
Body, [m, sum_of_powers],
return_same_structure=False)[1]
grad = gradients_impl.gradients(result, [n])
with self.cached_session() as sess:
self.assertEqual(self.evaluate(result), 364.)
self.assertSequenceEqual(self.evaluate(grad), [547.])
self.assertEqual(self.evaluate(result), 364.)
self.assertSequenceEqual(self.evaluate(grad), [547.])
@test_util.run_deprecated_v1
def testNestedWhileWithLegacyDefun(self):
n = constant_op.constant(3.)
m = constant_op.constant(5.)
sum_of_powers = constant_op.constant(0.)
def Body(i, previous_sum):
prod = constant_op.constant(1.)
def InnerBodyWrapper(c, v):
@function.Defun(dtypes.float32, dtypes.float32)
def InnerBody(c, v):
return c - 1., v * n
results = InnerBody(c, v)
results[0].set_shape([])
results[1].set_shape([])
return results
return i - 1., previous_sum + while_loop_v2(
lambda c, _: c > 0,
InnerBodyWrapper, [i, prod],
return_same_structure=False)[1]
result = while_loop_v2(
lambda i, _: i >= 0,
Body, [m, sum_of_powers],
return_same_structure=False)[1]
grad = gradients_impl.gradients(result, [n])
self.assertEqual(self.evaluate(result), 364.)
self.assertSequenceEqual(self.evaluate(grad), [547.])
@test_util.run_deprecated_v1
def testIdentityNodeInBody(self):
@ -875,9 +907,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
ret = while_loop_v2(
lambda v: v < 8., Body, [x], return_same_structure=False)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session() as sess:
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
@test_util.run_deprecated_v1
def testForwardPassRewrite(self):

View File

@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.util import tf_contextlib
_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph
@ -258,7 +259,18 @@ def output_all_intermediates():
def get_func_graph(op, input_shapes, func_name):
"""Generates and returns a FuncGraph for the given op and input_shapes."""
fdef = op.graph._get_function(func_name).definition # pylint: disable=protected-access
graph = op.graph
# Recursively search the func in graphs.
while graph is not None:
func = graph._get_function(func_name) # pylint: disable=protected-access
if func is not None:
fdef = func.definition
break
if hasattr(graph, "outer_graph"):
graph = graph.outer_graph
else:
break
# `op.graph` may not be the same as `ops.get_default_graph()` e.g.
# in the case of nested if ops or when the gradient is being computed
# from inside a Defun. We build the `func_graph` with `op.graph` as its
@ -266,5 +278,6 @@ def get_func_graph(op, input_shapes, func_name):
# forward pass. We need this so that we can resolve references to tensors
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
with op.graph.as_default():
func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
func_graph = function_def_to_graph.function_def_to_graph(
fdef, input_shapes, copy_functions=False)
return func_graph