Reenable autograph in Layer.__call__ and add safeguards to prevent users from writing invalid static control flow.

PiperOrigin-RevId: 245151330
This commit is contained in:
Francois Chollet 2019-04-24 17:48:17 -07:00 committed by TensorFlower Gardener
parent 8be60a9ce8
commit cf31e9001c
5 changed files with 541 additions and 96 deletions

View File

@ -1357,6 +1357,19 @@ tf_py_test(
tags = ["no_rocm"],
)
tf_py_test(
name = "control_flow_test",
size = "medium",
srcs = ["engine/control_flow_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
shard_count = 8,
)
tf_py_test(
name = "hdf5_format_test",
size = "medium",

View File

@ -26,6 +26,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.core.framework import node_def_pb2
from tensorflow.python import autograph
from tensorflow.python.distribute import values as distribute_values
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
@ -594,6 +595,21 @@ class Layer(module.Module):
# Build layer if applicable (if the `build` method has been
# overridden).
self._maybe_build(inputs)
# Wrapping `call` function in autograph to allow for dynamic control
# dependencies in call. We are limiting this to subclassed layers as
# autograph is strictly needed only for subclassed layers.
if base_layer_utils.is_subclassed(self):
decorators, original_func = tf_decorator.unwrap(self.call)
converted_func = autograph.convert(recursive=True)(original_func)
if decorators:
call_fn = tf_decorator.rewrap(self.call, original_func,
converted_func)
else:
call_fn = converted_func
else:
call_fn = self.call
# Explicitly pass the learning phase placeholder to `call` if
# the `training` argument was left unspecified by the user.
# This behavior is restricted to the managed Keras FuncGraph.
@ -613,20 +629,18 @@ class Layer(module.Module):
self._mixed_precision_policy.should_cast_variables), (
base_layer_utils.AutoAddUpdates(self,
inputs)) as auto_updater:
outputs = self.call(inputs, *args, **kwargs)
outputs = call_fn(inputs, *args, **kwargs)
auto_updater.set_outputs(outputs)
except TypeError as e:
messages = ('`tf.Tensor` as a Python `bool` is not allowed',
'Tensor objects are only iterable when eager')
exception_str = str(e)
for msg in messages:
if msg in exception_str:
raise TypeError('You are attempting to use Python control '
'flow in a layer that was not declared to be '
'dynamic. Pass `dynamic=True` to the class '
'constructor.\nEncountered error:\n"""\n' +
exception_str + '\n"""')
exception_msg = 'Tensor objects are only iterable when eager'
if exception_msg in exception_str:
raise TypeError('You are attempting to use Python control '
'flow in a layer that was not declared to be '
'dynamic. Pass `dynamic=True` to the class '
'constructor.\nEncountered error:\n"""\n' +
exception_str + '\n"""')
raise
else:
# We will use static shape inference to return symbolic tensors
@ -775,7 +789,7 @@ class Layer(module.Module):
class MyLayer(tf.keras.layers.Layer):
def call(inputs, self):
self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True)
return 2*inputs
return inputs
```
This method can also be called directly on a Functional Model during
@ -831,6 +845,7 @@ class Layer(module.Module):
return None # Will be filtered out when computing the .losses property
if not tensor_util.is_tensor(loss):
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
base_layer_utils.check_graph_consistency(loss, method='add_loss')
loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
return loss
@ -842,12 +857,15 @@ class Layer(module.Module):
for loss in losses:
if callable(loss):
callable_losses.append(functools.partial(_tag_unconditional, loss))
elif tf_utils.is_symbolic_tensor(loss):
continue
if loss is None:
continue
if not tensor_util.is_tensor(loss):
loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
if tf_utils.is_symbolic_tensor(loss):
symbolic_losses.append(_tag_unconditional(loss))
elif tensor_util.is_tensor(loss):
eager_losses.append(_tag_unconditional(loss))
elif loss is not None: # `None` is valid but should be ignored.
raise ValueError('Found non-Tensor loss: ' + str(loss))
self._callable_losses += callable_losses
@ -1005,14 +1023,24 @@ class Layer(module.Module):
return # Updates already applied when in eager mode.
def process_update(x):
"""Standardize update ops.
Arguments:
x: Tensor, op, or callable.
Returns:
An update op.
"""
if callable(x):
x = x()
if isinstance(x, ops.Operation):
return x
update = x
elif hasattr(x, 'op'):
return x.op
update = x.op
else:
return ops.convert_to_tensor(x)
update = ops.convert_to_tensor(x)
base_layer_utils.check_graph_consistency(update, method='add_update')
return update
updates = [process_update(x) for x in updates]
self._updates += updates
@ -1502,6 +1530,7 @@ class Layer(module.Module):
self._metrics.append(metric_obj)
def _symbolic_add_metric(self, value, aggregation=None, name=None):
base_layer_utils.check_graph_consistency(value, method='add_metric')
match = self._get_existing_metric(name)
if aggregation is None:
# Iterate over the metrics and check if the given metric exists already.

View File

@ -28,6 +28,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@ -40,35 +41,22 @@ from tensorflow.python.layers import core as legacy_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class DynamicLayer1(base_layer.Layer):
class DynamicLayer(base_layer.Layer):
def __init__(self, dynamic=False, **kwargs):
super(DynamicLayer1, self).__init__(dynamic=dynamic, **kwargs)
super(DynamicLayer, self).__init__(dynamic=dynamic, **kwargs)
def call(self, inputs):
if math_ops.reduce_sum(inputs) > 0:
return math_ops.sqrt(inputs)
else:
return math_ops.square(inputs)
def compute_output_shape(self, input_shape):
return input_shape
class DynamicLayer2(base_layer.Layer):
def __init__(self, dynamic=False, **kwargs):
super(DynamicLayer2, self).__init__(dynamic=dynamic, **kwargs)
def call(self, inputs):
samples = []
for sample in inputs:
samples.append(math_ops.square(sample))
return array_ops.stack(samples, axis=0)
samples = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
for idx, sample in enumerate(inputs):
samples = samples.write(idx, math_ops.square(sample))
return samples.stack()
def compute_output_shape(self, input_shape):
return input_shape
@ -82,34 +70,39 @@ class InvalidLayer(base_layer.Layer):
class BaseLayerTest(keras_parameterized.TestCase):
@parameterized.parameters(DynamicLayer1, DynamicLayer2)
def test_dynamic_layer_in_functional_model_in_graph_mode(self, layer_class):
@keras_parameterized.run_with_all_model_types
def test_dynamic_layer(self):
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
input_shape=(3,))
self.assertEqual(model.dynamic, True)
model.compile(rmsprop.RMSprop(0.001), loss='mse')
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
@keras_parameterized.run_with_all_model_types
def test_dynamic_layer_error(self):
with self.assertRaisesRegexp(TypeError,
'attempting to use Python control flow'):
model = testing_utils.get_model_from_layers([DynamicLayer()],
input_shape=(3,))
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
@keras_parameterized.run_with_all_model_types
def test_dynamic_layer_error_running_in_graph_mode(self):
with context.graph_mode():
inputs = keras.Input((3,))
# Works when `dynamic=True` is declared.
outputs = layer_class(dynamic=True)(inputs)
model = keras.Model(inputs, outputs)
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
input_shape=(3,))
self.assertEqual(model.dynamic, True)
# But then you cannot run the model since you're in a graph scope.
with self.assertRaisesRegexp(
ValueError, 'You must enable eager execution'):
model.compile(rmsprop.RMSprop(0.001), loss='mse')
# Fails when `dynamic=True` not declared.
with self.assertRaisesRegexp(
TypeError, 'attempting to use Python control flow'):
_ = layer_class()(inputs)
@parameterized.parameters(DynamicLayer1, DynamicLayer2)
def test_dynamic_layer_in_functional_model_in_eager_mode(self, layer_class):
inputs = keras.Input((3,))
# Fails when `dynamic=True` not declared.
with self.assertRaisesRegexp(
TypeError, 'attempting to use Python control flow'):
_ = layer_class()(inputs)
# Works when `dynamic=True` is declared.
outputs = layer_class(dynamic=True)(inputs)
model = keras.Model(inputs, outputs)
def test_dynamic_layer_with_deferred_sequential_model(self):
model = keras.Sequential(
[DynamicLayer(dynamic=True),
keras.layers.Dense(3)])
self.assertEqual(model.dynamic, True)
model.compile(rmsprop.RMSprop(0.001), loss='mse')
self.assertEqual(model.run_eagerly, True)
@ -117,12 +110,12 @@ class BaseLayerTest(keras_parameterized.TestCase):
def test_nested_dynamic_layers_in_eager_mode(self):
inputs = keras.Input((3,))
outputs = DynamicLayer1(dynamic=True)(inputs)
outputs = DynamicLayer(dynamic=True)(inputs)
inner_model = keras.Model(inputs, outputs)
self.assertEqual(inner_model.dynamic, True)
inputs = keras.Input((3,))
x = DynamicLayer2(dynamic=True)(inputs)
x = DynamicLayer(dynamic=True)(inputs)
outputs = inner_model(x)
model = keras.Model(inputs, outputs)
@ -131,41 +124,6 @@ class BaseLayerTest(keras_parameterized.TestCase):
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
def test_dynamic_layers_in_sequential_model(self):
# Without input_shape argument
model = keras.Sequential([DynamicLayer1(dynamic=True),
keras.layers.Dense(3),
DynamicLayer2(dynamic=True)])
self.assertEqual(model.dynamic, True)
model.compile(rmsprop.RMSprop(0.001), loss='mse')
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
# With input_shape argument
model = keras.Sequential([DynamicLayer1(dynamic=True, input_shape=(3,)),
DynamicLayer2(dynamic=True)])
self.assertEqual(model.dynamic, True)
model.compile(rmsprop.RMSprop(0.001), loss='mse')
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
def test_dynamic_layers_in_subclassed_model(self):
class MyModel(keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.layer1 = DynamicLayer1(dynamic=True)
def call(self, inputs):
return self.layer1(inputs)
model = MyModel()
self.assertEqual(model.dynamic, True)
model.compile(rmsprop.RMSprop(0.001), loss='mse')
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
def test_dynamic_subclassed_model_no_shape_inference(self):
class MyModel(keras.Model):
@ -213,6 +171,24 @@ class BaseLayerTest(keras_parameterized.TestCase):
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
self.assertEqual(model.outputs[0].shape.as_list(), [None, 3])
@keras_parameterized.run_all_keras_modes
def test_add_loss_correctness(self):
class MyLayer(keras.layers.Layer):
def call(self, inputs, training=None):
self.add_loss(math_ops.reduce_sum(inputs))
return inputs
inputs = keras.Input((3,))
layer = MyLayer()
outputs = layer(inputs)
model = keras.Model(inputs, outputs)
self.assertEqual(len(model.losses), 1)
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(loss, 2 * 3)
@test_util.run_in_graph_and_eager_modes
def test_invalid_forward_pass(self):
inputs = keras.Input((3,))
@ -676,6 +652,201 @@ class NameScopingTest(keras_parameterized.TestCase):
self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
class AutographControlFlowTest(keras_parameterized.TestCase):
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_if_training_pattern_output(self, eager):
class MyLayer(keras.layers.Layer):
def call(self, inputs, training=None):
if training:
return inputs * 1.
return inputs * 0.
inputs = keras.Input((3,))
outputs = MyLayer()(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse', run_eagerly=eager)
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(train_loss, 0.)
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(test_loss, 1.)
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_if_training_pattern_loss(self, eager):
class MyLayer(keras.layers.Layer):
def call(self, inputs, training=None):
if training:
loss = math_ops.reduce_sum(inputs)
else:
loss = 0.
self.add_loss(loss)
return inputs
inputs = keras.Input((3,))
outputs = MyLayer()(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse', run_eagerly=eager)
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(train_loss, 2 * 3)
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(test_loss, 0)
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_if_training_pattern_metric(self, eager):
class MyLayer(keras.layers.Layer):
def call(self, inputs, training=None):
if training:
metric = math_ops.reduce_sum(inputs)
else:
metric = 0.
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
inputs = keras.Input((3,))
outputs = MyLayer()(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse', run_eagerly=eager)
_, train_metric = model.train_on_batch(np.ones((2, 3)),
np.ones((2, 3)))
self.assertEqual(train_metric, 2 * 3)
_, test_metric = model.test_on_batch(np.ones((2, 3)),
np.ones((2, 3)))
self.assertEqual(test_metric, 0)
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_if_training_pattern_update(self, eager):
class MyLayer(keras.layers.Layer):
def build(self, input_shape):
self.counter = self.add_weight(
shape=(), trainable=False, initializer='zeros')
def call(self, inputs, training=None):
if training:
increment = 1.
else:
increment = 0.
self.counter.assign_add(increment)
return inputs
inputs = keras.Input((3,))
layer = MyLayer()
outputs = layer(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse', run_eagerly=eager)
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(keras.backend.get_value(layer.counter), 1.)
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_conditional_updates_in_call(self, eager):
class MyLayer(keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__(self, dynamic=eager)
def build(self, input_shape):
self.counter = self.add_weight(
shape=(), trainable=False, initializer='zeros')
def call(self, inputs, training=None):
if training:
self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs)))
return inputs
def compute_output_shape(self, input_shape):
return input_shape
if eager:
inputs = keras.Input((3,))
layer = MyLayer()
outputs = layer(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse', run_eagerly=eager)
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(keras.backend.get_value(layer.counter), 6.)
else:
# TODO(fchollet): support the same workflow in graph mode.
with self.assertRaisesRegexp(RuntimeError,
'`add_update` in a control flow branch'):
layer = MyLayer()(keras.Input((3,)))
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_conditional_losses_in_call(self, eager):
class MyLayer(keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__(self, dynamic=eager)
def call(self, inputs, training=None):
if training:
self.add_loss(math_ops.reduce_sum(inputs))
return inputs
def compute_output_shape(self, input_shape):
return input_shape
if eager:
inputs = keras.Input((3,))
layer = MyLayer()
outputs = layer(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse')
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(loss, 2 * 3)
else:
with self.assertRaisesRegexp(RuntimeError,
'`add_loss` in a control flow branch'):
layer = MyLayer()(keras.Input((3,)))
@parameterized.named_parameters(('eager', True),
('symbolic', False))
def test_conditional_metrics_in_call(self, eager):
class MyLayer(keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__(self, dynamic=eager)
def call(self, inputs, training=None):
if training:
self.add_metric(math_ops.reduce_sum(inputs),
name='sum',
aggregation='mean')
return inputs
def compute_output_shape(self, input_shape):
return input_shape
if eager:
inputs = keras.Input((3,))
layer = MyLayer()
outputs = layer(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse')
history = model.fit(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(history.history['sum'][-1], 2 * 3)
else:
# TODO(fchollet): support the same workflow in graph mode.
with self.assertRaisesRegexp(RuntimeError,
'`add_metric` in a control flow branch'):
layer = MyLayer()(keras.Input((3,)))
_LAYERS_TO_TEST = [
(keras.layers.Dense, (1,), collections.OrderedDict(units=[1])),
(keras.layers.Activation, (2, 2),

View File

@ -30,6 +30,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import variables as tf_variables
@ -571,3 +572,96 @@ def autocast_context_manager(input_list, should_cast):
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
var_read_dtype)
def is_subclassed(layer):
return (layer.__module__.find('keras.engine') == -1 and
layer.__module__.find('keras.layers') == -1)
def check_graph_consistency(tensor, method):
"""Checks that tensors passed to `add_*` method match the Keras graph.
When one of the `add_*` method is called inside a V2 conditional branch,
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
We need to raise clear error messages in such cases.
Arguments:
tensor: Tensor to check.
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
Raises:
RuntimeError: In case of an out-of-graph tensor.
"""
if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'):
if isinstance(tensor.graph,
(control_flow_util_v2.CondBranchFuncGraph,
control_flow_util_v2.WhileCondFuncGraph,
control_flow_util_v2.WhileBodyFuncGraph)):
if method == 'add_metric':
bad_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
else:
metric = 0.
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
elif method == 'add_loss':
bad_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
self.add_loss(loss)
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
else:
loss = 0.
self.add_loss(loss)
return inputs
"""
else:
bad_example = """
def call(self, inputs, training=None):
if training:
self.add_update(self.w.assign_add(1))
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
increment = 1
else:
increment = 0
self.add_update(self.w.assign_add(increment))
return inputs
"""
raise RuntimeError(
'You are using the method `{method}` in a control flow branch '
'in your layer, e.g.:\n{bad_example}\n'
'This is not currently supported. '
'You should either use static control flow (`tf.cond`) '
'or move your call to {method} out of the control flow branch, '
'e.g.:\n{correct_example}\n'
'You can also resolve this by marking your layer '
'as dynamic (eager-only) by passing '
'`dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you '
'to implement static shape inference '
'in the `compute_output_shape(input_shape)` method.'.format(
method=method,
bad_example=bad_example,
correct_example=correct_example))

