Reenable autograph in Layer.__call__
and add safeguards to prevent users from writing invalid static control flow.
PiperOrigin-RevId: 245151330
This commit is contained in:
parent
8be60a9ce8
commit
cf31e9001c
@ -1357,6 +1357,19 @@ tf_py_test(
|
|||||||
tags = ["no_rocm"],
|
tags = ["no_rocm"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "control_flow_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["engine/control_flow_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":keras",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
shard_count = 8,
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "hdf5_format_test",
|
name = "hdf5_format_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -26,6 +26,7 @@ import numpy as np
|
|||||||
from six.moves import zip # pylint: disable=redefined-builtin
|
from six.moves import zip # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
|
from tensorflow.python import autograph
|
||||||
from tensorflow.python.distribute import values as distribute_values
|
from tensorflow.python.distribute import values as distribute_values
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
@ -594,6 +595,21 @@ class Layer(module.Module):
|
|||||||
# Build layer if applicable (if the `build` method has been
|
# Build layer if applicable (if the `build` method has been
|
||||||
# overridden).
|
# overridden).
|
||||||
self._maybe_build(inputs)
|
self._maybe_build(inputs)
|
||||||
|
|
||||||
|
# Wrapping `call` function in autograph to allow for dynamic control
|
||||||
|
# dependencies in call. We are limiting this to subclassed layers as
|
||||||
|
# autograph is strictly needed only for subclassed layers.
|
||||||
|
if base_layer_utils.is_subclassed(self):
|
||||||
|
decorators, original_func = tf_decorator.unwrap(self.call)
|
||||||
|
converted_func = autograph.convert(recursive=True)(original_func)
|
||||||
|
if decorators:
|
||||||
|
call_fn = tf_decorator.rewrap(self.call, original_func,
|
||||||
|
converted_func)
|
||||||
|
else:
|
||||||
|
call_fn = converted_func
|
||||||
|
else:
|
||||||
|
call_fn = self.call
|
||||||
|
|
||||||
# Explicitly pass the learning phase placeholder to `call` if
|
# Explicitly pass the learning phase placeholder to `call` if
|
||||||
# the `training` argument was left unspecified by the user.
|
# the `training` argument was left unspecified by the user.
|
||||||
# This behavior is restricted to the managed Keras FuncGraph.
|
# This behavior is restricted to the managed Keras FuncGraph.
|
||||||
@ -613,20 +629,18 @@ class Layer(module.Module):
|
|||||||
self._mixed_precision_policy.should_cast_variables), (
|
self._mixed_precision_policy.should_cast_variables), (
|
||||||
base_layer_utils.AutoAddUpdates(self,
|
base_layer_utils.AutoAddUpdates(self,
|
||||||
inputs)) as auto_updater:
|
inputs)) as auto_updater:
|
||||||
outputs = self.call(inputs, *args, **kwargs)
|
outputs = call_fn(inputs, *args, **kwargs)
|
||||||
auto_updater.set_outputs(outputs)
|
auto_updater.set_outputs(outputs)
|
||||||
|
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
messages = ('`tf.Tensor` as a Python `bool` is not allowed',
|
|
||||||
'Tensor objects are only iterable when eager')
|
|
||||||
exception_str = str(e)
|
exception_str = str(e)
|
||||||
for msg in messages:
|
exception_msg = 'Tensor objects are only iterable when eager'
|
||||||
if msg in exception_str:
|
if exception_msg in exception_str:
|
||||||
raise TypeError('You are attempting to use Python control '
|
raise TypeError('You are attempting to use Python control '
|
||||||
'flow in a layer that was not declared to be '
|
'flow in a layer that was not declared to be '
|
||||||
'dynamic. Pass `dynamic=True` to the class '
|
'dynamic. Pass `dynamic=True` to the class '
|
||||||
'constructor.\nEncountered error:\n"""\n' +
|
'constructor.\nEncountered error:\n"""\n' +
|
||||||
exception_str + '\n"""')
|
exception_str + '\n"""')
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
# We will use static shape inference to return symbolic tensors
|
# We will use static shape inference to return symbolic tensors
|
||||||
@ -775,7 +789,7 @@ class Layer(module.Module):
|
|||||||
class MyLayer(tf.keras.layers.Layer):
|
class MyLayer(tf.keras.layers.Layer):
|
||||||
def call(inputs, self):
|
def call(inputs, self):
|
||||||
self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True)
|
self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True)
|
||||||
return 2*inputs
|
return inputs
|
||||||
```
|
```
|
||||||
|
|
||||||
This method can also be called directly on a Functional Model during
|
This method can also be called directly on a Functional Model during
|
||||||
@ -831,6 +845,7 @@ class Layer(module.Module):
|
|||||||
return None # Will be filtered out when computing the .losses property
|
return None # Will be filtered out when computing the .losses property
|
||||||
if not tensor_util.is_tensor(loss):
|
if not tensor_util.is_tensor(loss):
|
||||||
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
|
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
|
||||||
|
base_layer_utils.check_graph_consistency(loss, method='add_loss')
|
||||||
loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
|
loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -842,12 +857,15 @@ class Layer(module.Module):
|
|||||||
for loss in losses:
|
for loss in losses:
|
||||||
if callable(loss):
|
if callable(loss):
|
||||||
callable_losses.append(functools.partial(_tag_unconditional, loss))
|
callable_losses.append(functools.partial(_tag_unconditional, loss))
|
||||||
elif tf_utils.is_symbolic_tensor(loss):
|
continue
|
||||||
|
if loss is None:
|
||||||
|
continue
|
||||||
|
if not tensor_util.is_tensor(loss):
|
||||||
|
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
|
||||||
|
if tf_utils.is_symbolic_tensor(loss):
|
||||||
symbolic_losses.append(_tag_unconditional(loss))
|
symbolic_losses.append(_tag_unconditional(loss))
|
||||||
elif tensor_util.is_tensor(loss):
|
elif tensor_util.is_tensor(loss):
|
||||||
eager_losses.append(_tag_unconditional(loss))
|
eager_losses.append(_tag_unconditional(loss))
|
||||||
elif loss is not None: # `None` is valid but should be ignored.
|
|
||||||
raise ValueError('Found non-Tensor loss: ' + str(loss))
|
|
||||||
|
|
||||||
self._callable_losses += callable_losses
|
self._callable_losses += callable_losses
|
||||||
|
|
||||||
@ -1005,14 +1023,24 @@ class Layer(module.Module):
|
|||||||
return # Updates already applied when in eager mode.
|
return # Updates already applied when in eager mode.
|
||||||
|
|
||||||
def process_update(x):
|
def process_update(x):
|
||||||
|
"""Standardize update ops.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: Tensor, op, or callable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An update op.
|
||||||
|
"""
|
||||||
if callable(x):
|
if callable(x):
|
||||||
x = x()
|
x = x()
|
||||||
if isinstance(x, ops.Operation):
|
if isinstance(x, ops.Operation):
|
||||||
return x
|
update = x
|
||||||
elif hasattr(x, 'op'):
|
elif hasattr(x, 'op'):
|
||||||
return x.op
|
update = x.op
|
||||||
else:
|
else:
|
||||||
return ops.convert_to_tensor(x)
|
update = ops.convert_to_tensor(x)
|
||||||
|
base_layer_utils.check_graph_consistency(update, method='add_update')
|
||||||
|
return update
|
||||||
|
|
||||||
updates = [process_update(x) for x in updates]
|
updates = [process_update(x) for x in updates]
|
||||||
self._updates += updates
|
self._updates += updates
|
||||||
@ -1502,6 +1530,7 @@ class Layer(module.Module):
|
|||||||
self._metrics.append(metric_obj)
|
self._metrics.append(metric_obj)
|
||||||
|
|
||||||
def _symbolic_add_metric(self, value, aggregation=None, name=None):
|
def _symbolic_add_metric(self, value, aggregation=None, name=None):
|
||||||
|
base_layer_utils.check_graph_consistency(value, method='add_metric')
|
||||||
match = self._get_existing_metric(name)
|
match = self._get_existing_metric(name)
|
||||||
if aggregation is None:
|
if aggregation is None:
|
||||||
# Iterate over the metrics and check if the given metric exists already.
|
# Iterate over the metrics and check if the given metric exists already.
|
||||||
|
@ -28,6 +28,7 @@ import numpy as np
|
|||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -40,35 +41,22 @@ from tensorflow.python.layers import core as legacy_core
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class DynamicLayer1(base_layer.Layer):
|
class DynamicLayer(base_layer.Layer):
|
||||||
|
|
||||||
def __init__(self, dynamic=False, **kwargs):
|
def __init__(self, dynamic=False, **kwargs):
|
||||||
super(DynamicLayer1, self).__init__(dynamic=dynamic, **kwargs)
|
super(DynamicLayer, self).__init__(dynamic=dynamic, **kwargs)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
if math_ops.reduce_sum(inputs) > 0:
|
samples = tensor_array_ops.TensorArray(
|
||||||
return math_ops.sqrt(inputs)
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
else:
|
for idx, sample in enumerate(inputs):
|
||||||
return math_ops.square(inputs)
|
samples = samples.write(idx, math_ops.square(sample))
|
||||||
|
return samples.stack()
|
||||||
def compute_output_shape(self, input_shape):
|
|
||||||
return input_shape
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicLayer2(base_layer.Layer):
|
|
||||||
|
|
||||||
def __init__(self, dynamic=False, **kwargs):
|
|
||||||
super(DynamicLayer2, self).__init__(dynamic=dynamic, **kwargs)
|
|
||||||
|
|
||||||
def call(self, inputs):
|
|
||||||
samples = []
|
|
||||||
for sample in inputs:
|
|
||||||
samples.append(math_ops.square(sample))
|
|
||||||
return array_ops.stack(samples, axis=0)
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
@ -82,34 +70,39 @@ class InvalidLayer(base_layer.Layer):
|
|||||||
|
|
||||||
class BaseLayerTest(keras_parameterized.TestCase):
|
class BaseLayerTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.parameters(DynamicLayer1, DynamicLayer2)
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_dynamic_layer_in_functional_model_in_graph_mode(self, layer_class):
|
def test_dynamic_layer(self):
|
||||||
|
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||||
|
input_shape=(3,))
|
||||||
|
self.assertEqual(model.dynamic, True)
|
||||||
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
|
self.assertEqual(model.run_eagerly, True)
|
||||||
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
def test_dynamic_layer_error(self):
|
||||||
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
'attempting to use Python control flow'):
|
||||||
|
model = testing_utils.get_model_from_layers([DynamicLayer()],
|
||||||
|
input_shape=(3,))
|
||||||
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
def test_dynamic_layer_error_running_in_graph_mode(self):
|
||||||
with context.graph_mode():
|
with context.graph_mode():
|
||||||
inputs = keras.Input((3,))
|
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||||
# Works when `dynamic=True` is declared.
|
input_shape=(3,))
|
||||||
outputs = layer_class(dynamic=True)(inputs)
|
|
||||||
model = keras.Model(inputs, outputs)
|
|
||||||
self.assertEqual(model.dynamic, True)
|
self.assertEqual(model.dynamic, True)
|
||||||
# But then you cannot run the model since you're in a graph scope.
|
# But then you cannot run the model since you're in a graph scope.
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, 'You must enable eager execution'):
|
ValueError, 'You must enable eager execution'):
|
||||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
|
|
||||||
# Fails when `dynamic=True` not declared.
|
def test_dynamic_layer_with_deferred_sequential_model(self):
|
||||||
with self.assertRaisesRegexp(
|
model = keras.Sequential(
|
||||||
TypeError, 'attempting to use Python control flow'):
|
[DynamicLayer(dynamic=True),
|
||||||
_ = layer_class()(inputs)
|
keras.layers.Dense(3)])
|
||||||
|
|
||||||
@parameterized.parameters(DynamicLayer1, DynamicLayer2)
|
|
||||||
def test_dynamic_layer_in_functional_model_in_eager_mode(self, layer_class):
|
|
||||||
inputs = keras.Input((3,))
|
|
||||||
# Fails when `dynamic=True` not declared.
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
TypeError, 'attempting to use Python control flow'):
|
|
||||||
_ = layer_class()(inputs)
|
|
||||||
# Works when `dynamic=True` is declared.
|
|
||||||
outputs = layer_class(dynamic=True)(inputs)
|
|
||||||
model = keras.Model(inputs, outputs)
|
|
||||||
self.assertEqual(model.dynamic, True)
|
self.assertEqual(model.dynamic, True)
|
||||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
self.assertEqual(model.run_eagerly, True)
|
self.assertEqual(model.run_eagerly, True)
|
||||||
@ -117,12 +110,12 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def test_nested_dynamic_layers_in_eager_mode(self):
|
def test_nested_dynamic_layers_in_eager_mode(self):
|
||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
outputs = DynamicLayer1(dynamic=True)(inputs)
|
outputs = DynamicLayer(dynamic=True)(inputs)
|
||||||
inner_model = keras.Model(inputs, outputs)
|
inner_model = keras.Model(inputs, outputs)
|
||||||
self.assertEqual(inner_model.dynamic, True)
|
self.assertEqual(inner_model.dynamic, True)
|
||||||
|
|
||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
x = DynamicLayer2(dynamic=True)(inputs)
|
x = DynamicLayer(dynamic=True)(inputs)
|
||||||
outputs = inner_model(x)
|
outputs = inner_model(x)
|
||||||
|
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
@ -131,41 +124,6 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(model.run_eagerly, True)
|
self.assertEqual(model.run_eagerly, True)
|
||||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
|
|
||||||
def test_dynamic_layers_in_sequential_model(self):
|
|
||||||
# Without input_shape argument
|
|
||||||
model = keras.Sequential([DynamicLayer1(dynamic=True),
|
|
||||||
keras.layers.Dense(3),
|
|
||||||
DynamicLayer2(dynamic=True)])
|
|
||||||
self.assertEqual(model.dynamic, True)
|
|
||||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
|
||||||
self.assertEqual(model.run_eagerly, True)
|
|
||||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
|
||||||
|
|
||||||
# With input_shape argument
|
|
||||||
model = keras.Sequential([DynamicLayer1(dynamic=True, input_shape=(3,)),
|
|
||||||
DynamicLayer2(dynamic=True)])
|
|
||||||
self.assertEqual(model.dynamic, True)
|
|
||||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
|
||||||
self.assertEqual(model.run_eagerly, True)
|
|
||||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
|
||||||
|
|
||||||
def test_dynamic_layers_in_subclassed_model(self):
|
|
||||||
|
|
||||||
class MyModel(keras.Model):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(MyModel, self).__init__()
|
|
||||||
self.layer1 = DynamicLayer1(dynamic=True)
|
|
||||||
|
|
||||||
def call(self, inputs):
|
|
||||||
return self.layer1(inputs)
|
|
||||||
|
|
||||||
model = MyModel()
|
|
||||||
self.assertEqual(model.dynamic, True)
|
|
||||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
|
||||||
self.assertEqual(model.run_eagerly, True)
|
|
||||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
|
||||||
|
|
||||||
def test_dynamic_subclassed_model_no_shape_inference(self):
|
def test_dynamic_subclassed_model_no_shape_inference(self):
|
||||||
|
|
||||||
class MyModel(keras.Model):
|
class MyModel(keras.Model):
|
||||||
@ -213,6 +171,24 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
self.assertEqual(model.outputs[0].shape.as_list(), [None, 3])
|
self.assertEqual(model.outputs[0].shape.as_list(), [None, 3])
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_add_loss_correctness(self):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
self.add_loss(math_ops.reduce_sum(inputs))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
layer = MyLayer()
|
||||||
|
outputs = layer(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
self.assertEqual(len(model.losses), 1)
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(loss, 2 * 3)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_invalid_forward_pass(self):
|
def test_invalid_forward_pass(self):
|
||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
@ -676,6 +652,201 @@ class NameScopingTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
|
self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
|
||||||
|
|
||||||
|
|
||||||
|
class AutographControlFlowTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_if_training_pattern_output(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
return inputs * 1.
|
||||||
|
return inputs * 0.
|
||||||
|
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
outputs = MyLayer()(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||||
|
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(train_loss, 0.)
|
||||||
|
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(test_loss, 1.)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_if_training_pattern_loss(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
loss = math_ops.reduce_sum(inputs)
|
||||||
|
else:
|
||||||
|
loss = 0.
|
||||||
|
self.add_loss(loss)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
outputs = MyLayer()(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||||
|
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(train_loss, 2 * 3)
|
||||||
|
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(test_loss, 0)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_if_training_pattern_metric(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
metric = math_ops.reduce_sum(inputs)
|
||||||
|
else:
|
||||||
|
metric = 0.
|
||||||
|
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
outputs = MyLayer()(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||||
|
_, train_metric = model.train_on_batch(np.ones((2, 3)),
|
||||||
|
np.ones((2, 3)))
|
||||||
|
self.assertEqual(train_metric, 2 * 3)
|
||||||
|
_, test_metric = model.test_on_batch(np.ones((2, 3)),
|
||||||
|
np.ones((2, 3)))
|
||||||
|
self.assertEqual(test_metric, 0)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_if_training_pattern_update(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.counter = self.add_weight(
|
||||||
|
shape=(), trainable=False, initializer='zeros')
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
increment = 1.
|
||||||
|
else:
|
||||||
|
increment = 0.
|
||||||
|
self.counter.assign_add(increment)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
layer = MyLayer()
|
||||||
|
outputs = layer(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||||
|
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(keras.backend.get_value(layer.counter), 1.)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_conditional_updates_in_call(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyLayer, self).__init__(self, dynamic=eager)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.counter = self.add_weight(
|
||||||
|
shape=(), trainable=False, initializer='zeros')
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs)))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return input_shape
|
||||||
|
|
||||||
|
if eager:
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
layer = MyLayer()
|
||||||
|
outputs = layer(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=eager)
|
||||||
|
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(keras.backend.get_value(layer.counter), 6.)
|
||||||
|
else:
|
||||||
|
# TODO(fchollet): support the same workflow in graph mode.
|
||||||
|
with self.assertRaisesRegexp(RuntimeError,
|
||||||
|
'`add_update` in a control flow branch'):
|
||||||
|
layer = MyLayer()(keras.Input((3,)))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_conditional_losses_in_call(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyLayer, self).__init__(self, dynamic=eager)
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
self.add_loss(math_ops.reduce_sum(inputs))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return input_shape
|
||||||
|
|
||||||
|
if eager:
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
layer = MyLayer()
|
||||||
|
outputs = layer(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse')
|
||||||
|
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(loss, 2 * 3)
|
||||||
|
else:
|
||||||
|
with self.assertRaisesRegexp(RuntimeError,
|
||||||
|
'`add_loss` in a control flow branch'):
|
||||||
|
layer = MyLayer()(keras.Input((3,)))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', True),
|
||||||
|
('symbolic', False))
|
||||||
|
def test_conditional_metrics_in_call(self, eager):
|
||||||
|
|
||||||
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyLayer, self).__init__(self, dynamic=eager)
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
self.add_metric(math_ops.reduce_sum(inputs),
|
||||||
|
name='sum',
|
||||||
|
aggregation='mean')
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return input_shape
|
||||||
|
|
||||||
|
if eager:
|
||||||
|
inputs = keras.Input((3,))
|
||||||
|
layer = MyLayer()
|
||||||
|
outputs = layer(inputs)
|
||||||
|
model = keras.Model(inputs, outputs)
|
||||||
|
model.compile('sgd', 'mse')
|
||||||
|
history = model.fit(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
|
self.assertEqual(history.history['sum'][-1], 2 * 3)
|
||||||
|
else:
|
||||||
|
# TODO(fchollet): support the same workflow in graph mode.
|
||||||
|
with self.assertRaisesRegexp(RuntimeError,
|
||||||
|
'`add_metric` in a control flow branch'):
|
||||||
|
layer = MyLayer()(keras.Input((3,)))
|
||||||
|
|
||||||
|
|
||||||
_LAYERS_TO_TEST = [
|
_LAYERS_TO_TEST = [
|
||||||
(keras.layers.Dense, (1,), collections.OrderedDict(units=[1])),
|
(keras.layers.Dense, (1,), collections.OrderedDict(units=[1])),
|
||||||
(keras.layers.Activation, (2, 2),
|
(keras.layers.Activation, (2, 2),
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.keras import backend
|
|||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
|
from tensorflow.python.ops import control_flow_util_v2
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import init_ops_v2
|
from tensorflow.python.ops import init_ops_v2
|
||||||
from tensorflow.python.ops import variables as tf_variables
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
@ -571,3 +572,96 @@ def autocast_context_manager(input_list, should_cast):
|
|||||||
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
|
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
|
||||||
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
|
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
|
||||||
var_read_dtype)
|
var_read_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def is_subclassed(layer):
|
||||||
|
return (layer.__module__.find('keras.engine') == -1 and
|
||||||
|
layer.__module__.find('keras.layers') == -1)
|
||||||
|
|
||||||
|
|
||||||
|
def check_graph_consistency(tensor, method):
|
||||||
|
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
||||||
|
|
||||||
|
When one of the `add_*` method is called inside a V2 conditional branch,
|
||||||
|
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
|
||||||
|
We need to raise clear error messages in such cases.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor: Tensor to check.
|
||||||
|
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: In case of an out-of-graph tensor.
|
||||||
|
"""
|
||||||
|
if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'):
|
||||||
|
if isinstance(tensor.graph,
|
||||||
|
(control_flow_util_v2.CondBranchFuncGraph,
|
||||||
|
control_flow_util_v2.WhileCondFuncGraph,
|
||||||
|
control_flow_util_v2.WhileBodyFuncGraph)):
|
||||||
|
if method == 'add_metric':
|
||||||
|
bad_example = """
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
metric = compute_metric(inputs)
|
||||||
|
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||||
|
return inputs
|
||||||
|
"""
|
||||||
|
correct_example = """
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
metric = compute_metric(inputs)
|
||||||
|
else:
|
||||||
|
metric = 0.
|
||||||
|
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||||
|
return inputs
|
||||||
|
"""
|
||||||
|
elif method == 'add_loss':
|
||||||
|
bad_example = """
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
loss = compute_loss(inputs)
|
||||||
|
self.add_loss(loss)
|
||||||
|
return inputs
|
||||||
|
"""
|
||||||
|
correct_example = """
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
loss = compute_loss(inputs)
|
||||||
|
else:
|
||||||
|
loss = 0.
|
||||||
|
self.add_loss(loss)
|
||||||
|
return inputs
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
bad_example = """
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
self.add_update(self.w.assign_add(1))
|
||||||
|
return inputs
|
||||||
|
"""
|
||||||
|
correct_example = """
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
increment = 1
|
||||||
|
else:
|
||||||
|
increment = 0
|
||||||
|
self.add_update(self.w.assign_add(increment))
|
||||||
|
return inputs
|
||||||
|
"""
|
||||||
|
raise RuntimeError(
|
||||||
|
'You are using the method `{method}` in a control flow branch '
|
||||||
|
'in your layer, e.g.:\n{bad_example}\n'
|
||||||
|
'This is not currently supported. '
|
||||||
|
'You should either use static control flow (`tf.cond`) '
|
||||||
|
'or move your call to {method} out of the control flow branch, '
|
||||||
|
'e.g.:\n{correct_example}\n'
|
||||||
|
'You can also resolve this by marking your layer '
|
||||||
|
'as dynamic (eager-only) by passing '
|
||||||
|
'`dynamic=True` to the layer constructor. '
|
||||||
|
'Any kind of control flow is supported with dynamic layers. '
|
||||||
|
'Note that using `dynamic=True` requires you '
|
||||||
|
'to implement static shape inference '
|
||||||
|
'in the `compute_output_shape(input_shape)` method.'.format(
|
||||||
|
method=method,
|
||||||
|
bad_example=bad_example,
|
||||||
|
correct_example=correct_example))
|
||||||
|
138
tensorflow/python/keras/engine/control_flow_test.py
Normal file
138
tensorflow/python/keras/engine/control_flow_test.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for dynamic control flow behavior with Keras."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.engine import base_layer
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ControlFlowLayer1(base_layer.Layer):
|
||||||
|
"""Layer with an `if` condition in call."""
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
if math_ops.reduce_sum(inputs) > 0:
|
||||||
|
return math_ops.sqrt(inputs)
|
||||||
|
else:
|
||||||
|
return math_ops.square(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlFlowLayer2(base_layer.Layer):
|
||||||
|
"""Layer with a `for` loop in call."""
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
samples = tensor_array_ops.TensorArray(
|
||||||
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
|
i = 0
|
||||||
|
for sample in inputs:
|
||||||
|
samples = samples.write(i, math_ops.square(sample))
|
||||||
|
i += 1
|
||||||
|
return samples.stack()
|
||||||
|
|
||||||
|
|
||||||
|
class NestedControlFlowLayer(base_layer.Layer):
|
||||||
|
"""Layer nested with a control flow layer."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(NestedControlFlowLayer, self).__init__(**kwargs)
|
||||||
|
self.layer = ControlFlowLayer1()
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.layer(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlFlowModel(keras.Model):
|
||||||
|
"""Model with an `if` condition in call."""
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
if math_ops.reduce_sum(inputs) > 0:
|
||||||
|
return math_ops.sqrt(inputs)
|
||||||
|
else:
|
||||||
|
return math_ops.square(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class NestedControlFlowModel(keras.Model):
|
||||||
|
"""Model with an `if` condition in call using a control flow layer."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(NestedControlFlowModel, self).__init__(**kwargs)
|
||||||
|
self.layer = NestedControlFlowLayer()
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
inputs = self.layer(inputs)
|
||||||
|
if math_ops.reduce_sum(inputs) > 0:
|
||||||
|
return math_ops.sqrt(inputs)
|
||||||
|
else:
|
||||||
|
return math_ops.square(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionControlFlowModel(keras.Model):
|
||||||
|
"""Model with control flow where `call` is wrapped in function already."""
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def call(self, inputs):
|
||||||
|
if math_ops.reduce_sum(inputs) > 0:
|
||||||
|
return math_ops.sqrt(inputs)
|
||||||
|
else:
|
||||||
|
return math_ops.square(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class AutographWrapperTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
@parameterized.named_parameters(('with_if', ControlFlowLayer1),
|
||||||
|
('with_for', ControlFlowLayer2),
|
||||||
|
('nested', NestedControlFlowLayer))
|
||||||
|
def test_control_flow_layer(self, layer_class):
|
||||||
|
model = testing_utils.get_model_from_layers([layer_class()],
|
||||||
|
input_shape=(3,))
|
||||||
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('with_if', ControlFlowModel), ('nested', NestedControlFlowModel),
|
||||||
|
('wrapped_in_function', FunctionControlFlowModel))
|
||||||
|
def test_control_flow_model(self, model_class):
|
||||||
|
model = model_class()
|
||||||
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
|
|
||||||
|
def test_control_flow_in_deferred_sequential_model(self):
|
||||||
|
model = keras.Sequential(
|
||||||
|
[ControlFlowLayer1(),
|
||||||
|
keras.layers.Dense(3),
|
||||||
|
ControlFlowLayer2()])
|
||||||
|
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||||
|
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
x
Reference in New Issue
Block a user