Switches keras.backend.placeholder + keras.backend.function to build a keras model when running eagerly (instead of trying to directly lift ops out of a graph into a concretefunction). Allows us to strip most of EagerDefinedExecutionFunction from the keras backend.
This has the effect of making keras.backend.placeholder + backend.function use the same codepaths as the rest of Keras. This may have the following impact on user code: - keras.backend.function no longer supports the `updates` argument when eager execution is enabled. - keras.backend.placeholder + keras.backend.function now have the same limitations as TF op layers when manipulating the placeholders directly with tf ops. This means no support outside of a layer for sparse ops & ops that operate on composite tensors. PiperOrigin-RevId: 312711373 Change-Id: Ie4bab440b83ea2becf1c83b83837771eba185ff5
This commit is contained in:
parent
5d3ad40d11
commit
1d8bc7222d
|
@ -584,6 +584,7 @@ tf_py_test(
|
|||
deps = [
|
||||
":backend",
|
||||
":combinations",
|
||||
":engine",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
|
|
@ -294,7 +294,6 @@ def clear_session():
|
|||
global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
|
||||
global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
|
||||
global _GRAPH
|
||||
global _FREEZABLE_VARS
|
||||
_GRAPH.graph = None
|
||||
ops.reset_default_graph()
|
||||
reset_uids()
|
||||
|
@ -307,7 +306,6 @@ def clear_session():
|
|||
_GRAPH_LEARNING_PHASES.setdefault(graph)
|
||||
_GRAPH_VARIABLES.pop(graph, None)
|
||||
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
||||
_FREEZABLE_VARS.pop(graph, None)
|
||||
|
||||
|
||||
@keras_export('keras.backend.manual_variable_initialization')
|
||||
|
@ -1059,9 +1057,9 @@ def is_keras_tensor(x):
|
|||
>>> tf.keras.backend.is_keras_tensor(keras_var)
|
||||
False
|
||||
>>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
|
||||
>>> # A placeholder is not a Keras tensor.
|
||||
>>> # A placeholder is a Keras tensor.
|
||||
>>> tf.keras.backend.is_keras_tensor(keras_placeholder)
|
||||
False
|
||||
True
|
||||
>>> keras_input = tf.keras.layers.Input([10])
|
||||
>>> # An Input is a Keras tensor.
|
||||
>>> tf.keras.backend.is_keras_tensor(keras_input)
|
||||
|
@ -1144,6 +1142,14 @@ def placeholder(shape=None,
|
|||
expand_composites=True)
|
||||
else:
|
||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Add keras_history connectivity information to the placeholder
|
||||
# when the placeholder is built in a top-level eager context
|
||||
# (intended to be used with keras.backend.function)
|
||||
from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top
|
||||
return input_layer.Input(tensor=x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -3379,7 +3385,7 @@ def get_value(x):
|
|||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# This method of evaluating works inside the Keras FuncGraph.
|
||||
return function([], x)(x)
|
||||
return eval_in_eager_or_function(x)
|
||||
|
||||
with x.graph.as_default():
|
||||
return x.eval(session=get_session((x,)))
|
||||
|
@ -3722,161 +3728,74 @@ class GraphExecutionFunction(object):
|
|||
return nest.map_structure(self._eval_if_composite, output_structure)
|
||||
|
||||
|
||||
class EagerExecutionFunction(object):
|
||||
"""Helper class for constructing a TF graph function from the Keras graph.
|
||||
def eval_in_eager_or_function(outputs):
|
||||
"""Method to evaluate a tensor in eager or in a tf.function.
|
||||
|
||||
In the case of a tf.function, it will lift the tensor out of the function
|
||||
and try to evaluate that piece of the graph.
|
||||
|
||||
Warning: Do not add new usages of this function.
|
||||
TODO(b/150169018): delete this function once _keras_history_helper is no
|
||||
longer needed, after Keras switches to KerasTensors and op layers
|
||||
work via dispatch.
|
||||
|
||||
Arguments:
|
||||
inputs: Feed placeholders to the computation graph.
|
||||
outputs: Output tensors to fetch.
|
||||
updates: Additional update ops to be run at function call.
|
||||
name: A name to help users identify what this function does.
|
||||
session_kwargs: Unsupported.
|
||||
outputs: tensors to fetch.
|
||||
Returns:
|
||||
The value of the tensors (as numpy arrays).
|
||||
"""
|
||||
outputs_structure = outputs
|
||||
outputs = nest.flatten(outputs, expand_composites=True)
|
||||
|
||||
def __init__(self, inputs, outputs, updates=None, name=None):
|
||||
self.name = name
|
||||
self._inputs_structure = inputs
|
||||
inputs = nest.flatten(inputs, expand_composites=True)
|
||||
self._outputs_structure = outputs
|
||||
outputs = nest.flatten(outputs, expand_composites=True)
|
||||
graphs = {
|
||||
i.graph
|
||||
for i in nest.flatten([outputs])
|
||||
if hasattr(i, 'graph')
|
||||
}
|
||||
if len(graphs) > 1:
|
||||
raise ValueError('Cannot create an execution function which is comprised '
|
||||
'of elements from multiple graphs.')
|
||||
|
||||
updates = updates or []
|
||||
if not isinstance(updates, (list, tuple)):
|
||||
raise TypeError('`updates` in a Keras backend function '
|
||||
'should be a list or tuple.')
|
||||
source_graph = graphs.pop()
|
||||
|
||||
if updates and not outputs:
|
||||
# Edge case; never happens in practice
|
||||
raise ValueError('Cannot create a Keras backend function with updates'
|
||||
' but no outputs during eager execution.')
|
||||
graphs = {
|
||||
i.graph
|
||||
for i in nest.flatten([inputs, outputs, updates])
|
||||
if hasattr(i, 'graph')
|
||||
}
|
||||
if len(graphs) > 1:
|
||||
raise ValueError('Cannot create an execution function which is comprised '
|
||||
'of elements from multiple graphs.')
|
||||
|
||||
source_graph = graphs.pop()
|
||||
with _scratch_graph() as exec_graph:
|
||||
global_graph = get_graph()
|
||||
if source_graph not in (exec_graph, global_graph):
|
||||
raise ValueError('Unknown graph. Aborting.')
|
||||
|
||||
updates_ops = []
|
||||
legacy_update_ops = []
|
||||
for update in updates:
|
||||
# For legacy reasons it is allowed to pass an update as a tuple
|
||||
# `(variable, new_value)` (this maps to an assign op). Otherwise it
|
||||
# is assumed to already be an op -- we cannot control its execution
|
||||
# order.
|
||||
if isinstance(update, tuple):
|
||||
legacy_update_ops.append(update)
|
||||
else:
|
||||
if hasattr(update, 'op'):
|
||||
update = update.op
|
||||
if update is not None:
|
||||
# `update.op` may have been None in certain cases.
|
||||
updates_ops.append(update)
|
||||
if source_graph is global_graph and exec_graph is not global_graph:
|
||||
init_tensors = outputs
|
||||
lifted_map = lift_to_graph.lift_to_graph(
|
||||
tensors=init_tensors,
|
||||
graph=exec_graph,
|
||||
sources=[],
|
||||
add_sources=True,
|
||||
handle_captures=True,
|
||||
base_graph=source_graph)
|
||||
|
||||
self._freezable_vars_to_feed = []
|
||||
self._freezable_vars_values = []
|
||||
freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet(
|
||||
_FREEZABLE_VARS.get(global_graph, {}))
|
||||
with _scratch_graph() as exec_graph:
|
||||
global_graph = get_graph()
|
||||
if source_graph not in (exec_graph, global_graph):
|
||||
raise ValueError('Unknown graph. Aborting.')
|
||||
outputs = [lifted_map[i] for i in outputs]
|
||||
|
||||
if source_graph is global_graph and exec_graph is not global_graph:
|
||||
init_tensors = (
|
||||
outputs + updates_ops + [p for [p, _] in legacy_update_ops] +
|
||||
[p_new for [_, p_new] in legacy_update_ops
|
||||
if isinstance(p_new, ops.Tensor)])
|
||||
lifted_map = lift_to_graph.lift_to_graph(
|
||||
tensors=init_tensors,
|
||||
graph=exec_graph,
|
||||
sources=inputs,
|
||||
add_sources=True,
|
||||
handle_captures=True,
|
||||
base_graph=source_graph)
|
||||
# Consolidate updates
|
||||
with exec_graph.as_default():
|
||||
outputs = cast_variables_to_tensor(outputs)
|
||||
|
||||
inputs = [lifted_map[i] for i in inputs]
|
||||
outputs = [lifted_map[i] for i in outputs]
|
||||
updates_ops = [lifted_map[i] for i in updates_ops]
|
||||
legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
|
||||
for p, p_new in legacy_update_ops]
|
||||
exec_graph.inputs = exec_graph.internal_captures
|
||||
exec_graph.outputs = outputs
|
||||
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
||||
|
||||
# Keep track of the value to feed to any "freezable variables"
|
||||
# created in this graph.
|
||||
for old_op, new_op in lifted_map.items():
|
||||
if old_op in freezable_vars_from_keras_graph:
|
||||
frozen_var = old_op
|
||||
if frozen_var._initial_value != frozen_var._current_value:
|
||||
# We only feed a frozen_variable if its value has changed;
|
||||
# otherwise it can rely on the default value of the
|
||||
# underlying placeholder_with_default.
|
||||
self._freezable_vars_to_feed.append(new_op)
|
||||
self._freezable_vars_values.append(frozen_var._current_value)
|
||||
graph_fn._num_positional_args = 0
|
||||
graph_fn._arg_keywords = []
|
||||
|
||||
# Consolidate updates
|
||||
with exec_graph.as_default():
|
||||
outputs = cast_variables_to_tensor(outputs)
|
||||
with ops.control_dependencies(outputs):
|
||||
for p, p_new in legacy_update_ops:
|
||||
updates_ops.append(state_ops.assign(p, p_new))
|
||||
outputs = graph_fn()
|
||||
|
||||
self.inputs, self.outputs = inputs, outputs
|
||||
self._input_references = self.inputs + self._freezable_vars_to_feed
|
||||
with ops.control_dependencies(updates_ops):
|
||||
self.outputs[0] = array_ops.identity(self.outputs[0])
|
||||
|
||||
exec_graph.inputs = self._input_references + exec_graph.internal_captures
|
||||
exec_graph.outputs = self.outputs
|
||||
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
||||
|
||||
graph_fn._num_positional_args = len(self._input_references)
|
||||
graph_fn._arg_keywords = []
|
||||
self._graph_fn = graph_fn
|
||||
|
||||
# Handle placeholders with default
|
||||
# (treated as required placeholder by graph functions)
|
||||
self._placeholder_default_values = {}
|
||||
with exec_graph.as_default():
|
||||
for x in self.inputs:
|
||||
if x.op.type == 'PlaceholderWithDefault':
|
||||
self._placeholder_default_values[ops.tensor_id(
|
||||
x)] = tensor_util.constant_value(x.op.inputs[0])
|
||||
|
||||
def __call__(self, inputs):
|
||||
input_values = nest.flatten(inputs, expand_composites=True)
|
||||
|
||||
if self._freezable_vars_values:
|
||||
input_values = input_values + self._freezable_vars_values
|
||||
converted_inputs = []
|
||||
for tensor, value in zip(self._input_references, input_values):
|
||||
if value is None:
|
||||
# Assume `value` is a placeholder with default
|
||||
value = self._placeholder_default_values.get(
|
||||
ops.tensor_id(tensor), None)
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
'You must feed a value for placeholder %s' % (tensor,))
|
||||
if not isinstance(value, ops.Tensor):
|
||||
value = ops.convert_to_tensor_v2(value, dtype=tensor.dtype)
|
||||
if value.dtype != tensor.dtype:
|
||||
# Temporary workaround due to `convert_to_tensor` not casting floats.
|
||||
# See b/119637405
|
||||
value = math_ops.cast(value, tensor.dtype)
|
||||
converted_inputs.append(value)
|
||||
outputs = self._graph_fn(*converted_inputs)
|
||||
|
||||
# EagerTensor.numpy() will often make a copy to ensure memory safety.
|
||||
# However in this case `outputs` is not directly returned, so it is always
|
||||
# safe to reuse the underlying buffer without checking. In such a case the
|
||||
# private numpy conversion method is preferred to guarantee performance.
|
||||
return nest.pack_sequence_as(
|
||||
self._outputs_structure,
|
||||
[x._numpy() for x in outputs], # pylint: disable=protected-access
|
||||
expand_composites=True)
|
||||
# EagerTensor.numpy() will often make a copy to ensure memory safety.
|
||||
# However in this case `outputs` is not directly returned, so it is always
|
||||
# safe to reuse the underlying buffer without checking. In such a case the
|
||||
# private numpy conversion method is preferred to guarantee performance.
|
||||
return nest.pack_sequence_as(
|
||||
outputs_structure,
|
||||
[x._numpy() for x in outputs], # pylint: disable=protected-access
|
||||
expand_composites=True)
|
||||
|
||||
|
||||
@keras_export('keras.backend.function')
|
||||
|
@ -3900,7 +3819,20 @@ def function(inputs, outputs, updates=None, name=None, **kwargs):
|
|||
if kwargs:
|
||||
raise ValueError('Session keyword arguments are not support during '
|
||||
'eager execution. You passed: %s' % (kwargs,))
|
||||
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
|
||||
if updates:
|
||||
raise ValueError('`updates` argument is not support during '
|
||||
'eager execution. You passed: %s' % (updates,))
|
||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
|
||||
model = models.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
|
||||
def func(model_inputs):
|
||||
outs = model(model_inputs)
|
||||
if wrap_outputs:
|
||||
outs = [outs]
|
||||
return tf_utils.to_numpy_or_python_type(outs)
|
||||
return func
|
||||
|
||||
if kwargs:
|
||||
for key in kwargs:
|
||||
|
@ -6344,10 +6276,6 @@ class ContextValueCache(weakref.WeakKeyDictionary):
|
|||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
|
||||
|
||||
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
|
||||
# Each set tracks objects created via `freezable_variable` in the graph.
|
||||
_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||
|
||||
# This dictionary holds a mapping between a graph and variables to initialize
|
||||
# in the graph.
|
||||
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||
|
|
|
@ -1677,8 +1677,10 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase):
|
|||
t, p, from_logits=True, axis=0),
|
||||
self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
|
||||
# This test only runs in graph because the TF op layer is not supported yet
|
||||
# for sparse ops.
|
||||
t = backend.placeholder()
|
||||
p = backend.placeholder()
|
||||
o = backend.sparse_categorical_crossentropy(t, p)
|
||||
|
@ -1870,6 +1872,8 @@ class TestRandomOps(test.TestCase):
|
|||
class FunctionTest(test.TestCase):
|
||||
|
||||
def test_function_basics(self):
|
||||
if context.executing_eagerly():
|
||||
self.skipTest('eager backend.function does not support updates')
|
||||
x1 = backend.placeholder(shape=(), dtype='float32')
|
||||
x2 = backend.placeholder(shape=(), dtype='int32')
|
||||
v = backend.variable(10.)
|
||||
|
@ -1916,6 +1920,9 @@ class FunctionTest(test.TestCase):
|
|||
self.assertEqual(result, 4.)
|
||||
|
||||
def test_tuple_updates(self):
|
||||
if context.executing_eagerly():
|
||||
self.skipTest('eager backend.function does not support updates')
|
||||
|
||||
x_ph = backend.placeholder(ndim=2)
|
||||
v = backend.variable(np.ones((4, 2)))
|
||||
output = x_ph ** 2 + v
|
||||
|
@ -1929,7 +1936,7 @@ class FunctionTest(test.TestCase):
|
|||
|
||||
class BackendGraphTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
def test_function_placeholder_with_default(self):
|
||||
with backend.get_graph().as_default():
|
||||
x1 = array_ops.placeholder_with_default(
|
||||
|
|
|
@ -248,7 +248,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
|||
constants[i] = op_input
|
||||
else:
|
||||
with ops.init_scope():
|
||||
constants[i] = backend.function([], op_input)([])
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
constants[i] = backend.eval_in_eager_or_function(op_input)
|
||||
else:
|
||||
constants[i] = backend.function([], op_input)([])
|
||||
layer_inputs = unnest_if_single_tensor(layer_inputs)
|
||||
processed_ops, created_layers = _create_keras_history_helper(
|
||||
layer_inputs, processed_ops, created_layers)
|
||||
|
|
|
@ -161,8 +161,11 @@ class InputLayer(base_layer.Layer):
|
|||
'InputLayer, you should instantiate your model and '
|
||||
'directly call it on your input.')
|
||||
self.is_placeholder = False
|
||||
self._batch_input_shape = tuple(input_tensor.shape.as_list())
|
||||
|
||||
try:
|
||||
self._batch_input_shape = tuple(input_tensor.shape.as_list())
|
||||
except ValueError:
|
||||
# If the shape cannot be represented as a tuple (e.g. unknown rank)
|
||||
self._batch_input_shape = None
|
||||
# Create an input node.
|
||||
input_tensor._keras_mask = None
|
||||
node_module.Node(layer=self, outputs=input_tensor)
|
||||
|
|
|
@ -1935,7 +1935,7 @@ def get_input_shape_and_dtype(layer):
|
|||
raise ValueError('An empty Model cannot be used as a Layer.')
|
||||
layer = layer.layers[0]
|
||||
|
||||
if hasattr(layer, '_batch_input_shape'):
|
||||
if getattr(layer, '_batch_input_shape', None):
|
||||
return layer._batch_input_shape, layer.dtype
|
||||
return None, None
|
||||
|
||||
|
|
|
@ -288,9 +288,10 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
|||
constant_op.constant(40.0, shape=(1, 1)))
|
||||
|
||||
def test_no_tracking(self):
|
||||
x = keras.backend.placeholder((10, 10))
|
||||
keras.layers.Dense(1)(x)
|
||||
self.assertTrue(x._keras_history_checked)
|
||||
if not context.executing_eagerly():
|
||||
x = constant_op.constant(1.0, shape=(10, 10))
|
||||
keras.layers.Dense(1)(x)
|
||||
self.assertTrue(x._keras_history_checked)
|
||||
|
||||
def test_timing_scales_linearly(self):
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from tensorflow.python.keras import combinations
|
|||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -1213,9 +1214,14 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
|||
f_merged = keras.backend.function([inputs], layer(inputs))
|
||||
f_forward = keras.backend.function([inputs],
|
||||
layer.forward_layer(inputs))
|
||||
|
||||
# TODO(kaftan): after KerasTensor refactor TF op layers should work
|
||||
# with many composite tensors, and this shouldn't need to be a lambda
|
||||
# layer.
|
||||
reverse_layer = core.Lambda(array_ops.reverse, arguments=dict(axis=[1]))
|
||||
f_backward = keras.backend.function(
|
||||
[inputs],
|
||||
array_ops.reverse(layer.backward_layer(inputs), axis=[1]))
|
||||
reverse_layer(layer.backward_layer(inputs)))
|
||||
|
||||
y_merged = f_merged(x)
|
||||
y_expected = merge_func(
|
||||
|
|
|
@ -125,8 +125,10 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase):
|
|||
backend.eval(output_from_softmax),
|
||||
atol=1e-5)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
|
||||
# This test only runs in graph because the TF op layer is not supported yet
|
||||
# for sparse ops.
|
||||
t = backend.placeholder()
|
||||
p = backend.placeholder()
|
||||
o = losses.sparse_categorical_crossentropy(t, p)
|
||||
|
|
Loading…
Reference in New Issue