Switches keras.backend.placeholder + keras.backend.function to build a keras model when running eagerly (instead of trying to directly lift ops out of a graph into a concretefunction). Allows us to strip most of EagerDefinedExecutionFunction from the keras backend.

This has the effect of making keras.backend.placeholder + backend.function use the same codepaths as the rest of Keras.

This may have the following impact on user code:
- keras.backend.function no longer supports the `updates` argument when eager execution is enabled.
- keras.backend.placeholder + keras.backend.function now have the same limitations as TF op layers when manipulating the placeholders directly with tf ops. This means  no support outside of a layer for sparse ops & ops that operate on composite tensors.

PiperOrigin-RevId: 312711373
Change-Id: Ie4bab440b83ea2becf1c83b83837771eba185ff5
This commit is contained in:
Tomer Kaftan 2020-05-21 11:50:37 -07:00 committed by TensorFlower Gardener
parent 5d3ad40d11
commit 1d8bc7222d
9 changed files with 113 additions and 162 deletions

View File

@ -584,6 +584,7 @@ tf_py_test(
deps = [ deps = [
":backend", ":backend",
":combinations", ":combinations",
":engine",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",

View File

@ -294,7 +294,6 @@ def clear_session():
global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
global _GRAPH global _GRAPH
global _FREEZABLE_VARS
_GRAPH.graph = None _GRAPH.graph = None
ops.reset_default_graph() ops.reset_default_graph()
reset_uids() reset_uids()
@ -307,7 +306,6 @@ def clear_session():
_GRAPH_LEARNING_PHASES.setdefault(graph) _GRAPH_LEARNING_PHASES.setdefault(graph)
_GRAPH_VARIABLES.pop(graph, None) _GRAPH_VARIABLES.pop(graph, None)
_GRAPH_TF_OPTIMIZERS.pop(graph, None) _GRAPH_TF_OPTIMIZERS.pop(graph, None)
_FREEZABLE_VARS.pop(graph, None)
@keras_export('keras.backend.manual_variable_initialization') @keras_export('keras.backend.manual_variable_initialization')
@ -1059,9 +1057,9 @@ def is_keras_tensor(x):
>>> tf.keras.backend.is_keras_tensor(keras_var) >>> tf.keras.backend.is_keras_tensor(keras_var)
False False
>>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5)) >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
>>> # A placeholder is not a Keras tensor. >>> # A placeholder is a Keras tensor.
>>> tf.keras.backend.is_keras_tensor(keras_placeholder) >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
False True
>>> keras_input = tf.keras.layers.Input([10]) >>> keras_input = tf.keras.layers.Input([10])
>>> # An Input is a Keras tensor. >>> # An Input is a Keras tensor.
>>> tf.keras.backend.is_keras_tensor(keras_input) >>> tf.keras.backend.is_keras_tensor(keras_input)
@ -1144,6 +1142,14 @@ def placeholder(shape=None,
expand_composites=True) expand_composites=True)
else: else:
x = array_ops.placeholder(dtype, shape=shape, name=name) x = array_ops.placeholder(dtype, shape=shape, name=name)
if context.executing_eagerly():
# Add keras_history connectivity information to the placeholder
# when the placeholder is built in a top-level eager context
# (intended to be used with keras.backend.function)
from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top
return input_layer.Input(tensor=x)
return x return x
@ -3379,7 +3385,7 @@ def get_value(x):
if ops.executing_eagerly_outside_functions(): if ops.executing_eagerly_outside_functions():
# This method of evaluating works inside the Keras FuncGraph. # This method of evaluating works inside the Keras FuncGraph.
return function([], x)(x) return eval_in_eager_or_function(x)
with x.graph.as_default(): with x.graph.as_default():
return x.eval(session=get_session((x,))) return x.eval(session=get_session((x,)))
@ -3722,161 +3728,74 @@ class GraphExecutionFunction(object):
return nest.map_structure(self._eval_if_composite, output_structure) return nest.map_structure(self._eval_if_composite, output_structure)
class EagerExecutionFunction(object): def eval_in_eager_or_function(outputs):
"""Helper class for constructing a TF graph function from the Keras graph. """Method to evaluate a tensor in eager or in a tf.function.
In the case of a tf.function, it will lift the tensor out of the function
and try to evaluate that piece of the graph.
Warning: Do not add new usages of this function.
TODO(b/150169018): delete this function once _keras_history_helper is no
longer needed, after Keras switches to KerasTensors and op layers
work via dispatch.
Arguments: Arguments:
inputs: Feed placeholders to the computation graph. outputs: tensors to fetch.
outputs: Output tensors to fetch. Returns:
updates: Additional update ops to be run at function call. The value of the tensors (as numpy arrays).
name: A name to help users identify what this function does.
session_kwargs: Unsupported.
""" """
outputs_structure = outputs
outputs = nest.flatten(outputs, expand_composites=True)
def __init__(self, inputs, outputs, updates=None, name=None): graphs = {
self.name = name i.graph
self._inputs_structure = inputs for i in nest.flatten([outputs])
inputs = nest.flatten(inputs, expand_composites=True) if hasattr(i, 'graph')
self._outputs_structure = outputs }
outputs = nest.flatten(outputs, expand_composites=True) if len(graphs) > 1:
raise ValueError('Cannot create an execution function which is comprised '
'of elements from multiple graphs.')
updates = updates or [] source_graph = graphs.pop()
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a Keras backend function '
'should be a list or tuple.')
if updates and not outputs: with _scratch_graph() as exec_graph:
# Edge case; never happens in practice
raise ValueError('Cannot create a Keras backend function with updates'
' but no outputs during eager execution.')
graphs = {
i.graph
for i in nest.flatten([inputs, outputs, updates])
if hasattr(i, 'graph')
}
if len(graphs) > 1:
raise ValueError('Cannot create an execution function which is comprised '
'of elements from multiple graphs.')
source_graph = graphs.pop()
global_graph = get_graph() global_graph = get_graph()
if source_graph not in (exec_graph, global_graph):
raise ValueError('Unknown graph. Aborting.')
updates_ops = [] if source_graph is global_graph and exec_graph is not global_graph:
legacy_update_ops = [] init_tensors = outputs
for update in updates: lifted_map = lift_to_graph.lift_to_graph(
# For legacy reasons it is allowed to pass an update as a tuple tensors=init_tensors,
# `(variable, new_value)` (this maps to an assign op). Otherwise it graph=exec_graph,
# is assumed to already be an op -- we cannot control its execution sources=[],
# order. add_sources=True,
if isinstance(update, tuple): handle_captures=True,
legacy_update_ops.append(update) base_graph=source_graph)
else:
if hasattr(update, 'op'):
update = update.op
if update is not None:
# `update.op` may have been None in certain cases.
updates_ops.append(update)
self._freezable_vars_to_feed = [] outputs = [lifted_map[i] for i in outputs]
self._freezable_vars_values = []
freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet(
_FREEZABLE_VARS.get(global_graph, {}))
with _scratch_graph() as exec_graph:
global_graph = get_graph()
if source_graph not in (exec_graph, global_graph):
raise ValueError('Unknown graph. Aborting.')
if source_graph is global_graph and exec_graph is not global_graph: # Consolidate updates
init_tensors = ( with exec_graph.as_default():
outputs + updates_ops + [p for [p, _] in legacy_update_ops] + outputs = cast_variables_to_tensor(outputs)
[p_new for [_, p_new] in legacy_update_ops
if isinstance(p_new, ops.Tensor)])
lifted_map = lift_to_graph.lift_to_graph(
tensors=init_tensors,
graph=exec_graph,
sources=inputs,
add_sources=True,
handle_captures=True,
base_graph=source_graph)
inputs = [lifted_map[i] for i in inputs] exec_graph.inputs = exec_graph.internal_captures
outputs = [lifted_map[i] for i in outputs] exec_graph.outputs = outputs
updates_ops = [lifted_map[i] for i in updates_ops] graph_fn = eager_function.ConcreteFunction(exec_graph)
legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
for p, p_new in legacy_update_ops]
# Keep track of the value to feed to any "freezable variables" graph_fn._num_positional_args = 0
# created in this graph. graph_fn._arg_keywords = []
for old_op, new_op in lifted_map.items():
if old_op in freezable_vars_from_keras_graph:
frozen_var = old_op
if frozen_var._initial_value != frozen_var._current_value:
# We only feed a frozen_variable if its value has changed;
# otherwise it can rely on the default value of the
# underlying placeholder_with_default.
self._freezable_vars_to_feed.append(new_op)
self._freezable_vars_values.append(frozen_var._current_value)
# Consolidate updates outputs = graph_fn()
with exec_graph.as_default():
outputs = cast_variables_to_tensor(outputs)
with ops.control_dependencies(outputs):
for p, p_new in legacy_update_ops:
updates_ops.append(state_ops.assign(p, p_new))
self.inputs, self.outputs = inputs, outputs # EagerTensor.numpy() will often make a copy to ensure memory safety.
self._input_references = self.inputs + self._freezable_vars_to_feed # However in this case `outputs` is not directly returned, so it is always
with ops.control_dependencies(updates_ops): # safe to reuse the underlying buffer without checking. In such a case the
self.outputs[0] = array_ops.identity(self.outputs[0]) # private numpy conversion method is preferred to guarantee performance.
return nest.pack_sequence_as(
exec_graph.inputs = self._input_references + exec_graph.internal_captures outputs_structure,
exec_graph.outputs = self.outputs [x._numpy() for x in outputs], # pylint: disable=protected-access
graph_fn = eager_function.ConcreteFunction(exec_graph) expand_composites=True)
graph_fn._num_positional_args = len(self._input_references)
graph_fn._arg_keywords = []
self._graph_fn = graph_fn
# Handle placeholders with default
# (treated as required placeholder by graph functions)
self._placeholder_default_values = {}
with exec_graph.as_default():
for x in self.inputs:
if x.op.type == 'PlaceholderWithDefault':
self._placeholder_default_values[ops.tensor_id(
x)] = tensor_util.constant_value(x.op.inputs[0])
def __call__(self, inputs):
input_values = nest.flatten(inputs, expand_composites=True)
if self._freezable_vars_values:
input_values = input_values + self._freezable_vars_values
converted_inputs = []
for tensor, value in zip(self._input_references, input_values):
if value is None:
# Assume `value` is a placeholder with default
value = self._placeholder_default_values.get(
ops.tensor_id(tensor), None)
if value is None:
raise ValueError(
'You must feed a value for placeholder %s' % (tensor,))
if not isinstance(value, ops.Tensor):
value = ops.convert_to_tensor_v2(value, dtype=tensor.dtype)
if value.dtype != tensor.dtype:
# Temporary workaround due to `convert_to_tensor` not casting floats.
# See b/119637405
value = math_ops.cast(value, tensor.dtype)
converted_inputs.append(value)
outputs = self._graph_fn(*converted_inputs)
# EagerTensor.numpy() will often make a copy to ensure memory safety.
# However in this case `outputs` is not directly returned, so it is always
# safe to reuse the underlying buffer without checking. In such a case the
# private numpy conversion method is preferred to guarantee performance.
return nest.pack_sequence_as(
self._outputs_structure,
[x._numpy() for x in outputs], # pylint: disable=protected-access
expand_composites=True)
@keras_export('keras.backend.function') @keras_export('keras.backend.function')
@ -3900,7 +3819,20 @@ def function(inputs, outputs, updates=None, name=None, **kwargs):
if kwargs: if kwargs:
raise ValueError('Session keyword arguments are not support during ' raise ValueError('Session keyword arguments are not support during '
'eager execution. You passed: %s' % (kwargs,)) 'eager execution. You passed: %s' % (kwargs,))
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name) if updates:
raise ValueError('`updates` argument is not support during '
'eager execution. You passed: %s' % (updates,))
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
model = models.Model(inputs=inputs, outputs=outputs)
wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
def func(model_inputs):
outs = model(model_inputs)
if wrap_outputs:
outs = [outs]
return tf_utils.to_numpy_or_python_type(outs)
return func
if kwargs: if kwargs:
for key in kwargs: for key in kwargs:
@ -6344,10 +6276,6 @@ class ContextValueCache(weakref.WeakKeyDictionary):
# either train mode (learning_phase == 1) or test mode (learning_phase == 0). # either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase) _GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
# Each set tracks objects created via `freezable_variable` in the graph.
_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
# This dictionary holds a mapping between a graph and variables to initialize # This dictionary holds a mapping between a graph and variables to initialize
# in the graph. # in the graph.
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet) _GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)

