Reenable autograph in Layer.__call__
and add safeguards to prevent users from writing invalid static control flow.
PiperOrigin-RevId: 245151330
This commit is contained in:
parent
8be60a9ce8
commit
cf31e9001c
@ -1357,6 +1357,19 @@ tf_py_test(
|
||||
tags = ["no_rocm"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "control_flow_test",
|
||||
size = "medium",
|
||||
srcs = ["engine/control_flow_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
shard_count = 8,
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "hdf5_format_test",
|
||||
size = "medium",
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python import autograph
|
||||
from tensorflow.python.distribute import values as distribute_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
@ -594,6 +595,21 @@ class Layer(module.Module):
|
||||
# Build layer if applicable (if the `build` method has been
|
||||
# overridden).
|
||||
self._maybe_build(inputs)
|
||||
|
||||
# Wrapping `call` function in autograph to allow for dynamic control
|
||||
# dependencies in call. We are limiting this to subclassed layers as
|
||||
# autograph is strictly needed only for subclassed layers.
|
||||
if base_layer_utils.is_subclassed(self):
|
||||
decorators, original_func = tf_decorator.unwrap(self.call)
|
||||
converted_func = autograph.convert(recursive=True)(original_func)
|
||||
if decorators:
|
||||
call_fn = tf_decorator.rewrap(self.call, original_func,
|
||||
converted_func)
|
||||
else:
|
||||
call_fn = converted_func
|
||||
else:
|
||||
call_fn = self.call
|
||||
|
||||
# Explicitly pass the learning phase placeholder to `call` if
|
||||
# the `training` argument was left unspecified by the user.
|
||||
# This behavior is restricted to the managed Keras FuncGraph.
|
||||
@ -613,20 +629,18 @@ class Layer(module.Module):
|
||||
self._mixed_precision_policy.should_cast_variables), (
|
||||
base_layer_utils.AutoAddUpdates(self,
|
||||
inputs)) as auto_updater:
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
outputs = call_fn(inputs, *args, **kwargs)
|
||||
auto_updater.set_outputs(outputs)
|
||||
|
||||
except TypeError as e:
|
||||
messages = ('`tf.Tensor` as a Python `bool` is not allowed',
|
||||
'Tensor objects are only iterable when eager')
|
||||
exception_str = str(e)
|
||||
for msg in messages:
|
||||
if msg in exception_str:
|
||||
raise TypeError('You are attempting to use Python control '
|
||||
'flow in a layer that was not declared to be '
|
||||
'dynamic. Pass `dynamic=True` to the class '
|
||||
'constructor.\nEncountered error:\n"""\n' +
|
||||
exception_str + '\n"""')
|
||||
exception_msg = 'Tensor objects are only iterable when eager'
|
||||
if exception_msg in exception_str:
|
||||
raise TypeError('You are attempting to use Python control '
|
||||
'flow in a layer that was not declared to be '
|
||||
'dynamic. Pass `dynamic=True` to the class '
|
||||
'constructor.\nEncountered error:\n"""\n' +
|
||||
exception_str + '\n"""')
|
||||
raise
|
||||
else:
|
||||
# We will use static shape inference to return symbolic tensors
|
||||
@ -775,7 +789,7 @@ class Layer(module.Module):
|
||||
class MyLayer(tf.keras.layers.Layer):
|
||||
def call(inputs, self):
|
||||
self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True)
|
||||
return 2*inputs
|
||||
return inputs
|
||||
```
|
||||
|
||||
This method can also be called directly on a Functional Model during
|
||||
@ -831,6 +845,7 @@ class Layer(module.Module):
|
||||
return None # Will be filtered out when computing the .losses property
|
||||
if not tensor_util.is_tensor(loss):
|
||||
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
|
||||
base_layer_utils.check_graph_consistency(loss, method='add_loss')
|
||||
loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
|
||||
return loss
|
||||
|
||||
@ -842,12 +857,15 @@ class Layer(module.Module):
|
||||
for loss in losses:
|
||||
if callable(loss):
|
||||
callable_losses.append(functools.partial(_tag_unconditional, loss))
|
||||
elif tf_utils.is_symbolic_tensor(loss):
|
||||
continue
|
||||
if loss is None:
|
||||
continue
|
||||
if not tensor_util.is_tensor(loss):
|
||||
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
|
||||
if tf_utils.is_symbolic_tensor(loss):
|
||||
symbolic_losses.append(_tag_unconditional(loss))
|
||||
elif tensor_util.is_tensor(loss):
|
||||
eager_losses.append(_tag_unconditional(loss))
|
||||
elif loss is not None: # `None` is valid but should be ignored.
|
||||
raise ValueError('Found non-Tensor loss: ' + str(loss))
|
||||
|
||||
self._callable_losses += callable_losses
|
||||
|
||||
@ -1005,14 +1023,24 @@ class Layer(module.Module):
|
||||
return # Updates already applied when in eager mode.
|
||||
|
||||
def process_update(x):
|
||||
"""Standardize update ops.
|
||||
|
||||
Arguments:
|
||||
x: Tensor, op, or callable.
|
||||
|
||||
Returns:
|
||||
An update op.
|
||||
"""
|
||||
if callable(x):
|
||||
x = x()
|
||||
if isinstance(x, ops.Operation):
|
||||
return x
|
||||
update = x
|
||||
elif hasattr(x, 'op'):
|
||||
return x.op
|
||||
update = x.op
|
||||
else:
|
||||
return ops.convert_to_tensor(x)
|
||||
update = ops.convert_to_tensor(x)
|
||||
base_layer_utils.check_graph_consistency(update, method='add_update')
|
||||
return update
|
||||
|
||||
updates = [process_update(x) for x in updates]
|
||||
self._updates += updates
|
||||
@ -1502,6 +1530,7 @@ class Layer(module.Module):
|
||||
self._metrics.append(metric_obj)
|
||||
|
||||
def _symbolic_add_metric(self, value, aggregation=None, name=None):
|
||||
base_layer_utils.check_graph_consistency(value, method='add_metric')
|
||||
match = self._get_existing_metric(name)
|
||||
if aggregation is None:
|
||||
# Iterate over the metrics and check if the given metric exists already.
|
||||
|
@ -28,6 +28,7 @@ import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -40,35 +41,22 @@ from tensorflow.python.layers import core as legacy_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DynamicLayer1(base_layer.Layer):
|
||||
class DynamicLayer(base_layer.Layer):
|
||||
|
||||
def __init__(self, dynamic=False, **kwargs):
|
||||
super(DynamicLayer1, self).__init__(dynamic=dynamic, **kwargs)
|
||||
super(DynamicLayer, self).__init__(dynamic=dynamic, **kwargs)
|
||||
|
||||
def call(self, inputs):
|
||||
if math_ops.reduce_sum(inputs) > 0:
|
||||
return math_ops.sqrt(inputs)
|
||||
else:
|
||||
return math_ops.square(inputs)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
|
||||
class DynamicLayer2(base_layer.Layer):
|
||||
|
||||
def __init__(self, dynamic=False, **kwargs):
|
||||
super(DynamicLayer2, self).__init__(dynamic=dynamic, **kwargs)
|
||||
|
||||
def call(self, inputs):
|
||||
samples = []
|
||||
for sample in inputs:
|
||||
samples.append(math_ops.square(sample))
|
||||
return array_ops.stack(samples, axis=0)
|
||||
samples = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
for idx, sample in enumerate(inputs):
|
||||
samples = samples.write(idx, math_ops.square(sample))
|
||||
return samples.stack()
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
@ -82,34 +70,39 @@ class InvalidLayer(base_layer.Layer):
|
||||
|
||||
class BaseLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(DynamicLayer1, DynamicLayer2)
|
||||
def test_dynamic_layer_in_functional_model_in_graph_mode(self, layer_class):
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_dynamic_layer(self):
|
||||
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||
input_shape=(3,))
|
||||
self.assertEqual(model.dynamic, True)
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_dynamic_layer_error(self):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'attempting to use Python control flow'):
|
||||
model = testing_utils.get_model_from_layers([DynamicLayer()],
|
||||
input_shape=(3,))
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_dynamic_layer_error_running_in_graph_mode(self):
|
||||
with context.graph_mode():
|
||||
inputs = keras.Input((3,))
|
||||
# Works when `dynamic=True` is declared.
|
||||
outputs = layer_class(dynamic=True)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||
input_shape=(3,))
|
||||
self.assertEqual(model.dynamic, True)
|
||||
# But then you cannot run the model since you're in a graph scope.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'You must enable eager execution'):
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
|
||||
# Fails when `dynamic=True` not declared.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'attempting to use Python control flow'):
|
||||
_ = layer_class()(inputs)
|
||||
|
||||
@parameterized.parameters(DynamicLayer1, DynamicLayer2)
|
||||
def test_dynamic_layer_in_functional_model_in_eager_mode(self, layer_class):
|
||||
inputs = keras.Input((3,))
|
||||
# Fails when `dynamic=True` not declared.
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'attempting to use Python control flow'):
|
||||
_ = layer_class()(inputs)
|
||||
# Works when `dynamic=True` is declared.
|
||||
outputs = layer_class(dynamic=True)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
def test_dynamic_layer_with_deferred_sequential_model(self):
|
||||
model = keras.Sequential(
|
||||
[DynamicLayer(dynamic=True),
|
||||
keras.layers.Dense(3)])
|
||||
self.assertEqual(model.dynamic, True)
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
@ -117,12 +110,12 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_nested_dynamic_layers_in_eager_mode(self):
|
||||
inputs = keras.Input((3,))
|
||||
outputs = DynamicLayer1(dynamic=True)(inputs)
|
||||
outputs = DynamicLayer(dynamic=True)(inputs)
|
||||
inner_model = keras.Model(inputs, outputs)
|
||||
self.assertEqual(inner_model.dynamic, True)
|
||||
|
||||
inputs = keras.Input((3,))
|
||||
x = DynamicLayer2(dynamic=True)(inputs)
|
||||
x = DynamicLayer(dynamic=True)(inputs)
|
||||
outputs = inner_model(x)
|
||||
|
||||
model = keras.Model(inputs, outputs)
|
||||
@ -131,41 +124,6 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
def test_dynamic_layers_in_sequential_model(self):
|
||||
# Without input_shape argument
|
||||
model = keras.Sequential([DynamicLayer1(dynamic=True),
|
||||
keras.layers.Dense(3),
|
||||
DynamicLayer2(dynamic=True)])
|
||||
self.assertEqual(model.dynamic, True)
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
# With input_shape argument
|
||||
model = keras.Sequential([DynamicLayer1(dynamic=True, input_shape=(3,)),
|
||||
DynamicLayer2(dynamic=True)])
|
||||
self.assertEqual(model.dynamic, True)
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
def test_dynamic_layers_in_subclassed_model(self):
|
||||
|
||||
class MyModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.layer1 = DynamicLayer1(dynamic=True)
|
||||
|
||||
def call(self, inputs):
|
||||
return self.layer1(inputs)
|
||||
|
||||
model = MyModel()
|
||||
self.assertEqual(model.dynamic, True)
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
def test_dynamic_subclassed_model_no_shape_inference(self):
|
||||
|
||||
class MyModel(keras.Model):
|
||||
@ -213,6 +171,24 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
self.assertEqual(model.outputs[0].shape.as_list(), [None, 3])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_add_loss_correctness(self):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
self.add_loss(math_ops.reduce_sum(inputs))
|
||||
return inputs
|
||||
|
||||
inputs = keras.Input((3,))
|
||||
layer = MyLayer()
|
||||
outputs = layer(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
self.assertEqual(len(model.losses), 1)
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(loss, 2 * 3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_invalid_forward_pass(self):
|
||||
inputs = keras.Input((3,))
|
||||
@ -676,6 +652,201 @@ class NameScopingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
|
||||
|
||||
|
||||
class AutographControlFlowTest(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_if_training_pattern_output(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
return inputs * 1.
|
||||
return inputs * 0.
|
||||
|
||||
inputs = keras.Input((3,))
|
||||
outputs = MyLayer()(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(train_loss, 0.)
|
||||
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(test_loss, 1.)
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_if_training_pattern_loss(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = math_ops.reduce_sum(inputs)
|
||||
else:
|
||||
loss = 0.
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
|
||||
inputs = keras.Input((3,))
|
||||
outputs = MyLayer()(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(train_loss, 2 * 3)
|
||||
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(test_loss, 0)
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_if_training_pattern_metric(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = math_ops.reduce_sum(inputs)
|
||||
else:
|
||||
metric = 0.
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
|
||||
inputs = keras.Input((3,))
|
||||
outputs = MyLayer()(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||
_, train_metric = model.train_on_batch(np.ones((2, 3)),
|
||||
np.ones((2, 3)))
|
||||
self.assertEqual(train_metric, 2 * 3)
|
||||
_, test_metric = model.test_on_batch(np.ones((2, 3)),
|
||||
np.ones((2, 3)))
|
||||
self.assertEqual(test_metric, 0)
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_if_training_pattern_update(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.counter = self.add_weight(
|
||||
shape=(), trainable=False, initializer='zeros')
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
increment = 1.
|
||||
else:
|
||||
increment = 0.
|
||||
self.counter.assign_add(increment)
|
||||
return inputs
|
||||
|
||||
inputs = keras.Input((3,))
|
||||
layer = MyLayer()
|
||||
outputs = layer(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(keras.backend.get_value(layer.counter), 1.)
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_conditional_updates_in_call(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(MyLayer, self).__init__(self, dynamic=eager)
|
||||
|
||||
def build(self, input_shape):
|
||||
self.counter = self.add_weight(
|
||||
shape=(), trainable=False, initializer='zeros')
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs)))
|
||||
return inputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
if eager:
|
||||
inputs = keras.Input((3,))
|
||||
layer = MyLayer()
|
||||
outputs = layer(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(keras.backend.get_value(layer.counter), 6.)
|
||||
else:
|
||||
# TODO(fchollet): support the same workflow in graph mode.
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'`add_update` in a control flow branch'):
|
||||
layer = MyLayer()(keras.Input((3,)))
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_conditional_losses_in_call(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(MyLayer, self).__init__(self, dynamic=eager)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_loss(math_ops.reduce_sum(inputs))
|
||||
return inputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
if eager:
|
||||
inputs = keras.Input((3,))
|
||||
layer = MyLayer()
|
||||
outputs = layer(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse')
|
||||
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(loss, 2 * 3)
|
||||
else:
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'`add_loss` in a control flow branch'):
|
||||
layer = MyLayer()(keras.Input((3,)))
|
||||
|
||||
@parameterized.named_parameters(('eager', True),
|
||||
('symbolic', False))
|
||||
def test_conditional_metrics_in_call(self, eager):
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(MyLayer, self).__init__(self, dynamic=eager)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_metric(math_ops.reduce_sum(inputs),
|
||||
name='sum',
|
||||
aggregation='mean')
|
||||
return inputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
if eager:
|
||||
inputs = keras.Input((3,))
|
||||
layer = MyLayer()
|
||||
outputs = layer(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse')
|
||||
history = model.fit(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(history.history['sum'][-1], 2 * 3)
|
||||
else:
|
||||
# TODO(fchollet): support the same workflow in graph mode.
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'`add_metric` in a control flow branch'):
|
||||
layer = MyLayer()(keras.Input((3,)))
|
||||
|
||||
|
||||
_LAYERS_TO_TEST = [
|
||||
(keras.layers.Dense, (1,), collections.OrderedDict(units=[1])),
|
||||
(keras.layers.Activation, (2, 2),
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
@ -571,3 +572,96 @@ def autocast_context_manager(input_list, should_cast):
|
||||
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
|
||||
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
|
||||
var_read_dtype)
|
||||
|
||||
|
||||
def is_subclassed(layer):
|
||||
return (layer.__module__.find('keras.engine') == -1 and
|
||||
layer.__module__.find('keras.layers') == -1)
|
||||
|
||||
|
||||
def check_graph_consistency(tensor, method):
|
||||
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
||||
|
||||
When one of the `add_*` method is called inside a V2 conditional branch,
|
||||
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
|
||||
We need to raise clear error messages in such cases.
|
||||
|
||||
Arguments:
|
||||
tensor: Tensor to check.
|
||||
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
||||
|
||||
Raises:
|
||||
RuntimeError: In case of an out-of-graph tensor.
|
||||
"""
|
||||
if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'):
|
||||
if isinstance(tensor.graph,
|
||||
(control_flow_util_v2.CondBranchFuncGraph,
|
||||
control_flow_util_v2.WhileCondFuncGraph,
|
||||
control_flow_util_v2.WhileBodyFuncGraph)):
|
||||
if method == 'add_metric':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
else:
|
||||
metric = 0.
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
elif method == 'add_loss':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
else:
|
||||
loss = 0.
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
else:
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.w.assign_add(1))
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
increment = 1
|
||||
else:
|
||||
increment = 0
|
||||
self.add_update(self.w.assign_add(increment))
|
||||
return inputs
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using the method `{method}` in a control flow branch '
|
||||
'in your layer, e.g.:\n{bad_example}\n'
|
||||
'This is not currently supported. '
|
||||
'You should either use static control flow (`tf.cond`) '
|
||||
'or move your call to {method} out of the control flow branch, '
|
||||
'e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your layer '
|
||||
'as dynamic (eager-only) by passing '
|
||||
'`dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you '
|
||||
'to implement static shape inference '
|
||||
'in the `compute_output_shape(input_shape)` method.'.format(
|
||||
method=method,
|
||||
bad_example=bad_example,
|
||||
correct_example=correct_example))
|
||||
|
138
tensorflow/python/keras/engine/control_flow_test.py
Normal file
138
tensorflow/python/keras/engine/control_flow_test.py
Normal file
@ -0,0 +1,138 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for dynamic control flow behavior with Keras."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ControlFlowLayer1(base_layer.Layer):
|
||||
"""Layer with an `if` condition in call."""
|
||||
|
||||
def call(self, inputs):
|
||||
if math_ops.reduce_sum(inputs) > 0:
|
||||
return math_ops.sqrt(inputs)
|
||||
else:
|
||||
return math_ops.square(inputs)
|
||||
|
||||
|
||||
class ControlFlowLayer2(base_layer.Layer):
|
||||
"""Layer with a `for` loop in call."""
|
||||
|
||||
def call(self, inputs):
|
||||
samples = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
i = 0
|
||||
for sample in inputs:
|
||||
samples = samples.write(i, math_ops.square(sample))
|
||||
i += 1
|
||||
return samples.stack()
|
||||
|
||||
|
||||
class NestedControlFlowLayer(base_layer.Layer):
|
||||
"""Layer nested with a control flow layer."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(NestedControlFlowLayer, self).__init__(**kwargs)
|
||||
self.layer = ControlFlowLayer1()
|
||||
|
||||
def call(self, inputs):
|
||||
return self.layer(inputs)
|
||||
|
||||
|
||||
class ControlFlowModel(keras.Model):
|
||||
"""Model with an `if` condition in call."""
|
||||
|
||||
def call(self, inputs):
|
||||
if math_ops.reduce_sum(inputs) > 0:
|
||||
return math_ops.sqrt(inputs)
|
||||
else:
|
||||
return math_ops.square(inputs)
|
||||
|
||||
|
||||
class NestedControlFlowModel(keras.Model):
|
||||
"""Model with an `if` condition in call using a control flow layer."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(NestedControlFlowModel, self).__init__(**kwargs)
|
||||
self.layer = NestedControlFlowLayer()
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = self.layer(inputs)
|
||||
if math_ops.reduce_sum(inputs) > 0:
|
||||
return math_ops.sqrt(inputs)
|
||||
else:
|
||||
return math_ops.square(inputs)
|
||||
|
||||
|
||||
class FunctionControlFlowModel(keras.Model):
|
||||
"""Model with control flow where `call` is wrapped in function already."""
|
||||
|
||||
@def_function.function
|
||||
def call(self, inputs):
|
||||
if math_ops.reduce_sum(inputs) > 0:
|
||||
return math_ops.sqrt(inputs)
|
||||
else:
|
||||
return math_ops.square(inputs)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class AutographWrapperTest(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@parameterized.named_parameters(('with_if', ControlFlowLayer1),
|
||||
('with_for', ControlFlowLayer2),
|
||||
('nested', NestedControlFlowLayer))
|
||||
def test_control_flow_layer(self, layer_class):
|
||||
model = testing_utils.get_model_from_layers([layer_class()],
|
||||
input_shape=(3,))
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('with_if', ControlFlowModel), ('nested', NestedControlFlowModel),
|
||||
('wrapped_in_function', FunctionControlFlowModel))
|
||||
def test_control_flow_model(self, model_class):
|
||||
model = model_class()
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
def test_control_flow_in_deferred_sequential_model(self):
|
||||
model = keras.Sequential(
|
||||
[ControlFlowLayer1(),
|
||||
keras.layers.Dense(3),
|
||||
ControlFlowLayer2()])
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user