This change adds WIP support for using KerasTensor objects in Keras's functional API instead of symbolic graph tf.Tensors. It is controlled by an internal behavior flag that is disabled by default and is not yet exposed in TF's APIs.

`KerasTensor`s are an alternative representation for Keras `Inputs`
  and for intermediate outputs of layers during Functional API construction of
  models. They are a lightweight data structure comprised of only the
  `tf.TypeSpec` of the Tensor that will be consumed/produced in the
  corresponding position of the model.

  They implement just small subset of `tf.Tensor`'s attributes and
  methods, and also overload
  the same operators as `tf.Tensor` and automatically turn them into
  Keras layers in the model.

  `KerasTensor`s are still internal-only and are a work in progress, but they
  have several advantages over using a graph `tf.Tensor` to represent
  symbolic values in functional models.
  - Unlike symbolic tensors, they do not need to refer to a graph. This means
    Keras does not need to maintain a never-deleted global background graph
    containing all layers ever called during functional model construction when
    constructing Functional Models with KerasTensors. These memory savings
    can be significant.

  - Triggering Keras functional model construction is simpler
    when it just has to check whether something is a KerasTensor, rather
    than trying to infer if a tensor was meant to be a symbolic keras
    representation or just a value produced during function tracing. This means we can add support for cases where values in nest.flatten(*args, **kwargs) are a completely arbitrary mix of KerasTensors and objects that are not KerasTensors, as long as any value is a KerasTensor.

  - Autolambda layers (converting tf ops on symbolic Keras tensors to lambda
    Keras layers in the model) use TF's internal dispatching mechanism, instead
    of trying to manually walk a graph and extract nodes from it.
    The dispatching mechanism is simpler, works more reliably, and is less
    likely to run into issues with composite tensors or strange tf ops/nodes.

    (And when it fails, it's by design: because dispatch is explicitly not
    supported on the op & it's more obvious that dispatch doesn't support the
    setting).

  - Because they support arbitrary typespecs, models/layers that use
    KerasTensors are generally more friendly to composite tensors of different
    types than using symbolic graph tensors (which must have a TensorSpec and
    can't have arbitrary typespecs)

  To experiment with using KerasTensors instead of symbolic graph `tf.Tensors`,
  import keras_tensor directly and call `keras_tensor.enable_keras_tensors()`

PiperOrigin-RevId: 315009281
Change-Id: I6765f3a44da43f965ec261b6b193df26598cffae
This commit is contained in:
Tomer Kaftan 2020-06-05 15:38:47 -07:00 committed by TensorFlower Gardener
parent 52028d8d95
commit 0e1f3de50a
27 changed files with 907 additions and 97 deletions

View File

@ -86,6 +86,7 @@ py_library(
"//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/keras/engine:keras_tensor",
],
)

View File

@ -49,8 +49,10 @@ from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend_config
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
@ -1070,6 +1072,8 @@ def is_keras_tensor(x):
True
"""
if keras_tensor.keras_tensors_enabled():
return isinstance(x, keras_tensor.KerasTensor)
if not isinstance(x,
(ops.Tensor, variables_module.Variable,
sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
@ -1120,35 +1124,56 @@ def placeholder(shape=None,
raise ValueError(
'Cannot set both sparse and ragged to True when creating a placeholder.'
)
if dtype is None:
dtype = floatx()
if not shape:
if ndim:
shape = (None,) * ndim
with get_graph().as_default():
if keras_tensor.keras_tensors_enabled():
spec = tensor_spec.TensorSpec(
shape=shape, dtype=dtype, name=name)
if sparse:
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
spec = sparse_tensor.SparseTensorSpec(
shape=shape, dtype=dtype)
elif ragged:
ragged_rank = 0
for i in range(1, len(shape)):
if shape[i] is None:
# Hacky because could be tensorshape or tuple maybe?
# Or just tensorshape?
if shape[i] is None or (
hasattr(shape[i], 'value') and
shape[i].value is None):
ragged_rank = i
type_spec = ragged_tensor.RaggedTensorSpec(
spec = ragged_tensor.RaggedTensorSpec(
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
def tensor_spec_to_placeholder(tensorspec):
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
expand_composites=True)
else:
x = array_ops.placeholder(dtype, shape=shape, name=name)
x = keras_tensor.KerasTensor(spec, name=name)
else:
with get_graph().as_default():
if sparse:
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
elif ragged:
ragged_rank = 0
for i in range(1, len(shape)):
if shape[i] is None:
ragged_rank = i
type_spec = ragged_tensor.RaggedTensorSpec(
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
def tensor_spec_to_placeholder(tensorspec):
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
expand_composites=True)
else:
x = array_ops.placeholder(dtype, shape=shape, name=name)
if context.executing_eagerly():
# Add keras_history connectivity information to the placeholder
# when the placeholder is built in a top-level eager context
# (intended to be used with keras.backend.function)
from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top
return input_layer.Input(tensor=x)
x = input_layer.Input(tensor=x)
if keras_tensor.keras_tensors_enabled():
x._is_backend_placeholder = True
return x
@ -1163,6 +1188,8 @@ def is_placeholder(x):
Boolean.
"""
try:
if keras_tensor.keras_tensors_enabled():
return hasattr(x, '_is_backend_placeholder')
if isinstance(x, composite_tensor.CompositeTensor):
flat_components = nest.flatten(x, expand_composites=True)
return py_any(is_placeholder(c) for c in flat_components)

View File