View File

@ -1677,8 +1677,10 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase):
t, p, from_logits=True, axis=0), t, p, from_logits=True, axis=0),
self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3) self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph']))
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self): def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
# This test only runs in graph because the TF op layer is not supported yet
# for sparse ops.
t = backend.placeholder() t = backend.placeholder()
p = backend.placeholder() p = backend.placeholder()
o = backend.sparse_categorical_crossentropy(t, p) o = backend.sparse_categorical_crossentropy(t, p)
@ -1870,6 +1872,8 @@ class TestRandomOps(test.TestCase):
class FunctionTest(test.TestCase): class FunctionTest(test.TestCase):
def test_function_basics(self): def test_function_basics(self):
if context.executing_eagerly():
self.skipTest('eager backend.function does not support updates')
x1 = backend.placeholder(shape=(), dtype='float32') x1 = backend.placeholder(shape=(), dtype='float32')
x2 = backend.placeholder(shape=(), dtype='int32') x2 = backend.placeholder(shape=(), dtype='int32')
v = backend.variable(10.) v = backend.variable(10.)
@ -1916,6 +1920,9 @@ class FunctionTest(test.TestCase):
self.assertEqual(result, 4.) self.assertEqual(result, 4.)
def test_tuple_updates(self): def test_tuple_updates(self):
if context.executing_eagerly():
self.skipTest('eager backend.function does not support updates')
x_ph = backend.placeholder(ndim=2) x_ph = backend.placeholder(ndim=2)
v = backend.variable(np.ones((4, 2))) v = backend.variable(np.ones((4, 2)))
output = x_ph ** 2 + v output = x_ph ** 2 + v
@ -1929,7 +1936,7 @@ class FunctionTest(test.TestCase):
class BackendGraphTests(test.TestCase, parameterized.TestCase): class BackendGraphTests(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph']))
def test_function_placeholder_with_default(self): def test_function_placeholder_with_default(self):
with backend.get_graph().as_default(): with backend.get_graph().as_default():
x1 = array_ops.placeholder_with_default( x1 = array_ops.placeholder_with_default(

View File

@ -248,7 +248,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
constants[i] = op_input constants[i] = op_input
else: else:
with ops.init_scope(): with ops.init_scope():
constants[i] = backend.function([], op_input)([]) if ops.executing_eagerly_outside_functions():
constants[i] = backend.eval_in_eager_or_function(op_input)
else:
constants[i] = backend.function([], op_input)([])
layer_inputs = unnest_if_single_tensor(layer_inputs) layer_inputs = unnest_if_single_tensor(layer_inputs)
processed_ops, created_layers = _create_keras_history_helper( processed_ops, created_layers = _create_keras_history_helper(
layer_inputs, processed_ops, created_layers) layer_inputs, processed_ops, created_layers)

View File

@ -161,8 +161,11 @@ class InputLayer(base_layer.Layer):
'InputLayer, you should instantiate your model and ' 'InputLayer, you should instantiate your model and '
'directly call it on your input.') 'directly call it on your input.')
self.is_placeholder = False self.is_placeholder = False
self._batch_input_shape = tuple(input_tensor.shape.as_list()) try:
self._batch_input_shape = tuple(input_tensor.shape.as_list())
except ValueError:
# If the shape cannot be represented as a tuple (e.g. unknown rank)
self._batch_input_shape = None
# Create an input node. # Create an input node.
input_tensor._keras_mask = None input_tensor._keras_mask = None
node_module.Node(layer=self, outputs=input_tensor) node_module.Node(layer=self, outputs=input_tensor)

View File

@ -1935,7 +1935,7 @@ def get_input_shape_and_dtype(layer):
raise ValueError('An empty Model cannot be used as a Layer.') raise ValueError('An empty Model cannot be used as a Layer.')
layer = layer.layers[0] layer = layer.layers[0]
if hasattr(layer, '_batch_input_shape'): if getattr(layer, '_batch_input_shape', None):
return layer._batch_input_shape, layer.dtype return layer._batch_input_shape, layer.dtype
return None, None return None, None

View File

@ -288,9 +288,10 @@ class AutoLambdaTest(keras_parameterized.TestCase):
constant_op.constant(40.0, shape=(1, 1))) constant_op.constant(40.0, shape=(1, 1)))
def test_no_tracking(self): def test_no_tracking(self):
x = keras.backend.placeholder((10, 10)) if not context.executing_eagerly():
keras.layers.Dense(1)(x) x = constant_op.constant(1.0, shape=(10, 10))
self.assertTrue(x._keras_history_checked) keras.layers.Dense(1)(x)
self.assertTrue(x._keras_history_checked)
def test_timing_scales_linearly(self): def test_timing_scales_linearly(self):

