while_v2: Move more reduction ops to forward graph
When applicable, also move TensorListElementShape and TensorListLength to the forward graph as an optimization to Control Flow v2. PiperOrigin-RevId: 347699857 Change-Id: I98e4bd2df4d79cb7e3d4bc3c2c2f8c86e76aef9a
This commit is contained in:
parent
5ff1abdd7a
commit
deeb7f2e74
@ -616,12 +616,12 @@ def enable_output_all_intermediates(fn):
|
||||
The wrapped function
|
||||
"""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def wrapper(*args, **kwargs):
|
||||
output_all_intermediates_old = \
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
|
||||
try:
|
||||
return fn(self, *args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
|
||||
output_all_intermediates_old
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -594,6 +595,7 @@ class GRUV2Test(keras_parameterized.TestCase):
|
||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||
|
||||
@tf_test_util.enable_output_all_intermediates
|
||||
def test_v1_session_behavior(self):
|
||||
with ops.get_default_graph().as_default():
|
||||
# See b/139132348 for more details.
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import recurrent as rnn_v1
|
||||
@ -795,6 +796,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||
|
||||
@tf_test_util.enable_output_all_intermediates
|
||||
def test_v1_session_behavior(self):
|
||||
with ops.get_default_graph().as_default():
|
||||
# See b/139132348 for more details.
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -629,33 +630,39 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_bidirectional_statefulness(self):
|
||||
# Bidirectional and stateful
|
||||
rnn = keras.layers.SimpleRNN
|
||||
samples = 2
|
||||
dim = 2
|
||||
timesteps = 2
|
||||
output_dim = 2
|
||||
mode = 'sum'
|
||||
def run_test():
|
||||
rnn = keras.layers.SimpleRNN
|
||||
samples = 2
|
||||
dim = 2
|
||||
timesteps = 2
|
||||
output_dim = 2
|
||||
mode = 'sum'
|
||||
|
||||
with self.cached_session():
|
||||
x = np.random.random((samples, timesteps, dim))
|
||||
target_dim = 2 * output_dim if mode == 'concat' else output_dim
|
||||
y = np.random.random((samples, target_dim))
|
||||
with self.cached_session():
|
||||
x = np.random.random((samples, timesteps, dim))
|
||||
target_dim = 2 * output_dim if mode == 'concat' else output_dim
|
||||
y = np.random.random((samples, target_dim))
|
||||
|
||||
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
|
||||
bidi_rnn = keras.layers.Bidirectional(
|
||||
rnn(output_dim, stateful=True), merge_mode=mode)
|
||||
self.assertTrue(bidi_rnn.stateful)
|
||||
output = bidi_rnn(inputs)
|
||||
model = keras.models.Model(inputs, output)
|
||||
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
|
||||
bidi_rnn = keras.layers.Bidirectional(
|
||||
rnn(output_dim, stateful=True), merge_mode=mode)
|
||||
self.assertTrue(bidi_rnn.stateful)
|
||||
output = bidi_rnn(inputs)
|
||||
model = keras.models.Model(inputs, output)
|
||||
|
||||
y_1 = model.predict(x, batch_size=1)
|
||||
model.reset_states()
|
||||
y_2 = model.predict(x, batch_size=1)
|
||||
y_1 = model.predict(x, batch_size=1)
|
||||
model.reset_states()
|
||||
y_2 = model.predict(x, batch_size=1)
|
||||
|
||||
self.assertAllClose(y_1, y_2)
|
||||
self.assertAllClose(y_1, y_2)
|
||||
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
if context.executing_eagerly():
|
||||
run_test()
|
||||
else:
|
||||
tf_test_util.enable_output_all_intermediates(run_test)()
|
||||
|
||||
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
|
||||
def test_Bidirectional_merged_value(self, merge_mode):
|
||||
|
@ -138,10 +138,11 @@ class MapFnTest(test.TestCase):
|
||||
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
|
||||
y = map_fn.map_fn(
|
||||
lambda x: math_ops.multiply(math_ops.square(x), param), elems)
|
||||
r = gradients_impl.gradients(y, param)[0]
|
||||
self.assertAllEqual(91.0, self.evaluate(r))
|
||||
r = gradients_impl.gradients(y, elems)[0]
|
||||
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
|
||||
r_param = gradients_impl.gradients(y, param)[0]
|
||||
r_elems = gradients_impl.gradients(y, elems)[0]
|
||||
self.assertAllEqual(91.0, self.evaluate(r_param))
|
||||
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0],
|
||||
self.evaluate(r_elems))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMap_SimpleNotTensor(self):
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -42,6 +43,7 @@ from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import list_ops
|
||||
@ -1316,6 +1318,50 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
Fn()
|
||||
|
||||
def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self):
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
with backprop.GradientTape() as tape:
|
||||
e = constant_op.constant(2.)
|
||||
x = list_ops.empty_tensor_list(
|
||||
element_dtype=dtypes.float32, element_shape=e.shape)
|
||||
x = list_ops.tensor_list_push_back(x, e)
|
||||
tape.watch(x)
|
||||
|
||||
def Body(i, x):
|
||||
forward_graph = ops.get_default_graph()
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def IdentityWithZeroGrad(x):
|
||||
|
||||
def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name
|
||||
del variables
|
||||
gradient_graph = ops.get_default_graph()
|
||||
shape = gen_list_ops.tensor_list_element_shape(
|
||||
x, shape_type=dtypes.int32)
|
||||
assert shape.graph is forward_graph
|
||||
size = gen_list_ops.tensor_list_length(x)
|
||||
assert size.graph is forward_graph
|
||||
zeros = gen_list_ops.tensor_list_reserve(shape, size,
|
||||
dtypes.float32)
|
||||
assert zeros.graph is gradient_graph
|
||||
return zeros
|
||||
|
||||
return x, Grad
|
||||
|
||||
return i + 1, IdentityWithZeroGrad(x)
|
||||
|
||||
_, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
|
||||
ones_like = list_ops.tensor_list_from_tensor(
|
||||
array_ops.ones_like(
|
||||
list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)),
|
||||
element_shape=tensor_shape.TensorShape([]))
|
||||
grad = tape.gradient(result, x, output_gradients=[ones_like])
|
||||
return grad
|
||||
|
||||
Fn()
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testInheritParentNameScope(self):
|
||||
|
||||
|
@ -23,6 +23,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.client import pywrap_tf_session as c_api
|
||||
from tensorflow.python.eager import backprop_util
|
||||
@ -862,6 +864,19 @@ def _get_accumulator(tensor):
|
||||
return None
|
||||
|
||||
|
||||
OptimizedReductionOpsCacheKey = collections.namedtuple(
|
||||
"OptimizedReductionOpsCacheKey", [
|
||||
"op_type",
|
||||
"inputs",
|
||||
"dtypes",
|
||||
"input_types",
|
||||
"name",
|
||||
"attrs",
|
||||
"op_def",
|
||||
"compute_device",
|
||||
])
|
||||
|
||||
|
||||
class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
"""FuncGraph for the gradient function of the body of a While op.
|
||||
|
||||
@ -957,29 +972,25 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
# This optimization is currently also disabled when under a persistent tape,
|
||||
# since it leads to an unbounded number of side outputs. With caching it may
|
||||
# be possible to re-enable it.
|
||||
if (op_type in {"Shape", "Size", "Rank"} and
|
||||
optimized_reduction_ops = {
|
||||
"Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
|
||||
}
|
||||
if (op_type in optimized_reduction_ops and
|
||||
not util.output_all_intermediates() and
|
||||
all(input.graph is self._forward_graph for input in inputs) and
|
||||
all(_get_accumulator(input) is None for input in inputs) and
|
||||
not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
|
||||
not util.graph_wrapped_for_higher_order_tape_gradients(
|
||||
self._forward_graph)):
|
||||
with self._forward_graph.as_default():
|
||||
# `name` was built using name_scope stack of gradient graph and may not
|
||||
# be unique in the forward graph. `Graph.create_op` does not uniquify
|
||||
# names which are name scopes i.e. end in `/`. To ensure that the op
|
||||
# created gets a unique name in the forward graph we get rid of the
|
||||
# trailing slash.
|
||||
name = ops.name_from_scope_name(name)
|
||||
result = self._forward_graph._create_op_internal(
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes=dtypes,
|
||||
input_types=input_types,
|
||||
name=name,
|
||||
attrs=attrs,
|
||||
op_def=op_def,
|
||||
compute_device=compute_device)
|
||||
return result
|
||||
return self._move_op_to_forward_graph(
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes=dtypes,
|
||||
input_types=input_types,
|
||||
name=name,
|
||||
attrs=attrs,
|
||||
op_def=op_def,
|
||||
compute_device=compute_device)
|
||||
|
||||
return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
|
||||
op_type,
|
||||
@ -991,6 +1002,83 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
op_def=op_def,
|
||||
compute_device=compute_device)
|
||||
|
||||
def _move_op_to_forward_graph(
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes=None, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
op_def=None,
|
||||
compute_device=True):
|
||||
# We have a cache of reduction ops that have already been moved to the
|
||||
# forward graph, and we will check it first to avoid moving an op twice.
|
||||
if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
|
||||
self._forward_graph._optimized_reduction_ops_cache = {}
|
||||
cache_key = self._get_optimized_reduction_ops_cache_key(
|
||||
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
||||
compute_device)
|
||||
cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
|
||||
cache_key)
|
||||
if cached_op is not None:
|
||||
# This op has already been moved to the forward graph and we have it in
|
||||
# the cache.
|
||||
return cached_op
|
||||
|
||||
with self._forward_graph.as_default():
|
||||
# `name` was built using name_scope stack of gradient graph and may not
|
||||
# be unique in the forward graph. `Graph.create_op` does not uniquify
|
||||
# names which are name scopes i.e. end in `/`. To ensure that the op
|
||||
# created gets a unique name in the forward graph we get rid of the
|
||||
# trailing slash.
|
||||
name = ops.name_from_scope_name(name)
|
||||
result = self._forward_graph._create_op_internal(
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes=dtypes,
|
||||
input_types=input_types,
|
||||
name=name,
|
||||
attrs=attrs,
|
||||
op_def=op_def,
|
||||
compute_device=compute_device)
|
||||
|
||||
# Store the op we just moved to the forward graph so that it does
|
||||
# not need to be added there again.
|
||||
self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
def _get_optimized_reduction_ops_cache_key(
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes=None, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
op_def=None,
|
||||
compute_device=True):
|
||||
# We need all elements of CacheKey to be hashable.
|
||||
inputs = tuple(map(lambda t: t.ref(), inputs))
|
||||
|
||||
if dtypes is not None:
|
||||
dtypes = tuple(dtypes)
|
||||
|
||||
if input_types is not None:
|
||||
input_types = tuple(input_types)
|
||||
|
||||
if attrs is not None:
|
||||
hashable_attrs = []
|
||||
for attr_name, attr_value in sorted(attrs.items()):
|
||||
hashable_attrs.append((attr_name, attr_value.SerializeToString()))
|
||||
attrs = tuple(hashable_attrs)
|
||||
|
||||
if op_def is not None:
|
||||
op_def = op_def.SerializeToString()
|
||||
|
||||
return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
|
||||
name, attrs, op_def, compute_device)
|
||||
|
||||
def _capture_helper(self, tensor, name):
|
||||
"""Implements the capturing described in the class docstring."""
|
||||
captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
|
||||
|
Loading…
Reference in New Issue
Block a user