while_v2: Move more reduction ops to forward graph

When applicable, also move TensorListElementShape and
TensorListLength to the forward graph as an optimization
to Control Flow v2.

PiperOrigin-RevId: 347699857
Change-Id: I98e4bd2df4d79cb7e3d4bc3c2c2f8c86e76aef9a
This commit is contained in:
Victor de Souza 2020-12-15 15:01:54 -08:00 committed by TensorFlower Gardener
parent 5ff1abdd7a
commit deeb7f2e74
7 changed files with 192 additions and 46 deletions

View File

@ -616,12 +616,12 @@ def enable_output_all_intermediates(fn):
The wrapped function
"""
def wrapper(self, *args, **kwargs):
def wrapper(*args, **kwargs):
output_all_intermediates_old = \
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
try:
return fn(self, *args, **kwargs)
return fn(*args, **kwargs)
finally:
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
output_all_intermediates_old

View File

@ -34,6 +34,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
@ -594,6 +595,7 @@ class GRUV2Test(keras_parameterized.TestCase):
outputs_trimmed = lstm(inputs[:, :masksteps])
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
@tf_test_util.enable_output_all_intermediates
def test_v1_session_behavior(self):
with ops.get_default_graph().as_default():
# See b/139132348 for more details.

View File

@ -35,6 +35,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import recurrent as rnn_v1
@ -795,6 +796,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
outputs_trimmed = lstm(inputs[:, :masksteps])
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
@tf_test_util.enable_output_all_intermediates
def test_v1_session_behavior(self):
with ops.get_default_graph().as_default():
# See b/139132348 for more details.

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
@ -629,33 +630,39 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
def test_bidirectional_statefulness(self):
# Bidirectional and stateful
rnn = keras.layers.SimpleRNN
samples = 2
dim = 2
timesteps = 2
output_dim = 2
mode = 'sum'
def run_test():
rnn = keras.layers.SimpleRNN
samples = 2
dim = 2
timesteps = 2
output_dim = 2
mode = 'sum'
with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
bidi_rnn = keras.layers.Bidirectional(
rnn(output_dim, stateful=True), merge_mode=mode)
self.assertTrue(bidi_rnn.stateful)
output = bidi_rnn(inputs)
model = keras.models.Model(inputs, output)
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
bidi_rnn = keras.layers.Bidirectional(
rnn(output_dim, stateful=True), merge_mode=mode)
self.assertTrue(bidi_rnn.stateful)
output = bidi_rnn(inputs)
model = keras.models.Model(inputs, output)
y_1 = model.predict(x, batch_size=1)
model.reset_states()
y_2 = model.predict(x, batch_size=1)
y_1 = model.predict(x, batch_size=1)
model.reset_states()
y_2 = model.predict(x, batch_size=1)
self.assertAllClose(y_1, y_2)
self.assertAllClose(y_1, y_2)
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
if context.executing_eagerly():
run_test()
else:
tf_test_util.enable_output_all_intermediates(run_test)()
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
def test_Bidirectional_merged_value(self, merge_mode):

View File

@ -138,10 +138,11 @@ class MapFnTest(test.TestCase):
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
y = map_fn.map_fn(
lambda x: math_ops.multiply(math_ops.square(x), param), elems)
r = gradients_impl.gradients(y, param)[0]
self.assertAllEqual(91.0, self.evaluate(r))
r = gradients_impl.gradients(y, elems)[0]
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
r_param = gradients_impl.gradients(y, param)[0]
r_elems = gradients_impl.gradients(y, elems)[0]
self.assertAllEqual(91.0, self.evaluate(r_param))
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0],
self.evaluate(r_elems))
@test_util.run_in_graph_and_eager_modes
def testMap_SimpleNotTensor(self):

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
@ -42,6 +43,7 @@ from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@ -1316,6 +1318,50 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
Fn()
def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self):
@def_function.function
def Fn():
with backprop.GradientTape() as tape:
e = constant_op.constant(2.)
x = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=e.shape)
x = list_ops.tensor_list_push_back(x, e)
tape.watch(x)
def Body(i, x):
forward_graph = ops.get_default_graph()
@custom_gradient.custom_gradient
def IdentityWithZeroGrad(x):
def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name
del variables
gradient_graph = ops.get_default_graph()
shape = gen_list_ops.tensor_list_element_shape(
x, shape_type=dtypes.int32)
assert shape.graph is forward_graph
size = gen_list_ops.tensor_list_length(x)
assert size.graph is forward_graph
zeros = gen_list_ops.tensor_list_reserve(shape, size,
dtypes.float32)
assert zeros.graph is gradient_graph
return zeros
return x, Grad
return i + 1, IdentityWithZeroGrad(x)
_, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
ones_like = list_ops.tensor_list_from_tensor(
array_ops.ones_like(
list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)),
element_shape=tensor_shape.TensorShape([]))
grad = tape.gradient(result, x, output_gradients=[ones_like])
return grad
Fn()
@test_util.run_v2_only
def testInheritParentNameScope(self):

View File

@ -23,6 +23,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.eager import backprop_util
@ -862,6 +864,19 @@ def _get_accumulator(tensor):
return None
OptimizedReductionOpsCacheKey = collections.namedtuple(
"OptimizedReductionOpsCacheKey", [
"op_type",
"inputs",
"dtypes",
"input_types",
"name",
"attrs",
"op_def",
"compute_device",
])
class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
"""FuncGraph for the gradient function of the body of a While op.
@ -957,29 +972,25 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
# This optimization is currently also disabled when under a persistent tape,
# since it leads to an unbounded number of side outputs. With caching it may
# be possible to re-enable it.
if (op_type in {"Shape", "Size", "Rank"} and
optimized_reduction_ops = {
"Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
}
if (op_type in optimized_reduction_ops and
not util.output_all_intermediates() and
all(input.graph is self._forward_graph for input in inputs) and
all(_get_accumulator(input) is None for input in inputs) and
not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
not util.graph_wrapped_for_higher_order_tape_gradients(
self._forward_graph)):
with self._forward_graph.as_default():
# `name` was built using name_scope stack of gradient graph and may not
# be unique in the forward graph. `Graph.create_op` does not uniquify
# names which are name scopes i.e. end in `/`. To ensure that the op
# created gets a unique name in the forward graph we get rid of the
# trailing slash.
name = ops.name_from_scope_name(name)
result = self._forward_graph._create_op_internal(
op_type,
inputs,
dtypes=dtypes,
input_types=input_types,
name=name,
attrs=attrs,
op_def=op_def,
compute_device=compute_device)
return result
return self._move_op_to_forward_graph(
op_type,
inputs,
dtypes=dtypes,
input_types=input_types,
name=name,
attrs=attrs,
op_def=op_def,
compute_device=compute_device)
return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
op_type,
@ -991,6 +1002,83 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
op_def=op_def,
compute_device=compute_device)
def _move_op_to_forward_graph(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
# We have a cache of reduction ops that have already been moved to the
# forward graph, and we will check it first to avoid moving an op twice.
if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
self._forward_graph._optimized_reduction_ops_cache = {}
cache_key = self._get_optimized_reduction_ops_cache_key(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device)
cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
cache_key)
if cached_op is not None:
# This op has already been moved to the forward graph and we have it in
# the cache.
return cached_op
with self._forward_graph.as_default():
# `name` was built using name_scope stack of gradient graph and may not
# be unique in the forward graph. `Graph.create_op` does not uniquify
# names which are name scopes i.e. end in `/`. To ensure that the op
# created gets a unique name in the forward graph we get rid of the
# trailing slash.
name = ops.name_from_scope_name(name)
result = self._forward_graph._create_op_internal(
op_type,
inputs,
dtypes=dtypes,
input_types=input_types,
name=name,
attrs=attrs,
op_def=op_def,
compute_device=compute_device)
# Store the op we just moved to the forward graph so that it does
# not need to be added there again.
self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
return result
def _get_optimized_reduction_ops_cache_key(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
# We need all elements of CacheKey to be hashable.
inputs = tuple(map(lambda t: t.ref(), inputs))
if dtypes is not None:
dtypes = tuple(dtypes)
if input_types is not None:
input_types = tuple(input_types)
if attrs is not None:
hashable_attrs = []
for attr_name, attr_value in sorted(attrs.items()):
hashable_attrs.append((attr_name, attr_value.SerializeToString()))
attrs = tuple(hashable_attrs)
if op_def is not None:
op_def = op_def.SerializeToString()
return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
name, attrs, op_def, compute_device)
def _capture_helper(self, tensor, name):
"""Implements the capturing described in the class docstring."""
captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))