View File

@ -33,6 +33,7 @@ from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -1213,9 +1214,14 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
f_merged = keras.backend.function([inputs], layer(inputs)) f_merged = keras.backend.function([inputs], layer(inputs))
f_forward = keras.backend.function([inputs], f_forward = keras.backend.function([inputs],
layer.forward_layer(inputs)) layer.forward_layer(inputs))
# TODO(kaftan): after KerasTensor refactor TF op layers should work
# with many composite tensors, and this shouldn't need to be a lambda
# layer.
reverse_layer = core.Lambda(array_ops.reverse, arguments=dict(axis=[1]))
f_backward = keras.backend.function( f_backward = keras.backend.function(
[inputs], [inputs],
array_ops.reverse(layer.backward_layer(inputs), axis=[1])) reverse_layer(layer.backward_layer(inputs)))
y_merged = f_merged(x) y_merged = f_merged(x)
y_expected = merge_func( y_expected = merge_func(

View File

@ -125,8 +125,10 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase):
backend.eval(output_from_softmax), backend.eval(output_from_softmax),
atol=1e-5) atol=1e-5)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph']))
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self): def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
# This test only runs in graph because the TF op layer is not supported yet
# for sparse ops.
t = backend.placeholder() t = backend.placeholder()
p = backend.placeholder() p = backend.placeholder()
o = losses.sparse_categorical_crossentropy(t, p) o = losses.sparse_categorical_crossentropy(t, p)