From cf31e9001c5ab89fc2e92eb98443d48cd37a726b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 24 Apr 2019 17:48:17 -0700 Subject: [PATCH] Reenable autograph in `Layer.__call__` and add safeguards to prevent users from writing invalid static control flow. PiperOrigin-RevId: 245151330 --- tensorflow/python/keras/BUILD | 13 + tensorflow/python/keras/engine/base_layer.py | 63 +++- .../python/keras/engine/base_layer_test.py | 329 +++++++++++++----- .../python/keras/engine/base_layer_utils.py | 94 +++++ .../python/keras/engine/control_flow_test.py | 138 ++++++++ 5 files changed, 541 insertions(+), 96 deletions(-) create mode 100644 tensorflow/python/keras/engine/control_flow_test.py diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index a8b33b9e798..81269ad81a7 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1357,6 +1357,19 @@ tf_py_test( tags = ["no_rocm"], ) +tf_py_test( + name = "control_flow_test", + size = "medium", + srcs = ["engine/control_flow_test.py"], + additional_deps = [ + ":keras", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + ], + shard_count = 8, +) + tf_py_test( name = "hdf5_format_test", size = "medium", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 991589bf87c..3ca5ca09061 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.core.framework import node_def_pb2 +from tensorflow.python import autograph from tensorflow.python.distribute import values as distribute_values from tensorflow.python.eager import context from tensorflow.python.eager import execute @@ -594,6 +595,21 @@ class Layer(module.Module): # Build layer if applicable (if the `build` method has been # overridden). self._maybe_build(inputs) + + # Wrapping `call` function in autograph to allow for dynamic control + # dependencies in call. We are limiting this to subclassed layers as + # autograph is strictly needed only for subclassed layers. + if base_layer_utils.is_subclassed(self): + decorators, original_func = tf_decorator.unwrap(self.call) + converted_func = autograph.convert(recursive=True)(original_func) + if decorators: + call_fn = tf_decorator.rewrap(self.call, original_func, + converted_func) + else: + call_fn = converted_func + else: + call_fn = self.call + # Explicitly pass the learning phase placeholder to `call` if # the `training` argument was left unspecified by the user. # This behavior is restricted to the managed Keras FuncGraph. @@ -613,20 +629,18 @@ class Layer(module.Module): self._mixed_precision_policy.should_cast_variables), ( base_layer_utils.AutoAddUpdates(self, inputs)) as auto_updater: - outputs = self.call(inputs, *args, **kwargs) + outputs = call_fn(inputs, *args, **kwargs) auto_updater.set_outputs(outputs) except TypeError as e: - messages = ('`tf.Tensor` as a Python `bool` is not allowed', - 'Tensor objects are only iterable when eager') exception_str = str(e) - for msg in messages: - if msg in exception_str: - raise TypeError('You are attempting to use Python control ' - 'flow in a layer that was not declared to be ' - 'dynamic. Pass `dynamic=True` to the class ' - 'constructor.\nEncountered error:\n"""\n' + - exception_str + '\n"""') + exception_msg = 'Tensor objects are only iterable when eager' + if exception_msg in exception_str: + raise TypeError('You are attempting to use Python control ' + 'flow in a layer that was not declared to be ' + 'dynamic. Pass `dynamic=True` to the class ' + 'constructor.\nEncountered error:\n"""\n' + + exception_str + '\n"""') raise else: # We will use static shape inference to return symbolic tensors @@ -775,7 +789,7 @@ class Layer(module.Module): class MyLayer(tf.keras.layers.Layer): def call(inputs, self): self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True) - return 2*inputs + return inputs ``` This method can also be called directly on a Functional Model during @@ -831,6 +845,7 @@ class Layer(module.Module): return None # Will be filtered out when computing the .losses property if not tensor_util.is_tensor(loss): loss = ops.convert_to_tensor(loss, dtype=backend.floatx()) + base_layer_utils.check_graph_consistency(loss, method='add_loss') loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access return loss @@ -842,12 +857,15 @@ class Layer(module.Module): for loss in losses: if callable(loss): callable_losses.append(functools.partial(_tag_unconditional, loss)) - elif tf_utils.is_symbolic_tensor(loss): + continue + if loss is None: + continue + if not tensor_util.is_tensor(loss): + loss = ops.convert_to_tensor(loss, dtype=backend.floatx()) + if tf_utils.is_symbolic_tensor(loss): symbolic_losses.append(_tag_unconditional(loss)) elif tensor_util.is_tensor(loss): eager_losses.append(_tag_unconditional(loss)) - elif loss is not None: # `None` is valid but should be ignored. - raise ValueError('Found non-Tensor loss: ' + str(loss)) self._callable_losses += callable_losses @@ -1005,14 +1023,24 @@ class Layer(module.Module): return # Updates already applied when in eager mode. def process_update(x): + """Standardize update ops. + + Arguments: + x: Tensor, op, or callable. + + Returns: + An update op. + """ if callable(x): x = x() if isinstance(x, ops.Operation): - return x + update = x elif hasattr(x, 'op'): - return x.op + update = x.op else: - return ops.convert_to_tensor(x) + update = ops.convert_to_tensor(x) + base_layer_utils.check_graph_consistency(update, method='add_update') + return update updates = [process_update(x) for x in updates] self._updates += updates @@ -1502,6 +1530,7 @@ class Layer(module.Module): self._metrics.append(metric_obj) def _symbolic_add_metric(self, value, aggregation=None, name=None): + base_layer_utils.check_graph_consistency(value, method='add_metric') match = self._get_existing_metric(name) if aggregation is None: # Iterate over the metrics and check if the given metric exists already. diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 81921d70b90..8d2ced75533 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -28,6 +28,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -40,35 +41,22 @@ from tensorflow.python.layers import core as legacy_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -class DynamicLayer1(base_layer.Layer): +class DynamicLayer(base_layer.Layer): def __init__(self, dynamic=False, **kwargs): - super(DynamicLayer1, self).__init__(dynamic=dynamic, **kwargs) + super(DynamicLayer, self).__init__(dynamic=dynamic, **kwargs) def call(self, inputs): - if math_ops.reduce_sum(inputs) > 0: - return math_ops.sqrt(inputs) - else: - return math_ops.square(inputs) - - def compute_output_shape(self, input_shape): - return input_shape - - -class DynamicLayer2(base_layer.Layer): - - def __init__(self, dynamic=False, **kwargs): - super(DynamicLayer2, self).__init__(dynamic=dynamic, **kwargs) - - def call(self, inputs): - samples = [] - for sample in inputs: - samples.append(math_ops.square(sample)) - return array_ops.stack(samples, axis=0) + samples = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) + for idx, sample in enumerate(inputs): + samples = samples.write(idx, math_ops.square(sample)) + return samples.stack() def compute_output_shape(self, input_shape): return input_shape @@ -82,34 +70,39 @@ class InvalidLayer(base_layer.Layer): class BaseLayerTest(keras_parameterized.TestCase): - @parameterized.parameters(DynamicLayer1, DynamicLayer2) - def test_dynamic_layer_in_functional_model_in_graph_mode(self, layer_class): + @keras_parameterized.run_with_all_model_types + def test_dynamic_layer(self): + model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)], + input_shape=(3,)) + self.assertEqual(model.dynamic, True) + model.compile(rmsprop.RMSprop(0.001), loss='mse') + self.assertEqual(model.run_eagerly, True) + model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) + + @keras_parameterized.run_with_all_model_types + def test_dynamic_layer_error(self): + with self.assertRaisesRegexp(TypeError, + 'attempting to use Python control flow'): + model = testing_utils.get_model_from_layers([DynamicLayer()], + input_shape=(3,)) + model.compile(rmsprop.RMSprop(0.001), loss='mse') + model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) + + @keras_parameterized.run_with_all_model_types + def test_dynamic_layer_error_running_in_graph_mode(self): with context.graph_mode(): - inputs = keras.Input((3,)) - # Works when `dynamic=True` is declared. - outputs = layer_class(dynamic=True)(inputs) - model = keras.Model(inputs, outputs) + model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)], + input_shape=(3,)) self.assertEqual(model.dynamic, True) # But then you cannot run the model since you're in a graph scope. with self.assertRaisesRegexp( ValueError, 'You must enable eager execution'): model.compile(rmsprop.RMSprop(0.001), loss='mse') - # Fails when `dynamic=True` not declared. - with self.assertRaisesRegexp( - TypeError, 'attempting to use Python control flow'): - _ = layer_class()(inputs) - - @parameterized.parameters(DynamicLayer1, DynamicLayer2) - def test_dynamic_layer_in_functional_model_in_eager_mode(self, layer_class): - inputs = keras.Input((3,)) - # Fails when `dynamic=True` not declared. - with self.assertRaisesRegexp( - TypeError, 'attempting to use Python control flow'): - _ = layer_class()(inputs) - # Works when `dynamic=True` is declared. - outputs = layer_class(dynamic=True)(inputs) - model = keras.Model(inputs, outputs) + def test_dynamic_layer_with_deferred_sequential_model(self): + model = keras.Sequential( + [DynamicLayer(dynamic=True), + keras.layers.Dense(3)]) self.assertEqual(model.dynamic, True) model.compile(rmsprop.RMSprop(0.001), loss='mse') self.assertEqual(model.run_eagerly, True) @@ -117,12 +110,12 @@ class BaseLayerTest(keras_parameterized.TestCase): def test_nested_dynamic_layers_in_eager_mode(self): inputs = keras.Input((3,)) - outputs = DynamicLayer1(dynamic=True)(inputs) + outputs = DynamicLayer(dynamic=True)(inputs) inner_model = keras.Model(inputs, outputs) self.assertEqual(inner_model.dynamic, True) inputs = keras.Input((3,)) - x = DynamicLayer2(dynamic=True)(inputs) + x = DynamicLayer(dynamic=True)(inputs) outputs = inner_model(x) model = keras.Model(inputs, outputs) @@ -131,41 +124,6 @@ class BaseLayerTest(keras_parameterized.TestCase): self.assertEqual(model.run_eagerly, True) model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) - def test_dynamic_layers_in_sequential_model(self): - # Without input_shape argument - model = keras.Sequential([DynamicLayer1(dynamic=True), - keras.layers.Dense(3), - DynamicLayer2(dynamic=True)]) - self.assertEqual(model.dynamic, True) - model.compile(rmsprop.RMSprop(0.001), loss='mse') - self.assertEqual(model.run_eagerly, True) - model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) - - # With input_shape argument - model = keras.Sequential([DynamicLayer1(dynamic=True, input_shape=(3,)), - DynamicLayer2(dynamic=True)]) - self.assertEqual(model.dynamic, True) - model.compile(rmsprop.RMSprop(0.001), loss='mse') - self.assertEqual(model.run_eagerly, True) - model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) - - def test_dynamic_layers_in_subclassed_model(self): - - class MyModel(keras.Model): - - def __init__(self): - super(MyModel, self).__init__() - self.layer1 = DynamicLayer1(dynamic=True) - - def call(self, inputs): - return self.layer1(inputs) - - model = MyModel() - self.assertEqual(model.dynamic, True) - model.compile(rmsprop.RMSprop(0.001), loss='mse') - self.assertEqual(model.run_eagerly, True) - model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) - def test_dynamic_subclassed_model_no_shape_inference(self): class MyModel(keras.Model): @@ -213,6 +171,24 @@ class BaseLayerTest(keras_parameterized.TestCase): model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) self.assertEqual(model.outputs[0].shape.as_list(), [None, 3]) + @keras_parameterized.run_all_keras_modes + def test_add_loss_correctness(self): + + class MyLayer(keras.layers.Layer): + + def call(self, inputs, training=None): + self.add_loss(math_ops.reduce_sum(inputs)) + return inputs + + inputs = keras.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = keras.Model(inputs, outputs) + self.assertEqual(len(model.losses), 1) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(loss, 2 * 3) + @test_util.run_in_graph_and_eager_modes def test_invalid_forward_pass(self): inputs = keras.Input((3,)) @@ -676,6 +652,201 @@ class NameScopingTest(keras_parameterized.TestCase): self.assertEqual(layer.kernel.name, 'MyName3/kernel:0') +class AutographControlFlowTest(keras_parameterized.TestCase): + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_if_training_pattern_output(self, eager): + + class MyLayer(keras.layers.Layer): + + def call(self, inputs, training=None): + if training: + return inputs * 1. + return inputs * 0. + + inputs = keras.Input((3,)) + outputs = MyLayer()(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=eager) + train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(train_loss, 0.) + test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(test_loss, 1.) + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_if_training_pattern_loss(self, eager): + + class MyLayer(keras.layers.Layer): + + def call(self, inputs, training=None): + if training: + loss = math_ops.reduce_sum(inputs) + else: + loss = 0. + self.add_loss(loss) + return inputs + + inputs = keras.Input((3,)) + outputs = MyLayer()(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=eager) + train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(train_loss, 2 * 3) + test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(test_loss, 0) + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_if_training_pattern_metric(self, eager): + + class MyLayer(keras.layers.Layer): + + def call(self, inputs, training=None): + if training: + metric = math_ops.reduce_sum(inputs) + else: + metric = 0. + self.add_metric(metric, name='my_metric', aggregation='mean') + return inputs + + inputs = keras.Input((3,)) + outputs = MyLayer()(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=eager) + _, train_metric = model.train_on_batch(np.ones((2, 3)), + np.ones((2, 3))) + self.assertEqual(train_metric, 2 * 3) + _, test_metric = model.test_on_batch(np.ones((2, 3)), + np.ones((2, 3))) + self.assertEqual(test_metric, 0) + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_if_training_pattern_update(self, eager): + + class MyLayer(keras.layers.Layer): + + def build(self, input_shape): + self.counter = self.add_weight( + shape=(), trainable=False, initializer='zeros') + + def call(self, inputs, training=None): + if training: + increment = 1. + else: + increment = 0. + self.counter.assign_add(increment) + return inputs + + inputs = keras.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=eager) + model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(keras.backend.get_value(layer.counter), 1.) + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_conditional_updates_in_call(self, eager): + + class MyLayer(keras.layers.Layer): + + def __init__(self): + super(MyLayer, self).__init__(self, dynamic=eager) + + def build(self, input_shape): + self.counter = self.add_weight( + shape=(), trainable=False, initializer='zeros') + + def call(self, inputs, training=None): + if training: + self.add_update(self.counter.assign_add(math_ops.reduce_sum(inputs))) + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + if eager: + inputs = keras.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse', run_eagerly=eager) + model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(keras.backend.get_value(layer.counter), 6.) + else: + # TODO(fchollet): support the same workflow in graph mode. + with self.assertRaisesRegexp(RuntimeError, + '`add_update` in a control flow branch'): + layer = MyLayer()(keras.Input((3,))) + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_conditional_losses_in_call(self, eager): + + class MyLayer(keras.layers.Layer): + + def __init__(self): + super(MyLayer, self).__init__(self, dynamic=eager) + + def call(self, inputs, training=None): + if training: + self.add_loss(math_ops.reduce_sum(inputs)) + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + if eager: + inputs = keras.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse') + loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(loss, 2 * 3) + else: + with self.assertRaisesRegexp(RuntimeError, + '`add_loss` in a control flow branch'): + layer = MyLayer()(keras.Input((3,))) + + @parameterized.named_parameters(('eager', True), + ('symbolic', False)) + def test_conditional_metrics_in_call(self, eager): + + class MyLayer(keras.layers.Layer): + + def __init__(self): + super(MyLayer, self).__init__(self, dynamic=eager) + + def call(self, inputs, training=None): + if training: + self.add_metric(math_ops.reduce_sum(inputs), + name='sum', + aggregation='mean') + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + if eager: + inputs = keras.Input((3,)) + layer = MyLayer() + outputs = layer(inputs) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse') + history = model.fit(np.ones((2, 3)), np.ones((2, 3))) + self.assertEqual(history.history['sum'][-1], 2 * 3) + else: + # TODO(fchollet): support the same workflow in graph mode. + with self.assertRaisesRegexp(RuntimeError, + '`add_metric` in a control flow branch'): + layer = MyLayer()(keras.Input((3,))) + + _LAYERS_TO_TEST = [ (keras.layers.Dense, (1,), collections.OrderedDict(units=[1])), (keras.layers.Activation, (2, 2), diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index eea8227f8f3..b97326eea6a 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -30,6 +30,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops_v2 from tensorflow.python.ops import variables as tf_variables @@ -571,3 +572,96 @@ def autocast_context_manager(input_list, should_cast): var_read_dtype = _get_var_read_dtype(input_list, should_cast) return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access var_read_dtype) + + +def is_subclassed(layer): + return (layer.__module__.find('keras.engine') == -1 and + layer.__module__.find('keras.layers') == -1) + + +def check_graph_consistency(tensor, method): + """Checks that tensors passed to `add_*` method match the Keras graph. + + When one of the `add_*` method is called inside a V2 conditional branch, + the underlying tensor gets created in a FuncGraph managed by control_flow_v2. + We need to raise clear error messages in such cases. + + Arguments: + tensor: Tensor to check. + method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}. + + Raises: + RuntimeError: In case of an out-of-graph tensor. + """ + if ops.executing_eagerly_outside_functions() and hasattr(tensor, 'graph'): + if isinstance(tensor.graph, + (control_flow_util_v2.CondBranchFuncGraph, + control_flow_util_v2.WhileCondFuncGraph, + control_flow_util_v2.WhileBodyFuncGraph)): + if method == 'add_metric': + bad_example = """ + def call(self, inputs, training=None): + if training: + metric = compute_metric(inputs) + self.add_metric(metric, name='my_metric', aggregation='mean') + return inputs + """ + correct_example = """ + def call(self, inputs, training=None): + if training: + metric = compute_metric(inputs) + else: + metric = 0. + self.add_metric(metric, name='my_metric', aggregation='mean') + return inputs + """ + elif method == 'add_loss': + bad_example = """ + def call(self, inputs, training=None): + if training: + loss = compute_loss(inputs) + self.add_loss(loss) + return inputs + """ + correct_example = """ + def call(self, inputs, training=None): + if training: + loss = compute_loss(inputs) + else: + loss = 0. + self.add_loss(loss) + return inputs + """ + else: + bad_example = """ + def call(self, inputs, training=None): + if training: + self.add_update(self.w.assign_add(1)) + return inputs + """ + correct_example = """ + def call(self, inputs, training=None): + if training: + increment = 1 + else: + increment = 0 + self.add_update(self.w.assign_add(increment)) + return inputs + """ + raise RuntimeError( + 'You are using the method `{method}` in a control flow branch ' + 'in your layer, e.g.:\n{bad_example}\n' + 'This is not currently supported. ' + 'You should either use static control flow (`tf.cond`) ' + 'or move your call to {method} out of the control flow branch, ' + 'e.g.:\n{correct_example}\n' + 'You can also resolve this by marking your layer ' + 'as dynamic (eager-only) by passing ' + '`dynamic=True` to the layer constructor. ' + 'Any kind of control flow is supported with dynamic layers. ' + 'Note that using `dynamic=True` requires you ' + 'to implement static shape inference ' + 'in the `compute_output_shape(input_shape)` method.'.format( + method=method, + bad_example=bad_example, + correct_example=correct_example)) diff --git a/tensorflow/python/keras/engine/control_flow_test.py b/tensorflow/python/keras/engine/control_flow_test.py new file mode 100644 index 00000000000..b91f0362cc8 --- /dev/null +++ b/tensorflow/python/keras/engine/control_flow_test.py @@ -0,0 +1,138 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dynamic control flow behavior with Keras.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.optimizer_v2 import rmsprop +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class ControlFlowLayer1(base_layer.Layer): + """Layer with an `if` condition in call.""" + + def call(self, inputs): + if math_ops.reduce_sum(inputs) > 0: + return math_ops.sqrt(inputs) + else: + return math_ops.square(inputs) + + +class ControlFlowLayer2(base_layer.Layer): + """Layer with a `for` loop in call.""" + + def call(self, inputs): + samples = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) + i = 0 + for sample in inputs: + samples = samples.write(i, math_ops.square(sample)) + i += 1 + return samples.stack() + + +class NestedControlFlowLayer(base_layer.Layer): + """Layer nested with a control flow layer.""" + + def __init__(self, **kwargs): + super(NestedControlFlowLayer, self).__init__(**kwargs) + self.layer = ControlFlowLayer1() + + def call(self, inputs): + return self.layer(inputs) + + +class ControlFlowModel(keras.Model): + """Model with an `if` condition in call.""" + + def call(self, inputs): + if math_ops.reduce_sum(inputs) > 0: + return math_ops.sqrt(inputs) + else: + return math_ops.square(inputs) + + +class NestedControlFlowModel(keras.Model): + """Model with an `if` condition in call using a control flow layer.""" + + def __init__(self, **kwargs): + super(NestedControlFlowModel, self).__init__(**kwargs) + self.layer = NestedControlFlowLayer() + + def call(self, inputs): + inputs = self.layer(inputs) + if math_ops.reduce_sum(inputs) > 0: + return math_ops.sqrt(inputs) + else: + return math_ops.square(inputs) + + +class FunctionControlFlowModel(keras.Model): + """Model with control flow where `call` is wrapped in function already.""" + + @def_function.function + def call(self, inputs): + if math_ops.reduce_sum(inputs) > 0: + return math_ops.sqrt(inputs) + else: + return math_ops.square(inputs) + + +@keras_parameterized.run_all_keras_modes +class AutographWrapperTest(keras_parameterized.TestCase): + + @keras_parameterized.run_with_all_model_types + @parameterized.named_parameters(('with_if', ControlFlowLayer1), + ('with_for', ControlFlowLayer2), + ('nested', NestedControlFlowLayer)) + def test_control_flow_layer(self, layer_class): + model = testing_utils.get_model_from_layers([layer_class()], + input_shape=(3,)) + model.compile(rmsprop.RMSprop(0.001), loss='mse') + model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) + + @parameterized.named_parameters( + ('with_if', ControlFlowModel), ('nested', NestedControlFlowModel), + ('wrapped_in_function', FunctionControlFlowModel)) + def test_control_flow_model(self, model_class): + model = model_class() + model.compile(rmsprop.RMSprop(0.001), loss='mse') + model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) + + def test_control_flow_in_deferred_sequential_model(self): + model = keras.Sequential( + [ControlFlowLayer1(), + keras.layers.Dense(3), + ControlFlowLayer2()]) + model.compile(rmsprop.RMSprop(0.001), loss='mse') + model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3))) + + +if __name__ == '__main__': + test.main()