@ -61,6 +61,10 @@ def keras_model_type_combinations():
return combinations.combine(model_type=KERAS_MODEL_TYPES)
def keras_tensor_combinations():
return combinations.combine(use_keras_tensors=['True', 'False'])
class KerasModeCombination(test_combinations.TestCombination):
"""Combination for Keras test mode.
@ -100,11 +104,32 @@ class KerasModelTypeCombination(test_combinations.TestCombination):
return [test_combinations.OptionalParameter('model_type')]
class KerasTensorCombination(test_combinations.TestCombination):
"""Combination for whether KerasTensors are being used or not.
It by default includes `True` and `False`:
running Keras's functional API with KerasTensors
as the inputs, and without.
"""
def context_managers(self, kwargs):
use_keras_tensors = kwargs.pop('use_keras_tensors', None)
if use_keras_tensors is not None:
return [testing_utils.use_keras_tensors_scope(use_keras_tensors)]
else:
return []
def parameter_modifiers(self):
return [test_combinations.OptionalParameter('use_keras_tensors')]
_defaults = combinations.generate.keywords['test_combinations']
generate = functools.partial(
combinations.generate,
test_combinations=_defaults +
(KerasModeCombination(), KerasModelTypeCombination()))
(KerasModeCombination(), KerasModelTypeCombination(),
KerasTensorCombination()))
combine = test_combinations.combine
times = test_combinations.times
NamedObject = test_combinations.NamedObject

View File

@ -41,6 +41,7 @@ py_library(
":base_preprocessing_layer",
":data_adapter",
":input_spec",
":keras_tensor",
"//tensorflow/python:composite_tensor_utils",
"//tensorflow/python:py_checkpoint_reader",
"//tensorflow/python/data",
@ -164,6 +165,18 @@ py_library(
],
)
py_library(
name = "keras_tensor",
srcs = ["keras_tensor.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:tensor_spec",
"@six_archive//:six",
],
)
py_library(
name = "base_preprocessing_layer",
srcs = [

View File

@ -55,6 +55,7 @@ from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
@ -700,7 +701,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# with the shape the Layer will be called on (these users will have to
# implement `compute_output_shape` themselves).
self._maybe_build(input_shape)
with func_graph.FuncGraph('graph').as_default():
with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default():
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
def _make_placeholder_like(shape):
ph = backend.placeholder(shape=shape, dtype=self.dtype)
@ -759,6 +760,76 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
output_shape)
def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs):
if self.dynamic:
# We will use static shape inference to return symbolic tensors
# matching the specifications of the layer outputs.
# Since `self.dynamic` is True, we will never attempt to
# run the underlying TF graph (which is disconnected).
# TODO(fchollet): consider py_func as an alternative, which
# would enable us to run the underlying graph if needed.
input_signature = nest.map_structure(
lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype),
inputs)
output_signature = self.compute_output_signature(input_signature)
return nest.map_structure(keras_tensor.KerasTensor, output_signature)
else:
return self._infer_output_signature(inputs, args, kwargs, input_masks)
def _infer_output_signature(self, inputs, args, kwargs, input_masks):
"""TODO(kaftan): Docstring."""
call_fn = self.call
# Wrapping `call` function in autograph to allow for dynamic control
# flow and control dependencies in call. We are limiting this to
# subclassed layers as autograph is strictly needed only for
# subclassed layers and models.
# tf_convert will respect the value of autograph setting in the
# enclosing tf.function, if any.
if (base_layer_utils.is_subclassed(self) and
not base_layer_utils.from_saved_model(self)):
call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
# We enter a scratch graph and build placeholder inputs inside of it that
# match the input args.
# We then call the layer inside of the scratch graph to identify the
# output signatures, then we build KerasTensors corresponding to those
# outputs.
scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph')
with scratch_graph.as_default():
inputs = nest.map_structure(
keras_tensor.keras_tensor_to_placeholder, inputs)
args = nest.map_structure(
keras_tensor.keras_tensor_to_placeholder, args)
kwargs = nest.map_structure(
keras_tensor.keras_tensor_to_placeholder, kwargs)
input_masks = nest.map_structure(
keras_tensor.keras_tensor_to_placeholder, input_masks)
inputs = self._maybe_cast_inputs(inputs)
# try:
with backend.name_scope(self._name_scope()):
with ops.enable_auto_cast_variables(self._compute_dtype_object):
# Build layer if applicable (if the `build` method has been
# overridden).
# TODO(kaftan): do we maybe_build here, or have we already done it?
self._maybe_build(inputs)
outputs = call_fn(inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks,
build_graph=False)
outputs = nest.map_structure(keras_tensor.keras_tensor_from_tensor, outputs)
if hasattr(self, '_set_inputs') and not self.inputs:
# TODO(kaftan): figure out if we ned to do this at all
# Subclassed network: explicitly set metadata normally set by
# a call to self._set_inputs().
self._set_inputs(inputs, outputs)
del scratch_graph
return outputs
@generic_utils.default
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
"""Computes an output mask tensor.
@ -954,6 +1025,27 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
args, kwargs)
training_arg_passed_by_framework = True
if keras_tensor.keras_tensors_enabled():
with call_context.enter(
layer=self, inputs=inputs, build_graph=True, training=training_value):
# Check input assumptions set after layer building, e.g. input shape.
outputs = self._keras_tensor_symbolic_call(
inputs, input_masks, args, kwargs)
if outputs is None:
raise ValueError('A layer\'s `call` method should return a '
'Tensor or a list of Tensors, not None '
'(layer: ' + self.name + ').')
if training_arg_passed_by_framework:
args, kwargs = self._set_call_arg_value(
'training', None, args, kwargs, pop_kwarg_if_none=True)
if mask_arg_passed_by_framework:
kwargs.pop('mask')
# Node connectivity does not special-case the first argument.
outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
outputs)
return outputs
# Only create Keras history if at least one tensor originates from a
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
# framework.
@ -1237,6 +1329,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@property
@doc_controls.do_not_doc_inheritable
def updates(self):
if (keras_tensor.keras_tensors_enabled()
and ops.executing_eagerly_outside_functions()):
return []
collected_updates = []
all_layers = self._flatten_layers()
with backend.get_graph().as_default():
@ -1399,10 +1495,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
continue
if loss is None:
continue
if not tensor_util.is_tensor(loss):
if not tensor_util.is_tensor(loss) and not isinstance(
loss, keras_tensor.KerasTensor):
loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx())
# TF Functions should take the eager path.
if (tf_utils.is_symbolic_tensor(loss) and
if ((tf_utils.is_symbolic_tensor(loss) or
isinstance(loss, keras_tensor.KerasTensor)) and
not base_layer_utils.is_in_tf_function()):
symbolic_losses.append(loss)
elif tensor_util.is_tensor(loss):
@ -1418,7 +1516,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._eager_losses.extend(eager_losses)
if in_call_context:
if in_call_context and not keras_tensor.keras_tensors_enabled():
for symbolic_loss in symbolic_losses:
self._losses.append(symbolic_loss)
else:
@ -1520,7 +1618,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
raise TypeError('Unknown keyword arguments: ', str(kwargs.keys()))
from_metric_obj = hasattr(value, '_metric_obj')
is_symbolic = tf_utils.is_symbolic_tensor(value)
if keras_tensor.keras_tensors_enabled():
is_symbolic = isinstance(value, keras_tensor.KerasTensor)
else:
is_symbolic = tf_utils.is_symbolic_tensor(value)
in_call_context = base_layer_utils.call_context().in_call
if name is None and not from_metric_obj:
@ -2217,7 +2318,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
"""
return self._dtype_policy.compute_dtype
def _maybe_cast_inputs(self, inputs, input_list):
def _maybe_cast_inputs(self, inputs, input_list=None):
"""Maybe casts the inputs to the compute dtype.
If self._compute_dtype is floating-point, and self_autocast is True,
@ -2230,6 +2331,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Returns:
`inputs`, but tensors may have been casted to self._compute_dtype
"""
if not input_list:
input_list = nest.flatten(inputs)
compute_dtype_object = self._compute_dtype_object
should_autocast = (
self._autocast and compute_dtype_object and
@ -3094,11 +3198,19 @@ class AddMetric(Layer):
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
if context.executing_eagerly():
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
"""Check the arguments to see if we are constructing a functional model."""
if keras_tensor.keras_tensors_enabled():
# We are constructing a functional model if any of the inputs
# are KerasTensors
return any(
isinstance(tensor, keras_tensor.KerasTensor)
for tensor in nest.flatten([inputs, args, kwargs]))
else:
return (base_layer_utils.is_in_keras_graph() or
all(hasattr(t, '_keras_history') for t in input_list))
if context.executing_eagerly():
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
else:
return (base_layer_utils.is_in_keras_graph() or
all(hasattr(t, '_keras_history') for t in input_list))
def _convert_numpy_or_python_types(x):

View File

@ -48,7 +48,6 @@ from tensorflow.python.keras.optimizer_v2 import rmsprop
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.layers import core as legacy_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2
@ -86,7 +85,9 @@ class InvalidLayer(base_layer.Layer):
class BaseLayerTest(keras_parameterized.TestCase):
@combinations.generate(combinations.keras_model_type_combinations())
@combinations.generate(combinations.times(
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations()))
def test_dynamic_layer(self):
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
input_shape=(3,))
@ -95,26 +96,30 @@ class BaseLayerTest(keras_parameterized.TestCase):
self.assertEqual(model.run_eagerly, True)
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
@combinations.generate(combinations.keras_model_type_combinations())
@combinations.generate(combinations.times(
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations()))
def test_dynamic_layer_error(self):
# Functional Models hit the `dyanamic=True` error during construction.
# Subclass Models should just throw the original autograph error during
# execution.
model_type = testing_utils.get_model_type()
if 'subclass' in model_type and context.executing_eagerly():
error_type = errors_impl.OperatorNotAllowedInGraphError
error_message = 'iterating over `tf.Tensor` is not allowed'
else:
error_type = TypeError
error_message = 'attempting to use Python control flow'
with self.assertRaisesRegexp(error_type, error_message):
raised_error = False
try:
model = testing_utils.get_model_from_layers([DynamicLayer()],
input_shape=(3,))
model.compile(rmsprop.RMSprop(0.001), loss='mse')
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
except errors_impl.OperatorNotAllowedInGraphError as e:
if 'iterating over `tf.Tensor` is not allowed' in str(e):
raised_error = True
except TypeError as e:
if 'attempting to use Python control flow' in str(e):
raised_error = True
self.assertTrue(raised_error)
@combinations.generate(combinations.keras_model_type_combinations())
@combinations.generate(combinations.times(
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations()))
def test_dynamic_layer_error_running_in_graph_mode(self):
with ops.get_default_graph().as_default():
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
@ -155,11 +160,6 @@ class BaseLayerTest(keras_parameterized.TestCase):
self.assertEqual(layer.build_counter, 1)
self.assertEqual(layer.build_shape.as_list(), [None, 10])
def test_eager_switch_case_input(self):
task = input_layer.Input(shape=(), dtype=dtypes.int32)
control_flow_ops.switch_case(
task[0], [lambda: constant_op.constant(1.0) for _ in range(10)])
def test_dynamic_layer_with_deferred_sequential_model(self):
model = sequential.Sequential([DynamicLayer(dynamic=True), layers.Dense(3)])
self.assertEqual(model.dynamic, True)
@ -284,6 +284,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
@combinations.generate(
combinations.times(
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations(),
combinations.combine(mode=['graph', 'eager'])))
def test_build_with_numpy_data(self):
model_layers = [
@ -384,7 +385,8 @@ class BaseLayerTest(keras_parameterized.TestCase):
# b/124459427: can't test with `run_eagerly=True` for now.
@combinations.generate(
combinations.times(combinations.keras_mode_combinations(),
combinations.keras_model_type_combinations()))
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations()))
def test_training_arg_in_defun(self):
layer = self._get_layer_with_training_arg()
model = testing_utils.get_model_from_layers([layer], input_shape=(1,))
@ -409,7 +411,8 @@ class BaseLayerTest(keras_parameterized.TestCase):
@combinations.generate(
combinations.times(combinations.keras_mode_combinations(),
combinations.keras_model_type_combinations()))
combinations.keras_model_type_combinations(),
combinations.keras_tensor_combinations()))
def test_raw_variable_assignment(self):
class RawVariableLayer(base_layer.Layer):

View File

@ -20,12 +20,17 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
@ -70,7 +75,7 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
_ = backend.batch_get_value(table_handler.get_tensors())
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
@combinations.generate(combinations.combine(mode=['eager']))
class OpLayerTest(keras_parameterized.TestCase):
def test_tensor_op_layer(self):
@ -96,6 +101,35 @@ class OpLayerTest(keras_parameterized.TestCase):
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
def test_ragged_op_layer_keras_tensors(self):
with testing_utils.use_keras_tensors_scope(True):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
float_values = math_ops.cast(int_values, dtypes.float32)
model = keras.Model(int_values, float_values)
model.compile(loss='mse')
input_data = ragged_factory_ops.constant(
[[1, 2], [3, 4]], dtype=np.int32)
expected = [[1.0, 2.0], [3.0, 4.0]]
output = model.predict(input_data)
self.assertIsInstance(output, ragged_tensor.RaggedTensor)
self.assertAllClose(expected, output)
def test_sparse_op_layer_keras_tensors(self):
with testing_utils.use_keras_tensors_scope(True):
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
float_values = math_ops.cast(int_values, dtypes.float32)
_ = keras.Model(int_values, float_values)
model = keras.Model(int_values, float_values)
model.compile(loss='mse')
input_data = sparse_ops.from_dense(
np.array([[1, 2], [3, 4]], dtype=np.int32))
expected = [[1.0, 2.0], [3.0, 4.0]]
output = model.predict(input_data)
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertAllClose(expected, sparse_ops.sparse_tensor_to_dense(output))
if __name__ == '__main__':
test.main()

View File

@ -32,6 +32,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_layer as input_layer_module
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.engine import training_utils
@ -994,7 +995,8 @@ def _map_subgraph_network(inputs, outputs):
Returns:
A tuple of List{Node] and List[Layer].
"""
base_layer_utils.create_keras_history(outputs)
if not keras_tensor.keras_tensors_enabled():
base_layer_utils.create_keras_history(outputs)
# Keep only nodes and layers in the topology between inputs and outputs.
_, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers

View File

@ -965,6 +965,45 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(),
combinations.keras_tensor_combinations()))
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
class MaybeAdd(layers.Layer):
def call(self, x1, x2=None):
if x2 is not None:
return x1 + x2
return x1
input2 = input_layer_lib.Input(10)
outputs = MaybeAdd()(3., x2=input2)
model = training_lib.Model([input2], outputs)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=7 * np.ones((10, 10)),
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=7 * np.ones((10, 10)),
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_composite_call_kwarg_derived_from_keras_layer(self):
@ -1006,6 +1045,58 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.keras_tensor_combinations()))
def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
# This functionality is unsupported in v1 graphs
class AddAll(layers.Layer):
def call(self, x1_x2, x3):
x1, x2 = x1_x2
out = x1 + x2
if x3 is not None:
for t in x3.values():
out += t
return out
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
input3 = input_layer_lib.Input(10)
outputs = AddAll()(
[input1, 4 * array_ops.ones((1, 10))],
x3={
'a': input2,
'b': input3,
'c': 5 * array_ops.ones((1, 10))
})
model = training_lib.Model([input1, input2, input3], outputs)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
y=15 * np.ones((10, 10)),
batch_size=2)
# Check that all inputs were correctly added.
self.assertEqual(history.history['loss'][0], 0.0)
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'AddAll': AddAll})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
y=15 * np.ones((10, 10)),
batch_size=2)
# Check that all inputs were correctly added.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_call_nested_arg_derived_from_keras_layer(self):

