Fix memory issue in control flow
PiperOrigin-RevId: 267487468
This commit is contained in:
parent
1e2c750e62
commit
216a46b130
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
@ -29,6 +30,7 @@ from tensorflow.python.framework import versions
|
||||
from tensorflow.python.framework.func_graph import FuncGraph
|
||||
|
||||
|
||||
# TODO(yanhuasun): remove copy_functions.
|
||||
def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
|
||||
"""Converts a FunctionDef to a FuncGraph (sub-class Graph).
|
||||
|
||||
@ -111,7 +113,14 @@ def is_function(fname):
|
||||
if context.executing_eagerly():
|
||||
return context.context().has_function(fname)
|
||||
else:
|
||||
return ops.get_default_graph()._is_function(fname) # pylint: disable=protected-access
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
if graph._is_function(fname): # pylint: disable=protected-access
|
||||
return True
|
||||
if hasattr(graph, "outer_graph"):
|
||||
graph = graph.outer_graph
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
|
||||
@ -201,17 +210,28 @@ def function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
|
||||
nested_to_flat_tensor_name[control_name] = control_name
|
||||
|
||||
for node_def in fdef.node_def:
|
||||
f = default_graph._functions.get(node_def.op, None) # pylint: disable=protected-access
|
||||
if f is not None and hasattr(f, "signature"):
|
||||
op_def = f.signature
|
||||
graph = default_graph
|
||||
while True:
|
||||
f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access
|
||||
if f is not None or not hasattr(graph, "outer_graph"):
|
||||
break
|
||||
graph = graph.outer_graph
|
||||
|
||||
if f is not None:
|
||||
op_def = f.definition.signature
|
||||
if node_def.op not in copied_functions:
|
||||
# Since this function is referenced as an op type, we have no choice but
|
||||
# to copy it into the GraphDef if we want downstream tools to process
|
||||
# it.
|
||||
graph_def.library.function.add().CopyFrom(f.definition)
|
||||
copied_functions.add(node_def.op)
|
||||
if f.grad_func_name:
|
||||
grad_def = function_pb2.GradientDef()
|
||||
grad_def.function_name = f.name
|
||||
grad_def.gradient_func = f.grad_func_name
|
||||
graph_def.library.gradient.extend([grad_def])
|
||||
else:
|
||||
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
|
||||
op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access
|
||||
|
||||
for attr in op_def.attr:
|
||||
if attr.type == "func":
|
||||
|
@ -25,11 +25,9 @@ from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -37,16 +35,18 @@ from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import while_v2
|
||||
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def random_gamma(shape): # pylint: disable=invalid-name
|
||||
return random_ops.random_gamma(shape, 1.0)
|
||||
|
||||
@ -859,9 +859,41 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
Body, [m, sum_of_powers],
|
||||
return_same_structure=False)[1]
|
||||
grad = gradients_impl.gradients(result, [n])
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(self.evaluate(result), 364.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [547.])
|
||||
self.assertEqual(self.evaluate(result), 364.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [547.])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNestedWhileWithLegacyDefun(self):
|
||||
n = constant_op.constant(3.)
|
||||
m = constant_op.constant(5.)
|
||||
sum_of_powers = constant_op.constant(0.)
|
||||
|
||||
def Body(i, previous_sum):
|
||||
prod = constant_op.constant(1.)
|
||||
|
||||
def InnerBodyWrapper(c, v):
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def InnerBody(c, v):
|
||||
return c - 1., v * n
|
||||
|
||||
results = InnerBody(c, v)
|
||||
results[0].set_shape([])
|
||||
results[1].set_shape([])
|
||||
return results
|
||||
|
||||
return i - 1., previous_sum + while_loop_v2(
|
||||
lambda c, _: c > 0,
|
||||
InnerBodyWrapper, [i, prod],
|
||||
return_same_structure=False)[1]
|
||||
|
||||
result = while_loop_v2(
|
||||
lambda i, _: i >= 0,
|
||||
Body, [m, sum_of_powers],
|
||||
return_same_structure=False)[1]
|
||||
grad = gradients_impl.gradients(result, [n])
|
||||
self.assertEqual(self.evaluate(result), 364.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [547.])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testIdentityNodeInBody(self):
|
||||
@ -875,9 +907,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
ret = while_loop_v2(
|
||||
lambda v: v < 8., Body, [x], return_same_structure=False)
|
||||
grad = gradients_impl.gradients(ret, [x])
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(self.evaluate(ret), 16.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||
self.assertEqual(self.evaluate(ret), 16.)
|
||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testForwardPassRewrite(self):
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
|
||||
|
||||
CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph
|
||||
@ -258,7 +259,18 @@ def output_all_intermediates():
|
||||
|
||||
def get_func_graph(op, input_shapes, func_name):
|
||||
"""Generates and returns a FuncGraph for the given op and input_shapes."""
|
||||
fdef = op.graph._get_function(func_name).definition # pylint: disable=protected-access
|
||||
graph = op.graph
|
||||
# Recursively search the func in graphs.
|
||||
while graph is not None:
|
||||
func = graph._get_function(func_name) # pylint: disable=protected-access
|
||||
if func is not None:
|
||||
fdef = func.definition
|
||||
break
|
||||
if hasattr(graph, "outer_graph"):
|
||||
graph = graph.outer_graph
|
||||
else:
|
||||
break
|
||||
|
||||
# `op.graph` may not be the same as `ops.get_default_graph()` e.g.
|
||||
# in the case of nested if ops or when the gradient is being computed
|
||||
# from inside a Defun. We build the `func_graph` with `op.graph` as its
|
||||
@ -266,5 +278,6 @@ def get_func_graph(op, input_shapes, func_name):
|
||||
# forward pass. We need this so that we can resolve references to tensors
|
||||
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
|
||||
with op.graph.as_default():
|
||||
func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
|
||||
func_graph = function_def_to_graph.function_def_to_graph(
|
||||
fdef, input_shapes, copy_functions=False)
|
||||
return func_graph
|
||||
|
Loading…
Reference in New Issue
Block a user