Switches keras.backend.placeholder + keras.backend.function to build a keras model when running eagerly (instead of trying to directly lift ops out of a graph into a concretefunction). Allows us to strip most of EagerDefinedExecutionFunction from the keras backend.
This has the effect of making keras.backend.placeholder + backend.function use the same codepaths as the rest of Keras. This may have the following impact on user code: - keras.backend.function no longer supports the `updates` argument when eager execution is enabled. - keras.backend.placeholder + keras.backend.function now have the same limitations as TF op layers when manipulating the placeholders directly with tf ops. This means no support outside of a layer for sparse ops & ops that operate on composite tensors. PiperOrigin-RevId: 312711373 Change-Id: Ie4bab440b83ea2becf1c83b83837771eba185ff5
This commit is contained in:
parent
5d3ad40d11
commit
1d8bc7222d
|
@ -584,6 +584,7 @@ tf_py_test(
|
||||||
deps = [
|
deps = [
|
||||||
":backend",
|
":backend",
|
||||||
":combinations",
|
":combinations",
|
||||||
|
":engine",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
|
|
@ -294,7 +294,6 @@ def clear_session():
|
||||||
global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
|
global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
|
||||||
global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
|
global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
|
||||||
global _GRAPH
|
global _GRAPH
|
||||||
global _FREEZABLE_VARS
|
|
||||||
_GRAPH.graph = None
|
_GRAPH.graph = None
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
reset_uids()
|
reset_uids()
|
||||||
|
@ -307,7 +306,6 @@ def clear_session():
|
||||||
_GRAPH_LEARNING_PHASES.setdefault(graph)
|
_GRAPH_LEARNING_PHASES.setdefault(graph)
|
||||||
_GRAPH_VARIABLES.pop(graph, None)
|
_GRAPH_VARIABLES.pop(graph, None)
|
||||||
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
||||||
_FREEZABLE_VARS.pop(graph, None)
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.manual_variable_initialization')
|
@keras_export('keras.backend.manual_variable_initialization')
|
||||||
|
@ -1059,9 +1057,9 @@ def is_keras_tensor(x):
|
||||||
>>> tf.keras.backend.is_keras_tensor(keras_var)
|
>>> tf.keras.backend.is_keras_tensor(keras_var)
|
||||||
False
|
False
|
||||||
>>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
|
>>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
|
||||||
>>> # A placeholder is not a Keras tensor.
|
>>> # A placeholder is a Keras tensor.
|
||||||
>>> tf.keras.backend.is_keras_tensor(keras_placeholder)
|
>>> tf.keras.backend.is_keras_tensor(keras_placeholder)
|
||||||
False
|
True
|
||||||
>>> keras_input = tf.keras.layers.Input([10])
|
>>> keras_input = tf.keras.layers.Input([10])
|
||||||
>>> # An Input is a Keras tensor.
|
>>> # An Input is a Keras tensor.
|
||||||
>>> tf.keras.backend.is_keras_tensor(keras_input)
|
>>> tf.keras.backend.is_keras_tensor(keras_input)
|
||||||
|
@ -1144,6 +1142,14 @@ def placeholder(shape=None,
|
||||||
expand_composites=True)
|
expand_composites=True)
|
||||||
else:
|
else:
|
||||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||||
|
|
||||||
|
if context.executing_eagerly():
|
||||||
|
# Add keras_history connectivity information to the placeholder
|
||||||
|
# when the placeholder is built in a top-level eager context
|
||||||
|
# (intended to be used with keras.backend.function)
|
||||||
|
from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top
|
||||||
|
return input_layer.Input(tensor=x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -3379,7 +3385,7 @@ def get_value(x):
|
||||||
|
|
||||||
if ops.executing_eagerly_outside_functions():
|
if ops.executing_eagerly_outside_functions():
|
||||||
# This method of evaluating works inside the Keras FuncGraph.
|
# This method of evaluating works inside the Keras FuncGraph.
|
||||||
return function([], x)(x)
|
return eval_in_eager_or_function(x)
|
||||||
|
|
||||||
with x.graph.as_default():
|
with x.graph.as_default():
|
||||||
return x.eval(session=get_session((x,)))
|
return x.eval(session=get_session((x,)))
|
||||||
|
@ -3722,161 +3728,74 @@ class GraphExecutionFunction(object):
|
||||||
return nest.map_structure(self._eval_if_composite, output_structure)
|
return nest.map_structure(self._eval_if_composite, output_structure)
|
||||||
|
|
||||||
|
|
||||||
class EagerExecutionFunction(object):
|
def eval_in_eager_or_function(outputs):
|
||||||
"""Helper class for constructing a TF graph function from the Keras graph.
|
"""Method to evaluate a tensor in eager or in a tf.function.
|
||||||
|
|
||||||
|
In the case of a tf.function, it will lift the tensor out of the function
|
||||||
|
and try to evaluate that piece of the graph.
|
||||||
|
|
||||||
|
Warning: Do not add new usages of this function.
|
||||||
|
TODO(b/150169018): delete this function once _keras_history_helper is no
|
||||||
|
longer needed, after Keras switches to KerasTensors and op layers
|
||||||
|
work via dispatch.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
inputs: Feed placeholders to the computation graph.
|
outputs: tensors to fetch.
|
||||||
outputs: Output tensors to fetch.
|
Returns:
|
||||||
updates: Additional update ops to be run at function call.
|
The value of the tensors (as numpy arrays).
|
||||||
name: A name to help users identify what this function does.
|
|
||||||
session_kwargs: Unsupported.
|
|
||||||
"""
|
"""
|
||||||
|
outputs_structure = outputs
|
||||||
|
outputs = nest.flatten(outputs, expand_composites=True)
|
||||||
|
|
||||||
def __init__(self, inputs, outputs, updates=None, name=None):
|
graphs = {
|
||||||
self.name = name
|
i.graph
|
||||||
self._inputs_structure = inputs
|
for i in nest.flatten([outputs])
|
||||||
inputs = nest.flatten(inputs, expand_composites=True)
|
if hasattr(i, 'graph')
|
||||||
self._outputs_structure = outputs
|
}
|
||||||
outputs = nest.flatten(outputs, expand_composites=True)
|
if len(graphs) > 1:
|
||||||
|
raise ValueError('Cannot create an execution function which is comprised '
|
||||||
|
'of elements from multiple graphs.')
|
||||||
|
|
||||||
updates = updates or []
|
source_graph = graphs.pop()
|
||||||
if not isinstance(updates, (list, tuple)):
|
|
||||||
raise TypeError('`updates` in a Keras backend function '
|
|
||||||
'should be a list or tuple.')
|
|
||||||
|
|
||||||
if updates and not outputs:
|
with _scratch_graph() as exec_graph:
|
||||||
# Edge case; never happens in practice
|
|
||||||
raise ValueError('Cannot create a Keras backend function with updates'
|
|
||||||
' but no outputs during eager execution.')
|
|
||||||
graphs = {
|
|
||||||
i.graph
|
|
||||||
for i in nest.flatten([inputs, outputs, updates])
|
|
||||||
if hasattr(i, 'graph')
|
|
||||||
}
|
|
||||||
if len(graphs) > 1:
|
|
||||||
raise ValueError('Cannot create an execution function which is comprised '
|
|
||||||
'of elements from multiple graphs.')
|
|
||||||
|
|
||||||
source_graph = graphs.pop()
|
|
||||||
global_graph = get_graph()
|
global_graph = get_graph()
|
||||||
|
if source_graph not in (exec_graph, global_graph):
|
||||||
|
raise ValueError('Unknown graph. Aborting.')
|
||||||
|
|
||||||
updates_ops = []
|
if source_graph is global_graph and exec_graph is not global_graph:
|
||||||
legacy_update_ops = []
|
init_tensors = outputs
|
||||||
for update in updates:
|
lifted_map = lift_to_graph.lift_to_graph(
|
||||||
# For legacy reasons it is allowed to pass an update as a tuple
|
tensors=init_tensors,
|
||||||
# `(variable, new_value)` (this maps to an assign op). Otherwise it
|
graph=exec_graph,
|
||||||
# is assumed to already be an op -- we cannot control its execution
|
sources=[],
|
||||||
# order.
|
add_sources=True,
|
||||||
if isinstance(update, tuple):
|
handle_captures=True,
|
||||||
legacy_update_ops.append(update)
|
base_graph=source_graph)
|
||||||
else:
|
|
||||||
if hasattr(update, 'op'):
|
|
||||||
update = update.op
|
|
||||||
if update is not None:
|
|
||||||
# `update.op` may have been None in certain cases.
|
|
||||||
updates_ops.append(update)
|
|
||||||
|
|
||||||
self._freezable_vars_to_feed = []
|
outputs = [lifted_map[i] for i in outputs]
|
||||||
self._freezable_vars_values = []
|
|
||||||
freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet(
|
|
||||||
_FREEZABLE_VARS.get(global_graph, {}))
|
|
||||||
with _scratch_graph() as exec_graph:
|
|
||||||
global_graph = get_graph()
|
|
||||||
if source_graph not in (exec_graph, global_graph):
|
|
||||||
raise ValueError('Unknown graph. Aborting.')
|
|
||||||
|
|
||||||
if source_graph is global_graph and exec_graph is not global_graph:
|
# Consolidate updates
|
||||||
init_tensors = (
|
with exec_graph.as_default():
|
||||||
outputs + updates_ops + [p for [p, _] in legacy_update_ops] +
|
outputs = cast_variables_to_tensor(outputs)
|
||||||
[p_new for [_, p_new] in legacy_update_ops
|
|
||||||
if isinstance(p_new, ops.Tensor)])
|
|
||||||
lifted_map = lift_to_graph.lift_to_graph(
|
|
||||||
tensors=init_tensors,
|
|
||||||
graph=exec_graph,
|
|
||||||
sources=inputs,
|
|
||||||
add_sources=True,
|
|
||||||
handle_captures=True,
|
|
||||||
base_graph=source_graph)
|
|
||||||
|
|
||||||
inputs = [lifted_map[i] for i in inputs]
|
exec_graph.inputs = exec_graph.internal_captures
|
||||||
outputs = [lifted_map[i] for i in outputs]
|
exec_graph.outputs = outputs
|
||||||
updates_ops = [lifted_map[i] for i in updates_ops]
|
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
||||||
legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
|
|
||||||
for p, p_new in legacy_update_ops]
|
|
||||||
|
|
||||||
# Keep track of the value to feed to any "freezable variables"
|
graph_fn._num_positional_args = 0
|
||||||
# created in this graph.
|
graph_fn._arg_keywords = []
|
||||||
for old_op, new_op in lifted_map.items():
|
|
||||||
if old_op in freezable_vars_from_keras_graph:
|
|
||||||
frozen_var = old_op
|
|
||||||
if frozen_var._initial_value != frozen_var._current_value:
|
|
||||||
# We only feed a frozen_variable if its value has changed;
|
|
||||||
# otherwise it can rely on the default value of the
|
|
||||||
# underlying placeholder_with_default.
|
|
||||||
self._freezable_vars_to_feed.append(new_op)
|
|
||||||
self._freezable_vars_values.append(frozen_var._current_value)
|
|
||||||
|
|
||||||
# Consolidate updates
|
outputs = graph_fn()
|
||||||
with exec_graph.as_default():
|
|
||||||
outputs = cast_variables_to_tensor(outputs)
|
|
||||||
with ops.control_dependencies(outputs):
|
|
||||||
for p, p_new in legacy_update_ops:
|
|
||||||
updates_ops.append(state_ops.assign(p, p_new))
|
|
||||||
|
|
||||||
self.inputs, self.outputs = inputs, outputs
|
# EagerTensor.numpy() will often make a copy to ensure memory safety.
|
||||||
self._input_references = self.inputs + self._freezable_vars_to_feed
|
# However in this case `outputs` is not directly returned, so it is always
|
||||||
with ops.control_dependencies(updates_ops):
|
# safe to reuse the underlying buffer without checking. In such a case the
|
||||||
self.outputs[0] = array_ops.identity(self.outputs[0])
|
# private numpy conversion method is preferred to guarantee performance.
|
||||||
|
return nest.pack_sequence_as(
|
||||||
exec_graph.inputs = self._input_references + exec_graph.internal_captures
|
outputs_structure,
|
||||||
exec_graph.outputs = self.outputs
|
[x._numpy() for x in outputs], # pylint: disable=protected-access
|
||||||
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
expand_composites=True)
|
||||||
|
|
||||||
graph_fn._num_positional_args = len(self._input_references)
|
|
||||||
graph_fn._arg_keywords = []
|
|
||||||
self._graph_fn = graph_fn
|
|
||||||
|
|
||||||
# Handle placeholders with default
|
|
||||||
# (treated as required placeholder by graph functions)
|
|
||||||
self._placeholder_default_values = {}
|
|
||||||
with exec_graph.as_default():
|
|
||||||
for x in self.inputs:
|
|
||||||
if x.op.type == 'PlaceholderWithDefault':
|
|
||||||
self._placeholder_default_values[ops.tensor_id(
|
|
||||||
x)] = tensor_util.constant_value(x.op.inputs[0])
|
|
||||||
|
|
||||||
def __call__(self, inputs):
|
|
||||||
input_values = nest.flatten(inputs, expand_composites=True)
|
|
||||||
|
|
||||||
if self._freezable_vars_values:
|
|
||||||
input_values = input_values + self._freezable_vars_values
|
|
||||||
converted_inputs = []
|
|
||||||
for tensor, value in zip(self._input_references, input_values):
|
|
||||||
if value is None:
|
|
||||||
# Assume `value` is a placeholder with default
|
|
||||||
value = self._placeholder_default_values.get(
|
|
||||||
ops.tensor_id(tensor), None)
|
|
||||||
if value is None:
|
|
||||||
raise ValueError(
|
|
||||||
'You must feed a value for placeholder %s' % (tensor,))
|
|
||||||
if not isinstance(value, ops.Tensor):
|
|
||||||
value = ops.convert_to_tensor_v2(value, dtype=tensor.dtype)
|
|
||||||
if value.dtype != tensor.dtype:
|
|
||||||
# Temporary workaround due to `convert_to_tensor` not casting floats.
|
|
||||||
# See b/119637405
|
|
||||||
value = math_ops.cast(value, tensor.dtype)
|
|
||||||
converted_inputs.append(value)
|
|
||||||
outputs = self._graph_fn(*converted_inputs)
|
|
||||||
|
|
||||||
# EagerTensor.numpy() will often make a copy to ensure memory safety.
|
|
||||||
# However in this case `outputs` is not directly returned, so it is always
|
|
||||||
# safe to reuse the underlying buffer without checking. In such a case the
|
|
||||||
# private numpy conversion method is preferred to guarantee performance.
|
|
||||||
return nest.pack_sequence_as(
|
|
||||||
self._outputs_structure,
|
|
||||||
[x._numpy() for x in outputs], # pylint: disable=protected-access
|
|
||||||
expand_composites=True)
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.function')
|
@keras_export('keras.backend.function')
|
||||||
|
@ -3900,7 +3819,20 @@ def function(inputs, outputs, updates=None, name=None, **kwargs):
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError('Session keyword arguments are not support during '
|
raise ValueError('Session keyword arguments are not support during '
|
||||||
'eager execution. You passed: %s' % (kwargs,))
|
'eager execution. You passed: %s' % (kwargs,))
|
||||||
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
|
if updates:
|
||||||
|
raise ValueError('`updates` argument is not support during '
|
||||||
|
'eager execution. You passed: %s' % (updates,))
|
||||||
|
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
|
||||||
|
model = models.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
|
||||||
|
def func(model_inputs):
|
||||||
|
outs = model(model_inputs)
|
||||||
|
if wrap_outputs:
|
||||||
|
outs = [outs]
|
||||||
|
return tf_utils.to_numpy_or_python_type(outs)
|
||||||
|
return func
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
|
@ -6344,10 +6276,6 @@ class ContextValueCache(weakref.WeakKeyDictionary):
|
||||||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||||
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
|
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
|
||||||
|
|
||||||
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
|
|
||||||
# Each set tracks objects created via `freezable_variable` in the graph.
|
|
||||||
_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
|
||||||
|
|
||||||
# This dictionary holds a mapping between a graph and variables to initialize
|
# This dictionary holds a mapping between a graph and variables to initialize
|
||||||
# in the graph.
|
# in the graph.
|
||||||
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||||
|
|
|
@ -1677,8 +1677,10 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase):
|
||||||
t, p, from_logits=True, axis=0),
|
t, p, from_logits=True, axis=0),
|
||||||
self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
|
self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3)
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph']))
|
||||||
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
|
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
|
||||||
|
# This test only runs in graph because the TF op layer is not supported yet
|
||||||
|
# for sparse ops.
|
||||||
t = backend.placeholder()
|
t = backend.placeholder()
|
||||||
p = backend.placeholder()
|
p = backend.placeholder()
|
||||||
o = backend.sparse_categorical_crossentropy(t, p)
|
o = backend.sparse_categorical_crossentropy(t, p)
|
||||||
|
@ -1870,6 +1872,8 @@ class TestRandomOps(test.TestCase):
|
||||||
class FunctionTest(test.TestCase):
|
class FunctionTest(test.TestCase):
|
||||||
|
|
||||||
def test_function_basics(self):
|
def test_function_basics(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.skipTest('eager backend.function does not support updates')
|
||||||
x1 = backend.placeholder(shape=(), dtype='float32')
|
x1 = backend.placeholder(shape=(), dtype='float32')
|
||||||
x2 = backend.placeholder(shape=(), dtype='int32')
|
x2 = backend.placeholder(shape=(), dtype='int32')
|
||||||
v = backend.variable(10.)
|
v = backend.variable(10.)
|
||||||
|
@ -1916,6 +1920,9 @@ class FunctionTest(test.TestCase):
|
||||||
self.assertEqual(result, 4.)
|
self.assertEqual(result, 4.)
|
||||||
|
|
||||||
def test_tuple_updates(self):
|
def test_tuple_updates(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.skipTest('eager backend.function does not support updates')
|
||||||
|
|
||||||
x_ph = backend.placeholder(ndim=2)
|
x_ph = backend.placeholder(ndim=2)
|
||||||
v = backend.variable(np.ones((4, 2)))
|
v = backend.variable(np.ones((4, 2)))
|
||||||
output = x_ph ** 2 + v
|
output = x_ph ** 2 + v
|
||||||
|
@ -1929,7 +1936,7 @@ class FunctionTest(test.TestCase):
|
||||||
|
|
||||||
class BackendGraphTests(test.TestCase, parameterized.TestCase):
|
class BackendGraphTests(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph']))
|
||||||
def test_function_placeholder_with_default(self):
|
def test_function_placeholder_with_default(self):
|
||||||
with backend.get_graph().as_default():
|
with backend.get_graph().as_default():
|
||||||
x1 = array_ops.placeholder_with_default(
|
x1 = array_ops.placeholder_with_default(
|
||||||
|
|
|
@ -248,7 +248,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||||
constants[i] = op_input
|
constants[i] = op_input
|
||||||
else:
|
else:
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
constants[i] = backend.function([], op_input)([])
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
constants[i] = backend.eval_in_eager_or_function(op_input)
|
||||||
|
else:
|
||||||
|
constants[i] = backend.function([], op_input)([])
|
||||||
layer_inputs = unnest_if_single_tensor(layer_inputs)
|
layer_inputs = unnest_if_single_tensor(layer_inputs)
|
||||||
processed_ops, created_layers = _create_keras_history_helper(
|
processed_ops, created_layers = _create_keras_history_helper(
|
||||||
layer_inputs, processed_ops, created_layers)
|
layer_inputs, processed_ops, created_layers)
|
||||||
|
|
|
@ -161,8 +161,11 @@ class InputLayer(base_layer.Layer):
|
||||||
'InputLayer, you should instantiate your model and '
|
'InputLayer, you should instantiate your model and '
|
||||||
'directly call it on your input.')
|
'directly call it on your input.')
|
||||||
self.is_placeholder = False
|
self.is_placeholder = False
|
||||||
self._batch_input_shape = tuple(input_tensor.shape.as_list())
|
try:
|
||||||
|
self._batch_input_shape = tuple(input_tensor.shape.as_list())
|
||||||
|
except ValueError:
|
||||||
|
# If the shape cannot be represented as a tuple (e.g. unknown rank)
|
||||||
|
self._batch_input_shape = None
|
||||||
# Create an input node.
|
# Create an input node.
|
||||||
input_tensor._keras_mask = None
|
input_tensor._keras_mask = None
|
||||||
node_module.Node(layer=self, outputs=input_tensor)
|
node_module.Node(layer=self, outputs=input_tensor)
|
||||||
|
|
|
@ -1935,7 +1935,7 @@ def get_input_shape_and_dtype(layer):
|
||||||
raise ValueError('An empty Model cannot be used as a Layer.')
|
raise ValueError('An empty Model cannot be used as a Layer.')
|
||||||
layer = layer.layers[0]
|
layer = layer.layers[0]
|
||||||
|
|
||||||
if hasattr(layer, '_batch_input_shape'):
|
if getattr(layer, '_batch_input_shape', None):
|
||||||
return layer._batch_input_shape, layer.dtype
|
return layer._batch_input_shape, layer.dtype
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
|
@ -288,9 +288,10 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||||
constant_op.constant(40.0, shape=(1, 1)))
|
constant_op.constant(40.0, shape=(1, 1)))
|
||||||
|
|
||||||
def test_no_tracking(self):
|
def test_no_tracking(self):
|
||||||
x = keras.backend.placeholder((10, 10))
|
if not context.executing_eagerly():
|
||||||
keras.layers.Dense(1)(x)
|
x = constant_op.constant(1.0, shape=(10, 10))
|
||||||
self.assertTrue(x._keras_history_checked)
|
keras.layers.Dense(1)(x)
|
||||||
|
self.assertTrue(x._keras_history_checked)
|
||||||
|
|
||||||
def test_timing_scales_linearly(self):
|
def test_timing_scales_linearly(self):
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
@ -1213,9 +1214,14 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
||||||
f_merged = keras.backend.function([inputs], layer(inputs))
|
f_merged = keras.backend.function([inputs], layer(inputs))
|
||||||
f_forward = keras.backend.function([inputs],
|
f_forward = keras.backend.function([inputs],
|
||||||
layer.forward_layer(inputs))
|
layer.forward_layer(inputs))
|
||||||
|
|
||||||
|
# TODO(kaftan): after KerasTensor refactor TF op layers should work
|
||||||
|
# with many composite tensors, and this shouldn't need to be a lambda
|
||||||
|
# layer.
|
||||||
|
reverse_layer = core.Lambda(array_ops.reverse, arguments=dict(axis=[1]))
|
||||||
f_backward = keras.backend.function(
|
f_backward = keras.backend.function(
|
||||||
[inputs],
|
[inputs],
|
||||||
array_ops.reverse(layer.backward_layer(inputs), axis=[1]))
|
reverse_layer(layer.backward_layer(inputs)))
|
||||||
|
|
||||||
y_merged = f_merged(x)
|
y_merged = f_merged(x)
|
||||||
y_expected = merge_func(
|
y_expected = merge_func(
|
||||||
|
|
|
@ -125,8 +125,10 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase):
|
||||||
backend.eval(output_from_softmax),
|
backend.eval(output_from_softmax),
|
||||||
atol=1e-5)
|
atol=1e-5)
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph']))
|
||||||
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
|
def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
|
||||||
|
# This test only runs in graph because the TF op layer is not supported yet
|
||||||
|
# for sparse ops.
|
||||||
t = backend.placeholder()
|
t = backend.placeholder()
|
||||||
p = backend.placeholder()
|
p = backend.placeholder()
|
||||||
o = losses.sparse_categorical_crossentropy(t, p)
|
o = losses.sparse_categorical_crossentropy(t, p)
|
||||||
|
|
Loading…
Reference in New Issue