From 1d8bc7222d341a28f0002589f910d432d2c2add0 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Thu, 21 May 2020 11:50:37 -0700 Subject: [PATCH] Switches keras.backend.placeholder + keras.backend.function to build a keras model when running eagerly (instead of trying to directly lift ops out of a graph into a concretefunction). Allows us to strip most of EagerDefinedExecutionFunction from the keras backend. This has the effect of making keras.backend.placeholder + backend.function use the same codepaths as the rest of Keras. This may have the following impact on user code: - keras.backend.function no longer supports the `updates` argument when eager execution is enabled. - keras.backend.placeholder + keras.backend.function now have the same limitations as TF op layers when manipulating the placeholders directly with tf ops. This means no support outside of a layer for sparse ops & ops that operate on composite tensors. PiperOrigin-RevId: 312711373 Change-Id: Ie4bab440b83ea2becf1c83b83837771eba185ff5 --- tensorflow/python/keras/BUILD | 1 + tensorflow/python/keras/backend.py | 230 ++++++------------ tensorflow/python/keras/backend_test.py | 11 +- .../python/keras/engine/base_layer_utils.py | 5 +- tensorflow/python/keras/engine/input_layer.py | 7 +- .../python/keras/engine/training_utils.py | 2 +- .../keras/layers/tensorflow_op_layer_test.py | 7 +- .../python/keras/layers/wrappers_test.py | 8 +- tensorflow/python/keras/losses_test.py | 4 +- 9 files changed, 113 insertions(+), 162 deletions(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 4cd0af07c74..78e360c8354 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -584,6 +584,7 @@ tf_py_test( deps = [ ":backend", ":combinations", + ":engine", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 11795625d06..d0c3eb03342 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -294,7 +294,6 @@ def clear_session(): global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned global _GRAPH - global _FREEZABLE_VARS _GRAPH.graph = None ops.reset_default_graph() reset_uids() @@ -307,7 +306,6 @@ def clear_session(): _GRAPH_LEARNING_PHASES.setdefault(graph) _GRAPH_VARIABLES.pop(graph, None) _GRAPH_TF_OPTIMIZERS.pop(graph, None) - _FREEZABLE_VARS.pop(graph, None) @keras_export('keras.backend.manual_variable_initialization') @@ -1059,9 +1057,9 @@ def is_keras_tensor(x): >>> tf.keras.backend.is_keras_tensor(keras_var) False >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5)) - >>> # A placeholder is not a Keras tensor. + >>> # A placeholder is a Keras tensor. >>> tf.keras.backend.is_keras_tensor(keras_placeholder) - False + True >>> keras_input = tf.keras.layers.Input([10]) >>> # An Input is a Keras tensor. >>> tf.keras.backend.is_keras_tensor(keras_input) @@ -1144,6 +1142,14 @@ def placeholder(shape=None, expand_composites=True) else: x = array_ops.placeholder(dtype, shape=shape, name=name) + + if context.executing_eagerly(): + # Add keras_history connectivity information to the placeholder + # when the placeholder is built in a top-level eager context + # (intended to be used with keras.backend.function) + from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top + return input_layer.Input(tensor=x) + return x @@ -3379,7 +3385,7 @@ def get_value(x): if ops.executing_eagerly_outside_functions(): # This method of evaluating works inside the Keras FuncGraph. - return function([], x)(x) + return eval_in_eager_or_function(x) with x.graph.as_default(): return x.eval(session=get_session((x,))) @@ -3722,161 +3728,74 @@ class GraphExecutionFunction(object): return nest.map_structure(self._eval_if_composite, output_structure) -class EagerExecutionFunction(object): - """Helper class for constructing a TF graph function from the Keras graph. +def eval_in_eager_or_function(outputs): + """Method to evaluate a tensor in eager or in a tf.function. + + In the case of a tf.function, it will lift the tensor out of the function + and try to evaluate that piece of the graph. + + Warning: Do not add new usages of this function. + TODO(b/150169018): delete this function once _keras_history_helper is no + longer needed, after Keras switches to KerasTensors and op layers + work via dispatch. Arguments: - inputs: Feed placeholders to the computation graph. - outputs: Output tensors to fetch. - updates: Additional update ops to be run at function call. - name: A name to help users identify what this function does. - session_kwargs: Unsupported. + outputs: tensors to fetch. + Returns: + The value of the tensors (as numpy arrays). """ + outputs_structure = outputs + outputs = nest.flatten(outputs, expand_composites=True) - def __init__(self, inputs, outputs, updates=None, name=None): - self.name = name - self._inputs_structure = inputs - inputs = nest.flatten(inputs, expand_composites=True) - self._outputs_structure = outputs - outputs = nest.flatten(outputs, expand_composites=True) + graphs = { + i.graph + for i in nest.flatten([outputs]) + if hasattr(i, 'graph') + } + if len(graphs) > 1: + raise ValueError('Cannot create an execution function which is comprised ' + 'of elements from multiple graphs.') - updates = updates or [] - if not isinstance(updates, (list, tuple)): - raise TypeError('`updates` in a Keras backend function ' - 'should be a list or tuple.') + source_graph = graphs.pop() - if updates and not outputs: - # Edge case; never happens in practice - raise ValueError('Cannot create a Keras backend function with updates' - ' but no outputs during eager execution.') - graphs = { - i.graph - for i in nest.flatten([inputs, outputs, updates]) - if hasattr(i, 'graph') - } - if len(graphs) > 1: - raise ValueError('Cannot create an execution function which is comprised ' - 'of elements from multiple graphs.') - - source_graph = graphs.pop() + with _scratch_graph() as exec_graph: global_graph = get_graph() + if source_graph not in (exec_graph, global_graph): + raise ValueError('Unknown graph. Aborting.') - updates_ops = [] - legacy_update_ops = [] - for update in updates: - # For legacy reasons it is allowed to pass an update as a tuple - # `(variable, new_value)` (this maps to an assign op). Otherwise it - # is assumed to already be an op -- we cannot control its execution - # order. - if isinstance(update, tuple): - legacy_update_ops.append(update) - else: - if hasattr(update, 'op'): - update = update.op - if update is not None: - # `update.op` may have been None in certain cases. - updates_ops.append(update) + if source_graph is global_graph and exec_graph is not global_graph: + init_tensors = outputs + lifted_map = lift_to_graph.lift_to_graph( + tensors=init_tensors, + graph=exec_graph, + sources=[], + add_sources=True, + handle_captures=True, + base_graph=source_graph) - self._freezable_vars_to_feed = [] - self._freezable_vars_values = [] - freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet( - _FREEZABLE_VARS.get(global_graph, {})) - with _scratch_graph() as exec_graph: - global_graph = get_graph() - if source_graph not in (exec_graph, global_graph): - raise ValueError('Unknown graph. Aborting.') + outputs = [lifted_map[i] for i in outputs] - if source_graph is global_graph and exec_graph is not global_graph: - init_tensors = ( - outputs + updates_ops + [p for [p, _] in legacy_update_ops] + - [p_new for [_, p_new] in legacy_update_ops - if isinstance(p_new, ops.Tensor)]) - lifted_map = lift_to_graph.lift_to_graph( - tensors=init_tensors, - graph=exec_graph, - sources=inputs, - add_sources=True, - handle_captures=True, - base_graph=source_graph) + # Consolidate updates + with exec_graph.as_default(): + outputs = cast_variables_to_tensor(outputs) - inputs = [lifted_map[i] for i in inputs] - outputs = [lifted_map[i] for i in outputs] - updates_ops = [lifted_map[i] for i in updates_ops] - legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new)) - for p, p_new in legacy_update_ops] + exec_graph.inputs = exec_graph.internal_captures + exec_graph.outputs = outputs + graph_fn = eager_function.ConcreteFunction(exec_graph) - # Keep track of the value to feed to any "freezable variables" - # created in this graph. - for old_op, new_op in lifted_map.items(): - if old_op in freezable_vars_from_keras_graph: - frozen_var = old_op - if frozen_var._initial_value != frozen_var._current_value: - # We only feed a frozen_variable if its value has changed; - # otherwise it can rely on the default value of the - # underlying placeholder_with_default. - self._freezable_vars_to_feed.append(new_op) - self._freezable_vars_values.append(frozen_var._current_value) + graph_fn._num_positional_args = 0 + graph_fn._arg_keywords = [] - # Consolidate updates - with exec_graph.as_default(): - outputs = cast_variables_to_tensor(outputs) - with ops.control_dependencies(outputs): - for p, p_new in legacy_update_ops: - updates_ops.append(state_ops.assign(p, p_new)) + outputs = graph_fn() - self.inputs, self.outputs = inputs, outputs - self._input_references = self.inputs + self._freezable_vars_to_feed - with ops.control_dependencies(updates_ops): - self.outputs[0] = array_ops.identity(self.outputs[0]) - - exec_graph.inputs = self._input_references + exec_graph.internal_captures - exec_graph.outputs = self.outputs - graph_fn = eager_function.ConcreteFunction(exec_graph) - - graph_fn._num_positional_args = len(self._input_references) - graph_fn._arg_keywords = [] - self._graph_fn = graph_fn - - # Handle placeholders with default - # (treated as required placeholder by graph functions) - self._placeholder_default_values = {} - with exec_graph.as_default(): - for x in self.inputs: - if x.op.type == 'PlaceholderWithDefault': - self._placeholder_default_values[ops.tensor_id( - x)] = tensor_util.constant_value(x.op.inputs[0]) - - def __call__(self, inputs): - input_values = nest.flatten(inputs, expand_composites=True) - - if self._freezable_vars_values: - input_values = input_values + self._freezable_vars_values - converted_inputs = [] - for tensor, value in zip(self._input_references, input_values): - if value is None: - # Assume `value` is a placeholder with default - value = self._placeholder_default_values.get( - ops.tensor_id(tensor), None) - if value is None: - raise ValueError( - 'You must feed a value for placeholder %s' % (tensor,)) - if not isinstance(value, ops.Tensor): - value = ops.convert_to_tensor_v2(value, dtype=tensor.dtype) - if value.dtype != tensor.dtype: - # Temporary workaround due to `convert_to_tensor` not casting floats. - # See b/119637405 - value = math_ops.cast(value, tensor.dtype) - converted_inputs.append(value) - outputs = self._graph_fn(*converted_inputs) - - # EagerTensor.numpy() will often make a copy to ensure memory safety. - # However in this case `outputs` is not directly returned, so it is always - # safe to reuse the underlying buffer without checking. In such a case the - # private numpy conversion method is preferred to guarantee performance. - return nest.pack_sequence_as( - self._outputs_structure, - [x._numpy() for x in outputs], # pylint: disable=protected-access - expand_composites=True) + # EagerTensor.numpy() will often make a copy to ensure memory safety. + # However in this case `outputs` is not directly returned, so it is always + # safe to reuse the underlying buffer without checking. In such a case the + # private numpy conversion method is preferred to guarantee performance. + return nest.pack_sequence_as( + outputs_structure, + [x._numpy() for x in outputs], # pylint: disable=protected-access + expand_composites=True) @keras_export('keras.backend.function') @@ -3900,7 +3819,20 @@ def function(inputs, outputs, updates=None, name=None, **kwargs): if kwargs: raise ValueError('Session keyword arguments are not support during ' 'eager execution. You passed: %s' % (kwargs,)) - return EagerExecutionFunction(inputs, outputs, updates=updates, name=name) + if updates: + raise ValueError('`updates` argument is not support during ' + 'eager execution. You passed: %s' % (updates,)) + from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top + model = models.Model(inputs=inputs, outputs=outputs) + + wrap_outputs = isinstance(outputs, list) and len(outputs) == 1 + def func(model_inputs): + outs = model(model_inputs) + if wrap_outputs: + outs = [outs] + return tf_utils.to_numpy_or_python_type(outs) + return func if kwargs: for key in kwargs: @@ -6344,10 +6276,6 @@ class ContextValueCache(weakref.WeakKeyDictionary): # either train mode (learning_phase == 1) or test mode (learning_phase == 0). _GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase) -# This dictionary holds a mapping {graph: set_of_freezable_variables}. -# Each set tracks objects created via `freezable_variable` in the graph. -_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet) - # This dictionary holds a mapping between a graph and variables to initialize # in the graph. _GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet) diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 1adc20652b2..20547c570c7 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -1677,8 +1677,10 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): t, p, from_logits=True, axis=0), self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3) - @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + @combinations.generate(combinations.combine(mode=['graph'])) def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self): + # This test only runs in graph because the TF op layer is not supported yet + # for sparse ops. t = backend.placeholder() p = backend.placeholder() o = backend.sparse_categorical_crossentropy(t, p) @@ -1870,6 +1872,8 @@ class TestRandomOps(test.TestCase): class FunctionTest(test.TestCase): def test_function_basics(self): + if context.executing_eagerly(): + self.skipTest('eager backend.function does not support updates') x1 = backend.placeholder(shape=(), dtype='float32') x2 = backend.placeholder(shape=(), dtype='int32') v = backend.variable(10.) @@ -1916,6 +1920,9 @@ class FunctionTest(test.TestCase): self.assertEqual(result, 4.) def test_tuple_updates(self): + if context.executing_eagerly(): + self.skipTest('eager backend.function does not support updates') + x_ph = backend.placeholder(ndim=2) v = backend.variable(np.ones((4, 2))) output = x_ph ** 2 + v @@ -1929,7 +1936,7 @@ class FunctionTest(test.TestCase): class BackendGraphTests(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + @combinations.generate(combinations.combine(mode=['graph'])) def test_function_placeholder_with_default(self): with backend.get_graph().as_default(): x1 = array_ops.placeholder_with_default( diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 5980eeaf115..7e4e0e5da4a 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -248,7 +248,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): constants[i] = op_input else: with ops.init_scope(): - constants[i] = backend.function([], op_input)([]) + if ops.executing_eagerly_outside_functions(): + constants[i] = backend.eval_in_eager_or_function(op_input) + else: + constants[i] = backend.function([], op_input)([]) layer_inputs = unnest_if_single_tensor(layer_inputs) processed_ops, created_layers = _create_keras_history_helper( layer_inputs, processed_ops, created_layers) diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index ed715f61897..1fa380815fc 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -161,8 +161,11 @@ class InputLayer(base_layer.Layer): 'InputLayer, you should instantiate your model and ' 'directly call it on your input.') self.is_placeholder = False - self._batch_input_shape = tuple(input_tensor.shape.as_list()) - + try: + self._batch_input_shape = tuple(input_tensor.shape.as_list()) + except ValueError: + # If the shape cannot be represented as a tuple (e.g. unknown rank) + self._batch_input_shape = None # Create an input node. input_tensor._keras_mask = None node_module.Node(layer=self, outputs=input_tensor) diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 680f33f75a5..0d7637cb98c 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -1935,7 +1935,7 @@ def get_input_shape_and_dtype(layer): raise ValueError('An empty Model cannot be used as a Layer.') layer = layer.layers[0] - if hasattr(layer, '_batch_input_shape'): + if getattr(layer, '_batch_input_shape', None): return layer._batch_input_shape, layer.dtype return None, None diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index 73e395f5715..1a328995a80 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -288,9 +288,10 @@ class AutoLambdaTest(keras_parameterized.TestCase): constant_op.constant(40.0, shape=(1, 1))) def test_no_tracking(self): - x = keras.backend.placeholder((10, 10)) - keras.layers.Dense(1)(x) - self.assertTrue(x._keras_history_checked) + if not context.executing_eagerly(): + x = constant_op.constant(1.0, shape=(10, 10)) + keras.layers.Dense(1)(x) + self.assertTrue(x._keras_history_checked) def test_timing_scales_linearly(self): diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index bb22db25591..a73177fff12 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -33,6 +33,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops @@ -1213,9 +1214,14 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): f_merged = keras.backend.function([inputs], layer(inputs)) f_forward = keras.backend.function([inputs], layer.forward_layer(inputs)) + + # TODO(kaftan): after KerasTensor refactor TF op layers should work + # with many composite tensors, and this shouldn't need to be a lambda + # layer. + reverse_layer = core.Lambda(array_ops.reverse, arguments=dict(axis=[1])) f_backward = keras.backend.function( [inputs], - array_ops.reverse(layer.backward_layer(inputs), axis=[1])) + reverse_layer(layer.backward_layer(inputs))) y_merged = f_merged(x) y_expected = merge_func( diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 574d3d3f756..26a586b872b 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -125,8 +125,10 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase): backend.eval(output_from_softmax), atol=1e-5) - @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + @combinations.generate(combinations.combine(mode=['graph'])) def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self): + # This test only runs in graph because the TF op layer is not supported yet + # for sparse ops. t = backend.placeholder() p = backend.placeholder() o = losses.sparse_categorical_crossentropy(t, p)