View File

@ -0,0 +1,138 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dynamic control flow behavior with Keras."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.optimizer_v2 import rmsprop
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
class ControlFlowLayer1(base_layer.Layer):
"""Layer with an `if` condition in call."""
def call(self, inputs):
if math_ops.reduce_sum(inputs) > 0:
return math_ops.sqrt(inputs)
else:
return math_ops.square(inputs)
class ControlFlowLayer2(base_layer.Layer):
"""Layer with a `for` loop in call."""
def call(self, inputs):
samples = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
i = 0
for sample in inputs:
samples = samples.write(i, math_ops.square(sample))
i += 1
return samples.stack()
class NestedControlFlowLayer(base_layer.Layer):
"""Layer nested with a control flow layer."""
def __init__(self, **kwargs):
super(NestedControlFlowLayer, self).__init__(**kwargs)
self.layer = ControlFlowLayer1()
def call(self, inputs):
return self.layer(inputs)
class ControlFlowModel(keras.Model):
"""Model with an `if` condition in call."""
def call(self, inputs):
if math_ops.reduce_sum(inputs) > 0:
return math_ops.sqrt(inputs)
else:
return math_ops.square(inputs)
class NestedControlFlowModel(keras.Model):
"""Model with an `if` condition in call using a control flow layer."""
def __init__(self, **kwargs):
super(NestedControlFlowModel, self).__init__(**kwargs)
self.layer = NestedControlFlowLayer()
def call(self, inputs):
inputs = self.layer(inputs)
if math_ops.reduce_sum(inputs) > 0:
return math_ops.sqrt(inputs)
else:
return math_ops.square(inputs)
class FunctionControlFlowModel(keras.Model):
"""Model with control flow where `call` is wrapped in function already."""
@def_function.function
def call(self, inputs):
if math_ops.reduce_sum(inputs) > 0:
return math_ops.sqrt(inputs)
else:
return math_ops.square(inputs)
@keras_parameterized.run_all_keras_modes
class AutographWrapperTest(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@parameterized.named_parameters(('with_if', ControlFlowLayer1),
('with_for', ControlFlowLayer2),
('nested', NestedControlFlowLayer))
def test_control_flow_layer(self, layer_class):
model = testing_utils.get_model_from_layers([layer_class()],
input_shape=(3,))
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
@parameterized.named_parameters(
('with_if', ControlFlowModel), ('nested', NestedControlFlowModel),
('wrapped_in_function', FunctionControlFlowModel))
def test_control_flow_model(self, model_class):
model = model_class()
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
def test_control_flow_in_deferred_sequential_model(self):
model = keras.Sequential(
[ControlFlowLayer1(),
keras.layers.Dense(3),
ControlFlowLayer2()])
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
if __name__ == '__main__':
test.main()