View File

@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.saving.saved_model import layer_serialization
from tensorflow.python.keras.utils import tf_utils
@ -116,6 +117,10 @@ class InputLayer(base_layer.Layer):
if kwargs:
raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
if sparse and ragged:
raise ValueError(
'Cannot set both sparse and ragged to True in a Keras input.')
if not name:
prefix = 'input'
name = prefix + '_' + str(backend.get_uid(prefix))
@ -157,7 +162,14 @@ class InputLayer(base_layer.Layer):
self.is_placeholder = True
self._batch_input_shape = batch_input_shape
else:
if not tf_utils.is_symbolic_tensor(input_tensor):
raise_eager_tensor_error = False
if keras_tensor.keras_tensors_enabled():
if not isinstance(input_tensor, keras_tensor.keras_tensors_enabled()):
raise_eager_tensor_error = True
else:
if not tf_utils.is_symbolic_tensor(input_tensor):
raise_eager_tensor_error = True
if raise_eager_tensor_error:
raise ValueError('You should not pass an EagerTensor to `Input`. '
'For example, instead of creating an '
'InputLayer, you should instantiate your model and '
@ -173,7 +185,8 @@ class InputLayer(base_layer.Layer):
node_module.Node(layer=self, outputs=input_tensor)
# Store type spec
if isinstance(input_tensor, composite_tensor.CompositeTensor):
if isinstance(input_tensor, (
composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
else:
self._type_spec = tensor_spec.TensorSpec(

View File

@ -0,0 +1,229 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras Input Tensor used to track functional API Topology."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import type_spec as type_spec_module
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
_KERAS_TENSORS_ENABLED = False
def enable_keras_tensors():
"""Enable using KerasTensors in Keras's functional API."""
global _KERAS_TENSORS_ENABLED
_KERAS_TENSORS_ENABLED = True
def disable_keras_tensors():
"""Disable using KerasTensors in Keras's functional API."""
global _KERAS_TENSORS_ENABLED
_KERAS_TENSORS_ENABLED = False
def keras_tensors_enabled():
"""Return a bool specifying if KerasTensors are enabled."""
return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions()
class KerasTensor(object):
"""A representation of a Keras in/output during Functional API construction.
`KerasTensor`s are an alternative representation for Keras `Inputs`
and for intermediate outputs of layers during Functional API construction of
models. They are a lightweight data structure comprised of only the
`tf.TypeSpec` of the Tensor that will be consumed/produced in the
corresponding position of the model.
They implement just small subset of `tf.Tensor`'s attributes and
methods, and also overload
the same operators as `tf.Tensor` and automatically turn them into
Keras layers in the model.
`KerasTensor`s are still internal-only and are a work in progress, but they
have several advantages over using a graph `tf.Tensor` to represent
symbolic values in functional models.
- Unlike symbolic tensors, they do not need to refer to a graph. This means
Keras does not need to maintain a never-deleted global background graph
containing all layers ever called during functional model construction when
constructing Functional Models with KerasTensors. These memory savings
can be significant.
- Triggering Keras functional model construction is simpler
when it just has to check whether something is a KerasTensor, rather
than trying to infer if a tensor was meant to be a symbolic keras
representation or just a value produced during function tracing.
- Autolambda layers (converting tf ops on symbolic Keras tensors to lambda
Keras layers in the model) use TF's internal dispatching mechanism, instead
of trying to manually walk a graph and extract nodes from it.
The dispatching mechanism is simpler, works more reliably, and is less
likely to run into issues with composite tensors or strange tf ops/nodes.
(And when it fails, it's by design: because dispatch is explicitly not
supported on the op & it's more obvious that dispatch doesn't support the
setting).
- Because they support arbitrary typespecs, models/layers that use
KerasTensors are generally more friendly to composite tensors of different
types than using symbolic graph tensors (which must have a TensorSpec and
can't have arbitrary typespecs)
To experiment with using KerasTensors instead of symbolic graph `tf.Tensors`,
import keras_tensor directly and call `keras_tensor.enable_keras_tensors()`
"""
def __init__(self, type_spec, name=None):
"""Construct a KerasTensor from a type_spec and an optional name."""
if not isinstance(type_spec, type_spec_module.TypeSpec):
raise ValueError('KerasTensors must be constructed with a `tf.TypeSpec`.')
self._type_spec = type_spec
if name is None and hasattr(type_spec, 'name'):
name = type_spec.name
self._name = name
@property
def type_spec(self):
"""Returns the `TypeSpec` that represents this Tensor."""
return self._type_spec
@property
def shape(self):
"""Returns the `TensorShape` that represents the shape of the tensor."""
# TODO(kaftan): This is only valid for normal/sparse/ragged tensors.
# may need to raise an error when it's not valid for a type_spec,
# but some keras code (e.g. build-related stuff) will likely fail when
# it can't access shape or dtype
return self._type_spec._shape # pylint: disable=protected-access
def get_shape(self):
return self.shape
@property
def dtype(self):
"""Returns the `dtype` of elements in the tensor."""
# TODO(kaftan): This is only valid for normal/sparse/ragged tensors.
# may need to raise an error when it's not valid for a type_spec,
# but some keras code (e.g. build-related stuff) will likely fail when
# it can't access shape or dtype
return self._type_spec._dtype # pylint: disable=protected-access
def ref(self):
"""Returns a hashable reference object to this KerasTensor.
The primary use case for this API is to put KerasTensors in a
set/dictionary. We can't put tensors in a set/dictionary as
`tensor.__hash__()` is not available and tensor equality (`==`) is supposed
to produce a tensor representing if the two inputs are equal.
See the documentation of `tf.Tensor.ref()` for more info.
"""
return object_identity.Reference(self)
def __iter__(self):
shape = None
if self.shape.ndims is not None:
shape = [dim.value for dim in self.shape.dims]
if shape is None:
raise TypeError('Cannot iterate over a KerasTensor with unknown shape.')
if not shape:
raise TypeError('Cannot iterate over a scalar.')
if shape[0] is None:
raise TypeError(
'Cannot iterate over a KerasTensor with unknown first dimension.')
return _KerasTensorIterator(self, shape[0])
@property
def name(self):
"""Returns the (optionally provided) name of the described tensor."""
return self._name
@classmethod
def _overload_all_operators(cls): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
cls._overload_operator(operator)
@classmethod
def _overload_operator(cls, operator): # pylint: disable=invalid-name
"""Overload an operator with the same overloading as `ops.Tensor`.
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
Args:
operator: string. The operator name.
"""
tensor_oper = getattr(ops.Tensor, operator)
# Compatibility with Python 2:
# Python 2 unbound methods have type checks for the first arg,
# so we need to extract the underlying function
tensor_oper = getattr(tensor_oper, '__func__', tensor_oper)
setattr(cls, operator, tensor_oper)
KerasTensor._overload_all_operators() # pylint: disable=protected-access
class _KerasTensorIterator(object):
"""Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
def __init__(self, tensor, dim0):
self._tensor = tensor
self._index = 0
self._limit = dim0
def __iter__(self):
return self
def __next__(self):
if self._index == self._limit:
raise StopIteration
result = self._tensor[self._index]
self._index += 1
return result
next = __next__ # python2.x compatibility.
def keras_tensor_to_placeholder(x):
"""TODO(kaftan): Docstring."""
if isinstance(x, KerasTensor):
def tensor_spec_to_placeholder(tensorspec):
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
ph = nest.map_structure(tensor_spec_to_placeholder, x.type_spec,
expand_composites=True)
return ph
else:
return x
def keras_tensor_from_tensor(x):
name = getattr(x, 'name', None)
out = KerasTensor(type_spec_module.type_spec_from_value(x), name=name)
if hasattr(x, '_keras_mask'):
out._keras_mask = KerasTensor( # pylint: disable=protected-access
type_spec_module.type_spec_from_value(x._keras_mask)) # pylint: disable=protected-access
return out

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -80,11 +81,13 @@ class Node(object):
self._single_positional_tensor_passed = (not self.call_kwargs and len(
self.call_args) == 1 and tensor_util.is_tensor(self.call_args[0]))
# Create TensorFlowOpLayers if needed.
for obj in self._flat_arguments:
if (isinstance(obj, ops.Tensor) and
base_layer_utils.needs_keras_history(obj, ignore_call_context=True)):
base_layer_utils.create_keras_history(obj)
if not keras_tensor.keras_tensors_enabled():
# Create TensorFlowOpLayers if needed.
for obj in self._flat_arguments:
if (isinstance(obj, ops.Tensor) and
base_layer_utils.needs_keras_history(
obj, ignore_call_context=True)):
base_layer_utils.create_keras_history(obj)
self._keras_inputs = []
self._keras_inputs_ids_and_indices = []

View File

@ -690,7 +690,7 @@ class TrainingTest(keras_parameterized.TestCase):
metrics=['accuracy'],
run_eagerly=testing_utils.should_run_eagerly())
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@ -1584,7 +1584,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
run_eagerly=testing_utils.should_run_eagerly())
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_sparse_op_with_op_layer(self):
inputs = layers_module.Input(shape=(2,), sparse=True, name='sparse_tensor')
output = sparse_ops.sparse_minimum(inputs, inputs)
@ -2817,7 +2817,7 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.3328, 0.8], 0.001)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_add_metric_with_tensor_on_model(self):
x = layers_module.Input(shape=(1,))
y = layers_module.Dense(1, kernel_initializer='ones')(x)
@ -2932,7 +2932,8 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
self.assertEqual(history.history['metric_1'][-1], 5)
self.assertAlmostEqual(history.history['val_metric_1'][-1], 5, 0)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
skip_keras_tensors=True)
def test_model_metrics_list(self):
class LayerWithAddMetric(layers_module.Layer):

View File

@ -303,7 +303,8 @@ def _test_sequential_model_type(f, test_or_class, *args, **kwargs):
def run_all_keras_modes(test_or_class=None,
config=None,
always_skip_v1=False,
always_skip_eager=False):
always_skip_eager=False,
**kwargs):
"""Execute the decorated test with all keras execution modes.
This decorator is intended to be applied either to individual test methods in
@ -361,6 +362,9 @@ def run_all_keras_modes(test_or_class=None,
when Tensorflow v2 behavior is not enabled.
always_skip_eager: If True, does not execute the decorated test
with eager execution modes.
**kwargs: Additional kwargs for configuring tests for
in-progress Keras behaviors/ refactorings that we haven't fully
rolled out yet
Returns:
Returns a decorator that will run the decorated test method multiple times.
@ -369,8 +373,14 @@ def run_all_keras_modes(test_or_class=None,
ImportError: If abseil parameterized is not installed or not included as
a target dependency.
"""
skip_keras_tensors = kwargs.pop('skip_keras_tensors', False)
if kwargs:
raise ValueError('Unrecognized keyword args: {}'.format(kwargs))
params = [('_v2_function', 'v2_function')]
if not skip_keras_tensors:
params.append(('_v2_function_use_keras_tensors',
'v2_function_use_keras_tensors'))
if not always_skip_eager:
params.append(('_v2_eager', 'v2_eager'))
if not (always_skip_v1 or tf2.enabled()):
@ -390,6 +400,8 @@ def run_all_keras_modes(test_or_class=None,
_v2_eager_test(f, self, *args, **kwargs)
elif run_mode == 'v2_function':
_v2_function_test(f, self, *args, **kwargs)
elif run_mode == 'v2_function_use_keras_tensors':
_v2_function_and_kerastensors_test(f, self, *args, **kwargs)
else:
return ValueError('Unknown run mode %s' % run_mode)
@ -417,6 +429,13 @@ def _v2_function_test(f, test_or_class, *args, **kwargs):
f(test_or_class, *args, **kwargs)
def _v2_function_and_kerastensors_test(f, test_or_class, *args, **kwargs):
with context.eager_mode():
with testing_utils.run_eagerly_scope(False):
with testing_utils.use_keras_tensors_scope(True):
f(test_or_class, *args, **kwargs)
def _test_or_class_decorator(test_or_class, single_method_decorator):
"""Decorate a test or class with a decorator intended for one method.

View File

@ -27,6 +27,7 @@ from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.platform import googletest
@ -206,7 +207,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
def runTest(self):
pass
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
@ -242,6 +243,54 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
ts.run(res)
self.assertLen(l, 4)
def test_run_all_keras_modes_include_keras_tensors(self):
l = []
class ExampleTest(keras_parameterized.TestCase):
def runTest(self):
pass
@keras_parameterized.run_all_keras_modes()
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
l.append((mode, should_run_eagerly,
keras_tensor.keras_tensors_enabled()))
e = ExampleTest()
if not tf2.enabled():
e.testBody_v1_session()
e.testBody_v2_eager()
e.testBody_v2_function()
e.testBody_v2_function_use_keras_tensors()
if not tf2.enabled():
self.assertLen(l, 4)
self.assertAllEqual(l, [
("graph", False, False),
("eager", True, False),
("eager", False, False),
("eager", False, True),
])
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(l, 8)
else:
self.assertLen(l, 3)
self.assertAllEqual(l, [
("eager", True, False),
("eager", False, False),
("eager", False, True),
])
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(l, 6)
def test_run_all_keras_modes_extra_params(self):
l = []
@ -250,7 +299,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
def runTest(self):
pass
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters(
[dict(testcase_name="_0", with_brackets=True),
dict(testcase_name="_1", with_brackets=False)])
@ -300,7 +349,8 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
def runTest(self):
pass
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
skip_keras_tensors=True)
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
@ -330,7 +380,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
pass
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
@ -382,7 +432,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
def runTest(self):
pass
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@keras_parameterized.run_with_all_model_types
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
@ -431,7 +481,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
l = []
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
class ExampleTest(keras_parameterized.TestCase):
def runTest(self):
@ -491,7 +541,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
def runTest(self):
pass
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters(dict(testcase_name="_arg",
arg=True))
def testBody(self, arg):
@ -537,7 +587,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
self.assertLen(l, len(expected_combinations) * 2)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters(dict(testcase_name="argument",
arg=True))
def test_run_all_keras_modes_extra_params_2(self, arg):

View File

@ -139,8 +139,10 @@ py_library(
"//tensorflow/python/keras:base_layer",
"//tensorflow/python/keras:constraints",
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras:losses",
"//tensorflow/python/keras:regularizers",
"//tensorflow/python/keras/engine:input_spec",
"//tensorflow/python/keras/engine:keras_tensor",
"//tensorflow/python/keras/layers/ops:core",
"//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:generic_utils",

View File

@ -39,6 +39,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.layers.ops import core as core_ops
@ -52,8 +53,12 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
from tensorflow.python.util.tf_export import get_symbol_from_name
from tensorflow.python.util.tf_export import keras_export
@ -1255,3 +1260,155 @@ class ActivityRegularization(Layer):
config = {'l1': self.l1, 'l2': self.l2}
base_config = super(ActivityRegularization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class TFOpLambda(Layer):
"""Wraps TF API symbols in a `Layer` object.
It is inserted by the Functional API construction whenever users call
a supported TF symbol on KerasTensors.
Like Lambda layers, this layer tries to raise warnings when it detects users
explicitly use variables in the call. (To let them know
that the layer will not capture the variables).
This is useful in the case where users do something like:
x = keras.Input(...)
y = tf.Variable(...)
out = x * tf_variable
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, function, **kwargs):
self.function = function
self.symbol = (
get_canonical_name_for_symbol(self.function,
add_prefix_to_v1_names=True) or
get_canonical_name_for_symbol(self.function,
api_name='keras',
add_prefix_to_v1_names=True))
kwargs['autocast'] = False
# Decorate the function to produce this layer's call method
def _call_wrapper(*args, **kwargs):
return self._call_wrapper(*args, **kwargs)
self.call = tf_decorator.make_decorator(function, _call_wrapper)
super(TFOpLambda, self).__init__(**kwargs)
# Warning on every invocation will be quite irksome in Eager mode.
self._already_warned = False
self._expects_training_arg = False
self._expects_mask_arg = False
def _call_wrapper(self, *args, **kwargs):
created_variables = []
def _variable_creator(next_creator, **creator_kwargs):
var = next_creator(**creator_kwargs)
created_variables.append(var)
return var
with backprop.GradientTape(watch_accessed_variables=True) as tape, \
variable_scope.variable_creator_scope(_variable_creator):
# We explicitly drop `name` arguments here,
# to guard against the case where an op explicitly has a
# `name` passed (which is susceptible to producing
# multiple ops w/ the same name when the layer is reused)
kwargs.pop('name', None)
result = self.function(*args, **kwargs)
self._check_variables(created_variables, tape.watched_variables())
return result
def _check_variables(self, created_variables, accessed_variables):
if not created_variables and not accessed_variables:
# In the common case that a Lambda layer does not touch a Variable, we
# don't want to incur the runtime cost of assembling any state used for
# checking only to immediately discard it.
return
tracked_weights = set(v.ref() for v in self.weights)
untracked_new_vars = [
v for v in created_variables if v.ref() not in tracked_weights
]
if untracked_new_vars:
variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars)
error_str = textwrap.dedent(
'''
The following Variables were created within a Lambda layer ({name})
but are not tracked by said layer:
{variable_str}
The layer cannot safely ensure proper Variable reuse across multiple
calls, and consquently this behavior is disallowed for safety. Lambda
layers are not well suited to stateful computation; instead, writing a
subclassed Layer is the recommend way to define layers with
Variables.'''
).format(name=self.name, variable_str=variable_str)
raise ValueError(error_str)
untracked_used_vars = [
v for v in accessed_variables if v.ref() not in tracked_weights
]
if untracked_used_vars and not self._already_warned:
variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars)
self._warn(textwrap.dedent(
'''
The following Variables were used a Lambda layer's call ({name}), but
are not present in its tracked objects:
{variable_str}
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.'''
).format(name=self.name, variable_str=variable_str))
self._already_warned = True
def _warn(self, msg):
# This method will be overridden in a unit test to raise an error, because
# self.assertWarns is not universally implemented.
return tf_logging.warn(msg)
def get_config(self):
if not self.symbol:
raise ValueError('This Keras op layer was generated from %s, a method '
'that is not an exposed in the TensorFlow API. This '
'may have happened if the method was explicitly '
'decorated to add dispatching support, and it was used '
'during Functional model construction. '
'To ensure cross-version compatibility of Keras models '
'that use op layers, only op layers produced from '
'exported TF API symbols can be serialized.'
% self.function)
config = {
'function': self.symbol
}
base_config = super(TFOpLambda, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
symbol_name = config['function']
function = get_symbol_from_name(symbol_name)
if not function:
raise ValueError(
'TF symbol `tf.%s` could not be found.' % symbol_name)
config['function'] = function
return cls(**config)
class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
"""A global dispatcher that allows building a functional model with TF Ops."""
def handle(self, op, args, kwargs):
"""Handle the specified operation with the specified arguments."""
if any(
isinstance(x, keras_tensor.KerasTensor)
for x in nest.flatten([args, kwargs])):
return TFOpLambda(op)(*args, **kwargs)
else:
return self.NOT_SUPPORTED
KerasOpDispatcher().register()

View File

@ -523,7 +523,11 @@ class Concatenate(_Merge):
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, (tuple, list)):
if ((not isinstance(input_shape, (tuple, list))) or
(not isinstance(input_shape[0], (tuple, list)))):
# The tf_utils.shape_type_conversion decorator turns tensorshapes
# into tuples, so we need to verify that `input_shape` is a list/tuple,
# *and* that the individual elements are themselves shape tuples.
raise ValueError('A `Concatenate` layer should be called '
'on a list of inputs.')
input_shapes = input_shape

View File

@ -50,7 +50,6 @@ def _single_identity_op_at_end():
inputs = keras.Input(shape=(10,))
x = keras.layers.Dense(10)(inputs)
outputs = array_ops.identity(x)
assert 'Identity' in outputs.name
return keras.Model(inputs, outputs)
@ -186,7 +185,7 @@ def _reuse_ancillary_layer():
return model
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
class AutoLambdaTest(keras_parameterized.TestCase):
@parameterized.named_parameters(
@ -313,11 +312,6 @@ class AutoLambdaTest(keras_parameterized.TestCase):
e = 3 # Fudge factor to prevent flakiness.
self.assertLess(size_500, (10 * e) * size_50)
def test_no_mask_tracking(self):
x = keras.backend.placeholder((10, 10))
y = keras.layers.Masking(0.)(x)
self.assertTrue(y._keras_mask._keras_history_checked)
def test_built(self):
inputs = keras.Input(shape=(10,))
outputs = gen_nn_ops.relu(inputs)

View File

@ -79,7 +79,7 @@ def _get_model(input_shape=(4,)):
class TestModelCloning(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters([
{'testcase_name': 'has_input_layer',
'input_shape': (4,),
@ -142,7 +142,7 @@ class TestModelCloning(keras_parameterized.TestCase):
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
self.assertTrue(new_model._is_graph_network)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters([
{'testcase_name': 'clone_weights', 'share_weights': False},
{'testcase_name': 'share_weights', 'share_weights': True},

View File

@ -40,7 +40,8 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
skip_keras_tensors=True)
class LinearModelTest(keras_parameterized.TestCase):
def test_linear_model_with_single_input(self):

View File

@ -37,7 +37,8 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
skip_keras_tensors=True)
class WideDeepModelTest(keras_parameterized.TestCase):
def test_wide_deep_model(self):

View File

@ -83,7 +83,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
self.assertEqual(len(model.losses), 1)
model.fit(x_train, y_train, batch_size=10, epochs=1, verbose=0)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters([
('l1', regularizers.l1()),
('l2', regularizers.l2()),
@ -126,7 +126,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
model.get_config(), custom_objects={'my_regularizer': my_regularizer})
self.assertEqual(model2.layers[1].kernel_regularizer, my_regularizer)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters([
('l1', regularizers.l1()),
('l2', regularizers.l2()),
@ -144,7 +144,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
run_eagerly=testing_utils.should_run_eagerly())
self.assertLen(model.losses, 5)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters([
('l1', regularizers.l1()),
('l2', regularizers.l2()),
@ -166,7 +166,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
run_eagerly=testing_utils.should_run_eagerly())
self.assertLen(model.losses, 6)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@parameterized.named_parameters([
('l1', regularizers.l1()),
('l2', regularizers.l2()),

View File

@ -147,6 +147,7 @@ tf_py_test(
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",

View File

@ -33,6 +33,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
@ -331,6 +332,29 @@ def run_eagerly_scope(value):
_thread_local_data.run_eagerly = previous_value
@tf_contextlib.contextmanager
def use_keras_tensors_scope(value):
"""Provides a scope within which we use KerasTensors in the func. API or not.
The boolean gets restored to its original value upon exiting the scope.
Arguments:
value: Bool specifying if we should build functional models
using KerasTensors in the active test.
Should be True or False.
Yields:
The provided value.
"""
previous_value = keras_tensor._KERAS_TENSORS_ENABLED # pylint: disable=protected-access
try:
keras_tensor._KERAS_TENSORS_ENABLED = value # pylint: disable=protected-access
yield value
finally:
# Restore KerasTensor usage to initial value.
keras_tensor._KERAS_TENSORS_ENABLED = previous_value # pylint: disable=protected-access
def should_run_eagerly():
"""Returns whether the models we are testing should be run eagerly."""
if _thread_local_data.run_eagerly is None:

View File

@ -69,7 +69,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32')
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_loss_on_model_fit(self):
inputs = Input(shape=(1,))
targets = Input(shape=(1,))
@ -85,7 +85,8 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True,
always_skip_v1=True)
def test_loss_callable_on_model_fit(self):
model = testing_utils.get_model_from_layers([testing_utils.Bias()],
input_shape=(1,))
@ -144,7 +145,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = [train_step(self.x, self.y) for _ in range(5)]
self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_loss_with_sample_weight_on_model_fit(self):
inputs = Input(shape=(1,))
targets = Input(shape=(1,))
@ -181,7 +182,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_loss_with_sample_weight_in_model_call(self):
class MyModel(Model):
@ -209,7 +210,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
eval_out = model.evaluate([self.x, self.y, self.w])
self.assertAlmostEqual(eval_out, 1.0, 3)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_loss_with_sample_weight_in_layer_call(self):
class MyLayer(layers.Layer):
@ -244,7 +245,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
output = model.test_on_batch([self.x, self.y, self.w])
self.assertAlmostEqual(output, 1.0, 3)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_loss_on_layer(self):
class MyLayer(layers.Layer):
@ -265,7 +266,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
self.assertEqual(loss, 2 * 3)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@keras_parameterized.run_with_all_model_types
def test_activity_regularizer(self):
loss = {}
@ -299,7 +300,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss[reg] = model.evaluate(x, y)
self.assertLess(loss[None], loss['l2'])
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
@keras_parameterized.run_with_all_model_types
def test_activity_regularizer_loss_value(self):
layer = layers.Dense(
@ -318,7 +319,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss = model.test_on_batch(x)
self.assertAlmostEqual(0.01, loss, places=4)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_activity_regularizer_batch_independent(self):
inputs = layers.Input(shape=(10,))
x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs)
@ -334,7 +335,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32'))
self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_with_shared_layer(self):
class LayerWithLoss(layers.Layer):
@ -351,7 +352,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertEqual(len(m2.losses), 2)
self.assertAllClose(m2.losses, [6, 12])
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_with_shared_nested_layer(self):
class LayerWithLoss(layers.Layer):
@ -377,7 +378,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertEqual(len(m2.losses), 2)
self.assertAllClose(m2.losses, [6, 12])
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_clear_losses(self):
class LayerWithSharedNestedLossLayer(layers.Layer):
@ -428,7 +429,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
self.assertEqual(len(model.get_losses_for(x4)), 2)
self.assertEqual(len(model.get_losses_for(None)), 1)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_invalid_constant_input(self):
with context.eager_mode():
inputs = Input(shape=(1,))
@ -439,7 +440,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
'Expected a symbolic Tensors or a callable for the loss value'):
model.add_loss(1.)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
def test_invalid_variable_input(self):
with context.eager_mode():
inputs = Input(shape=(1,))

View File

@ -506,7 +506,8 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
model_input = input_layer.Input(
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32,
batch_size=2)
self.assertIsInstance(model_input, ragged_tensor.RaggedTensor)
self.assertIsInstance(model_input._type_spec,
ragged_tensor.RaggedTensorSpec)
self.assertEqual(model_input.shape.as_list(), [2, None, None])
layers = [ToDense(default_value=-1)]
model = get_model_from_layers_with_input(layers, model_input=model_input)
@ -602,7 +603,8 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
@keras_parameterized.run_with_all_model_types()
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
skip_keras_tensors=True)
class CompositeTensorModelPredictTest(keras_parameterized.TestCase):
def _normalize_shape(self, shape):