This change adds WIP support for using KerasTensor
objects in Keras's functional API instead of symbolic graph tf.Tensors. It is controlled by an internal behavior flag that is disabled by default and is not yet exposed in TF's APIs.
`KerasTensor`s are an alternative representation for Keras `Inputs` and for intermediate outputs of layers during Functional API construction of models. They are a lightweight data structure comprised of only the `tf.TypeSpec` of the Tensor that will be consumed/produced in the corresponding position of the model. They implement just small subset of `tf.Tensor`'s attributes and methods, and also overload the same operators as `tf.Tensor` and automatically turn them into Keras layers in the model. `KerasTensor`s are still internal-only and are a work in progress, but they have several advantages over using a graph `tf.Tensor` to represent symbolic values in functional models. - Unlike symbolic tensors, they do not need to refer to a graph. This means Keras does not need to maintain a never-deleted global background graph containing all layers ever called during functional model construction when constructing Functional Models with KerasTensors. These memory savings can be significant. - Triggering Keras functional model construction is simpler when it just has to check whether something is a KerasTensor, rather than trying to infer if a tensor was meant to be a symbolic keras representation or just a value produced during function tracing. This means we can add support for cases where values in nest.flatten(*args, **kwargs) are a completely arbitrary mix of KerasTensors and objects that are not KerasTensors, as long as any value is a KerasTensor. - Autolambda layers (converting tf ops on symbolic Keras tensors to lambda Keras layers in the model) use TF's internal dispatching mechanism, instead of trying to manually walk a graph and extract nodes from it. The dispatching mechanism is simpler, works more reliably, and is less likely to run into issues with composite tensors or strange tf ops/nodes. (And when it fails, it's by design: because dispatch is explicitly not supported on the op & it's more obvious that dispatch doesn't support the setting). - Because they support arbitrary typespecs, models/layers that use KerasTensors are generally more friendly to composite tensors of different types than using symbolic graph tensors (which must have a TensorSpec and can't have arbitrary typespecs) To experiment with using KerasTensors instead of symbolic graph `tf.Tensors`, import keras_tensor directly and call `keras_tensor.enable_keras_tensors()` PiperOrigin-RevId: 315009281 Change-Id: I6765f3a44da43f965ec261b6b193df26598cffae
This commit is contained in:
parent
52028d8d95
commit
0e1f3de50a
@ -86,6 +86,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/keras/engine:keras_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -49,8 +49,10 @@ from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -1070,6 +1072,8 @@ def is_keras_tensor(x):
|
||||
True
|
||||
|
||||
"""
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
return isinstance(x, keras_tensor.KerasTensor)
|
||||
if not isinstance(x,
|
||||
(ops.Tensor, variables_module.Variable,
|
||||
sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
|
||||
@ -1120,35 +1124,56 @@ def placeholder(shape=None,
|
||||
raise ValueError(
|
||||
'Cannot set both sparse and ragged to True when creating a placeholder.'
|
||||
)
|
||||
|
||||
if dtype is None:
|
||||
dtype = floatx()
|
||||
if not shape:
|
||||
if ndim:
|
||||
shape = (None,) * ndim
|
||||
with get_graph().as_default():
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
spec = tensor_spec.TensorSpec(
|
||||
shape=shape, dtype=dtype, name=name)
|
||||
if sparse:
|
||||
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
|
||||
spec = sparse_tensor.SparseTensorSpec(
|
||||
shape=shape, dtype=dtype)
|
||||
elif ragged:
|
||||
ragged_rank = 0
|
||||
for i in range(1, len(shape)):
|
||||
if shape[i] is None:
|
||||
# Hacky because could be tensorshape or tuple maybe?
|
||||
# Or just tensorshape?
|
||||
if shape[i] is None or (
|
||||
hasattr(shape[i], 'value') and
|
||||
shape[i].value is None):
|
||||
ragged_rank = i
|
||||
type_spec = ragged_tensor.RaggedTensorSpec(
|
||||
spec = ragged_tensor.RaggedTensorSpec(
|
||||
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
|
||||
def tensor_spec_to_placeholder(tensorspec):
|
||||
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
||||
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
|
||||
expand_composites=True)
|
||||
else:
|
||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||
|
||||
x = keras_tensor.KerasTensor(spec, name=name)
|
||||
else:
|
||||
with get_graph().as_default():
|
||||
if sparse:
|
||||
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
|
||||
elif ragged:
|
||||
ragged_rank = 0
|
||||
for i in range(1, len(shape)):
|
||||
if shape[i] is None:
|
||||
ragged_rank = i
|
||||
type_spec = ragged_tensor.RaggedTensorSpec(
|
||||
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
|
||||
def tensor_spec_to_placeholder(tensorspec):
|
||||
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
||||
x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
|
||||
expand_composites=True)
|
||||
else:
|
||||
x = array_ops.placeholder(dtype, shape=shape, name=name)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Add keras_history connectivity information to the placeholder
|
||||
# when the placeholder is built in a top-level eager context
|
||||
# (intended to be used with keras.backend.function)
|
||||
from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top
|
||||
return input_layer.Input(tensor=x)
|
||||
x = input_layer.Input(tensor=x)
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
x._is_backend_placeholder = True
|
||||
|
||||
return x
|
||||
|
||||
@ -1163,6 +1188,8 @@ def is_placeholder(x):
|
||||
Boolean.
|
||||
"""
|
||||
try:
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
return hasattr(x, '_is_backend_placeholder')
|
||||
if isinstance(x, composite_tensor.CompositeTensor):
|
||||
flat_components = nest.flatten(x, expand_composites=True)
|
||||
return py_any(is_placeholder(c) for c in flat_components)
|
||||
|
@ -61,6 +61,10 @@ def keras_model_type_combinations():
|
||||
return combinations.combine(model_type=KERAS_MODEL_TYPES)
|
||||
|
||||
|
||||
def keras_tensor_combinations():
|
||||
return combinations.combine(use_keras_tensors=['True', 'False'])
|
||||
|
||||
|
||||
class KerasModeCombination(test_combinations.TestCombination):
|
||||
"""Combination for Keras test mode.
|
||||
|
||||
@ -100,11 +104,32 @@ class KerasModelTypeCombination(test_combinations.TestCombination):
|
||||
return [test_combinations.OptionalParameter('model_type')]
|
||||
|
||||
|
||||
class KerasTensorCombination(test_combinations.TestCombination):
|
||||
"""Combination for whether KerasTensors are being used or not.
|
||||
|
||||
It by default includes `True` and `False`:
|
||||
running Keras's functional API with KerasTensors
|
||||
as the inputs, and without.
|
||||
"""
|
||||
|
||||
def context_managers(self, kwargs):
|
||||
use_keras_tensors = kwargs.pop('use_keras_tensors', None)
|
||||
|
||||
if use_keras_tensors is not None:
|
||||
return [testing_utils.use_keras_tensors_scope(use_keras_tensors)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [test_combinations.OptionalParameter('use_keras_tensors')]
|
||||
|
||||
|
||||
_defaults = combinations.generate.keywords['test_combinations']
|
||||
generate = functools.partial(
|
||||
combinations.generate,
|
||||
test_combinations=_defaults +
|
||||
(KerasModeCombination(), KerasModelTypeCombination()))
|
||||
(KerasModeCombination(), KerasModelTypeCombination(),
|
||||
KerasTensorCombination()))
|
||||
combine = test_combinations.combine
|
||||
times = test_combinations.times
|
||||
NamedObject = test_combinations.NamedObject
|
||||
|
@ -41,6 +41,7 @@ py_library(
|
||||
":base_preprocessing_layer",
|
||||
":data_adapter",
|
||||
":input_spec",
|
||||
":keras_tensor",
|
||||
"//tensorflow/python:composite_tensor_utils",
|
||||
"//tensorflow/python:py_checkpoint_reader",
|
||||
"//tensorflow/python/data",
|
||||
@ -164,6 +165,18 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "keras_tensor",
|
||||
srcs = ["keras_tensor.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_preprocessing_layer",
|
||||
srcs = [
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||
@ -700,7 +701,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# with the shape the Layer will be called on (these users will have to
|
||||
# implement `compute_output_shape` themselves).
|
||||
self._maybe_build(input_shape)
|
||||
with func_graph.FuncGraph('graph').as_default():
|
||||
with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default():
|
||||
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
||||
def _make_placeholder_like(shape):
|
||||
ph = backend.placeholder(shape=shape, dtype=self.dtype)
|
||||
@ -759,6 +760,76 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
|
||||
output_shape)
|
||||
|
||||
def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs):
|
||||
if self.dynamic:
|
||||
# We will use static shape inference to return symbolic tensors
|
||||
# matching the specifications of the layer outputs.
|
||||
# Since `self.dynamic` is True, we will never attempt to
|
||||
# run the underlying TF graph (which is disconnected).
|
||||
# TODO(fchollet): consider py_func as an alternative, which
|
||||
# would enable us to run the underlying graph if needed.
|
||||
input_signature = nest.map_structure(
|
||||
lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype),
|
||||
inputs)
|
||||
output_signature = self.compute_output_signature(input_signature)
|
||||
return nest.map_structure(keras_tensor.KerasTensor, output_signature)
|
||||
else:
|
||||
return self._infer_output_signature(inputs, args, kwargs, input_masks)
|
||||
|
||||
def _infer_output_signature(self, inputs, args, kwargs, input_masks):
|
||||
"""TODO(kaftan): Docstring."""
|
||||
|
||||
call_fn = self.call
|
||||
# Wrapping `call` function in autograph to allow for dynamic control
|
||||
# flow and control dependencies in call. We are limiting this to
|
||||
# subclassed layers as autograph is strictly needed only for
|
||||
# subclassed layers and models.
|
||||
# tf_convert will respect the value of autograph setting in the
|
||||
# enclosing tf.function, if any.
|
||||
if (base_layer_utils.is_subclassed(self) and
|
||||
not base_layer_utils.from_saved_model(self)):
|
||||
call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
|
||||
|
||||
# We enter a scratch graph and build placeholder inputs inside of it that
|
||||
# match the input args.
|
||||
# We then call the layer inside of the scratch graph to identify the
|
||||
# output signatures, then we build KerasTensors corresponding to those
|
||||
# outputs.
|
||||
scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph')
|
||||
with scratch_graph.as_default():
|
||||
inputs = nest.map_structure(
|
||||
keras_tensor.keras_tensor_to_placeholder, inputs)
|
||||
args = nest.map_structure(
|
||||
keras_tensor.keras_tensor_to_placeholder, args)
|
||||
kwargs = nest.map_structure(
|
||||
keras_tensor.keras_tensor_to_placeholder, kwargs)
|
||||
input_masks = nest.map_structure(
|
||||
keras_tensor.keras_tensor_to_placeholder, input_masks)
|
||||
|
||||
inputs = self._maybe_cast_inputs(inputs)
|
||||
|
||||
# try:
|
||||
with backend.name_scope(self._name_scope()):
|
||||
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||
# Build layer if applicable (if the `build` method has been
|
||||
# overridden).
|
||||
# TODO(kaftan): do we maybe_build here, or have we already done it?
|
||||
self._maybe_build(inputs)
|
||||
outputs = call_fn(inputs, *args, **kwargs)
|
||||
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, input_masks,
|
||||
build_graph=False)
|
||||
|
||||
outputs = nest.map_structure(keras_tensor.keras_tensor_from_tensor, outputs)
|
||||
if hasattr(self, '_set_inputs') and not self.inputs:
|
||||
# TODO(kaftan): figure out if we ned to do this at all
|
||||
# Subclassed network: explicitly set metadata normally set by
|
||||
# a call to self._set_inputs().
|
||||
self._set_inputs(inputs, outputs)
|
||||
del scratch_graph
|
||||
return outputs
|
||||
|
||||
@generic_utils.default
|
||||
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
|
||||
"""Computes an output mask tensor.
|
||||
@ -954,6 +1025,27 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
args, kwargs)
|
||||
training_arg_passed_by_framework = True
|
||||
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
with call_context.enter(
|
||||
layer=self, inputs=inputs, build_graph=True, training=training_value):
|
||||
# Check input assumptions set after layer building, e.g. input shape.
|
||||
outputs = self._keras_tensor_symbolic_call(
|
||||
inputs, input_masks, args, kwargs)
|
||||
|
||||
if outputs is None:
|
||||
raise ValueError('A layer\'s `call` method should return a '
|
||||
'Tensor or a list of Tensors, not None '
|
||||
'(layer: ' + self.name + ').')
|
||||
if training_arg_passed_by_framework:
|
||||
args, kwargs = self._set_call_arg_value(
|
||||
'training', None, args, kwargs, pop_kwarg_if_none=True)
|
||||
if mask_arg_passed_by_framework:
|
||||
kwargs.pop('mask')
|
||||
# Node connectivity does not special-case the first argument.
|
||||
outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
|
||||
outputs)
|
||||
return outputs
|
||||
|
||||
# Only create Keras history if at least one tensor originates from a
|
||||
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
|
||||
# framework.
|
||||
@ -1237,6 +1329,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
@property
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
def updates(self):
|
||||
if (keras_tensor.keras_tensors_enabled()
|
||||
and ops.executing_eagerly_outside_functions()):
|
||||
return []
|
||||
|
||||
collected_updates = []
|
||||
all_layers = self._flatten_layers()
|
||||
with backend.get_graph().as_default():
|
||||
@ -1399,10 +1495,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
continue
|
||||
if loss is None:
|
||||
continue
|
||||
if not tensor_util.is_tensor(loss):
|
||||
if not tensor_util.is_tensor(loss) and not isinstance(
|
||||
loss, keras_tensor.KerasTensor):
|
||||
loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx())
|
||||
# TF Functions should take the eager path.
|
||||
if (tf_utils.is_symbolic_tensor(loss) and
|
||||
if ((tf_utils.is_symbolic_tensor(loss) or
|
||||
isinstance(loss, keras_tensor.KerasTensor)) and
|
||||
not base_layer_utils.is_in_tf_function()):
|
||||
symbolic_losses.append(loss)
|
||||
elif tensor_util.is_tensor(loss):
|
||||
@ -1418,7 +1516,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
|
||||
self._eager_losses.extend(eager_losses)
|
||||
|
||||
if in_call_context:
|
||||
if in_call_context and not keras_tensor.keras_tensors_enabled():
|
||||
for symbolic_loss in symbolic_losses:
|
||||
self._losses.append(symbolic_loss)
|
||||
else:
|
||||
@ -1520,7 +1618,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
raise TypeError('Unknown keyword arguments: ', str(kwargs.keys()))
|
||||
|
||||
from_metric_obj = hasattr(value, '_metric_obj')
|
||||
is_symbolic = tf_utils.is_symbolic_tensor(value)
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
is_symbolic = isinstance(value, keras_tensor.KerasTensor)
|
||||
else:
|
||||
is_symbolic = tf_utils.is_symbolic_tensor(value)
|
||||
in_call_context = base_layer_utils.call_context().in_call
|
||||
|
||||
if name is None and not from_metric_obj:
|
||||
@ -2217,7 +2318,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
"""
|
||||
return self._dtype_policy.compute_dtype
|
||||
|
||||
def _maybe_cast_inputs(self, inputs, input_list):
|
||||
def _maybe_cast_inputs(self, inputs, input_list=None):
|
||||
"""Maybe casts the inputs to the compute dtype.
|
||||
|
||||
If self._compute_dtype is floating-point, and self_autocast is True,
|
||||
@ -2230,6 +2331,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
Returns:
|
||||
`inputs`, but tensors may have been casted to self._compute_dtype
|
||||
"""
|
||||
if not input_list:
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
compute_dtype_object = self._compute_dtype_object
|
||||
should_autocast = (
|
||||
self._autocast and compute_dtype_object and
|
||||
@ -3094,11 +3198,19 @@ class AddMetric(Layer):
|
||||
|
||||
|
||||
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||
if context.executing_eagerly():
|
||||
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
"""Check the arguments to see if we are constructing a functional model."""
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
# We are constructing a functional model if any of the inputs
|
||||
# are KerasTensors
|
||||
return any(
|
||||
isinstance(tensor, keras_tensor.KerasTensor)
|
||||
for tensor in nest.flatten([inputs, args, kwargs]))
|
||||
else:
|
||||
return (base_layer_utils.is_in_keras_graph() or
|
||||
all(hasattr(t, '_keras_history') for t in input_list))
|
||||
if context.executing_eagerly():
|
||||
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
else:
|
||||
return (base_layer_utils.is_in_keras_graph() or
|
||||
all(hasattr(t, '_keras_history') for t in input_list))
|
||||
|
||||
|
||||
def _convert_numpy_or_python_types(x):
|
||||
|
@ -48,7 +48,6 @@ from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.layers import core as legacy_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
@ -86,7 +85,9 @@ class InvalidLayer(base_layer.Layer):
|
||||
|
||||
class BaseLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.keras_model_type_combinations())
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_dynamic_layer(self):
|
||||
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||
input_shape=(3,))
|
||||
@ -95,26 +96,30 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(model.run_eagerly, True)
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
|
||||
@combinations.generate(combinations.keras_model_type_combinations())
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_dynamic_layer_error(self):
|
||||
# Functional Models hit the `dyanamic=True` error during construction.
|
||||
# Subclass Models should just throw the original autograph error during
|
||||
# execution.
|
||||
model_type = testing_utils.get_model_type()
|
||||
if 'subclass' in model_type and context.executing_eagerly():
|
||||
error_type = errors_impl.OperatorNotAllowedInGraphError
|
||||
error_message = 'iterating over `tf.Tensor` is not allowed'
|
||||
else:
|
||||
error_type = TypeError
|
||||
error_message = 'attempting to use Python control flow'
|
||||
|
||||
with self.assertRaisesRegexp(error_type, error_message):
|
||||
raised_error = False
|
||||
try:
|
||||
model = testing_utils.get_model_from_layers([DynamicLayer()],
|
||||
input_shape=(3,))
|
||||
model.compile(rmsprop.RMSprop(0.001), loss='mse')
|
||||
model.train_on_batch(np.random.random((2, 3)), np.random.random((2, 3)))
|
||||
except errors_impl.OperatorNotAllowedInGraphError as e:
|
||||
if 'iterating over `tf.Tensor` is not allowed' in str(e):
|
||||
raised_error = True
|
||||
except TypeError as e:
|
||||
if 'attempting to use Python control flow' in str(e):
|
||||
raised_error = True
|
||||
self.assertTrue(raised_error)
|
||||
|
||||
@combinations.generate(combinations.keras_model_type_combinations())
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_dynamic_layer_error_running_in_graph_mode(self):
|
||||
with ops.get_default_graph().as_default():
|
||||
model = testing_utils.get_model_from_layers([DynamicLayer(dynamic=True)],
|
||||
@ -155,11 +160,6 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer.build_counter, 1)
|
||||
self.assertEqual(layer.build_shape.as_list(), [None, 10])
|
||||
|
||||
def test_eager_switch_case_input(self):
|
||||
task = input_layer.Input(shape=(), dtype=dtypes.int32)
|
||||
control_flow_ops.switch_case(
|
||||
task[0], [lambda: constant_op.constant(1.0) for _ in range(10)])
|
||||
|
||||
def test_dynamic_layer_with_deferred_sequential_model(self):
|
||||
model = sequential.Sequential([DynamicLayer(dynamic=True), layers.Dense(3)])
|
||||
self.assertEqual(model.dynamic, True)
|
||||
@ -284,6 +284,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations(),
|
||||
combinations.combine(mode=['graph', 'eager'])))
|
||||
def test_build_with_numpy_data(self):
|
||||
model_layers = [
|
||||
@ -384,7 +385,8 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
# b/124459427: can't test with `run_eagerly=True` for now.
|
||||
@combinations.generate(
|
||||
combinations.times(combinations.keras_mode_combinations(),
|
||||
combinations.keras_model_type_combinations()))
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_training_arg_in_defun(self):
|
||||
layer = self._get_layer_with_training_arg()
|
||||
model = testing_utils.get_model_from_layers([layer], input_shape=(1,))
|
||||
@ -409,7 +411,8 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(combinations.keras_mode_combinations(),
|
||||
combinations.keras_model_type_combinations()))
|
||||
combinations.keras_model_type_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_raw_variable_assignment(self):
|
||||
|
||||
class RawVariableLayer(base_layer.Layer):
|
||||
|
@ -20,12 +20,17 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -70,7 +75,7 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
class OpLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_tensor_op_layer(self):
|
||||
@ -96,6 +101,35 @@ class OpLayerTest(keras_parameterized.TestCase):
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
_ = keras.Model(int_values, float_values)
|
||||
|
||||
def test_ragged_op_layer_keras_tensors(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
model = keras.Model(int_values, float_values)
|
||||
model.compile(loss='mse')
|
||||
|
||||
input_data = ragged_factory_ops.constant(
|
||||
[[1, 2], [3, 4]], dtype=np.int32)
|
||||
expected = [[1.0, 2.0], [3.0, 4.0]]
|
||||
output = model.predict(input_data)
|
||||
self.assertIsInstance(output, ragged_tensor.RaggedTensor)
|
||||
self.assertAllClose(expected, output)
|
||||
|
||||
def test_sparse_op_layer_keras_tensors(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
int_values = keras.Input(shape=(None,), dtype=dtypes.int32, sparse=True)
|
||||
float_values = math_ops.cast(int_values, dtypes.float32)
|
||||
_ = keras.Model(int_values, float_values)
|
||||
model = keras.Model(int_values, float_values)
|
||||
model.compile(loss='mse')
|
||||
|
||||
input_data = sparse_ops.from_dense(
|
||||
np.array([[1, 2], [3, 4]], dtype=np.int32))
|
||||
expected = [[1.0, 2.0], [3.0, 4.0]]
|
||||
output = model.predict(input_data)
|
||||
self.assertIsInstance(output, sparse_tensor.SparseTensor)
|
||||
self.assertAllClose(expected, sparse_ops.sparse_tensor_to_dense(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
@ -994,7 +995,8 @@ def _map_subgraph_network(inputs, outputs):
|
||||
Returns:
|
||||
A tuple of List{Node] and List[Layer].
|
||||
"""
|
||||
base_layer_utils.create_keras_history(outputs)
|
||||
if not keras_tensor.keras_tensors_enabled():
|
||||
base_layer_utils.create_keras_history(outputs)
|
||||
# Keep only nodes and layers in the topology between inputs and outputs.
|
||||
_, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
|
||||
return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
|
||||
|
@ -965,6 +965,45 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
|
||||
|
||||
class MaybeAdd(layers.Layer):
|
||||
|
||||
def call(self, x1, x2=None):
|
||||
if x2 is not None:
|
||||
return x1 + x2
|
||||
return x1
|
||||
|
||||
input2 = input_layer_lib.Input(10)
|
||||
outputs = MaybeAdd()(3., x2=input2)
|
||||
model = training_lib.Model([input2], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=7 * np.ones((10, 10)),
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=7 * np.ones((10, 10)),
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_composite_call_kwarg_derived_from_keras_layer(self):
|
||||
|
||||
@ -1006,6 +1045,58 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(mode='eager'),
|
||||
combinations.keras_tensor_combinations()))
|
||||
def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
|
||||
# This functionality is unsupported in v1 graphs
|
||||
|
||||
class AddAll(layers.Layer):
|
||||
|
||||
def call(self, x1_x2, x3):
|
||||
x1, x2 = x1_x2
|
||||
out = x1 + x2
|
||||
if x3 is not None:
|
||||
for t in x3.values():
|
||||
out += t
|
||||
return out
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
input3 = input_layer_lib.Input(10)
|
||||
|
||||
outputs = AddAll()(
|
||||
[input1, 4 * array_ops.ones((1, 10))],
|
||||
x3={
|
||||
'a': input2,
|
||||
'b': input3,
|
||||
'c': 5 * array_ops.ones((1, 10))
|
||||
})
|
||||
model = training_lib.Model([input1, input2, input3], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
|
||||
y=15 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that all inputs were correctly added.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'AddAll': AddAll})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
|
||||
y=15 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that all inputs were correctly added.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_nested_arg_derived_from_keras_layer(self):
|
||||
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
@ -116,6 +117,10 @@ class InputLayer(base_layer.Layer):
|
||||
if kwargs:
|
||||
raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
|
||||
|
||||
if sparse and ragged:
|
||||
raise ValueError(
|
||||
'Cannot set both sparse and ragged to True in a Keras input.')
|
||||
|
||||
if not name:
|
||||
prefix = 'input'
|
||||
name = prefix + '_' + str(backend.get_uid(prefix))
|
||||
@ -157,7 +162,14 @@ class InputLayer(base_layer.Layer):
|
||||
self.is_placeholder = True
|
||||
self._batch_input_shape = batch_input_shape
|
||||
else:
|
||||
if not tf_utils.is_symbolic_tensor(input_tensor):
|
||||
raise_eager_tensor_error = False
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
if not isinstance(input_tensor, keras_tensor.keras_tensors_enabled()):
|
||||
raise_eager_tensor_error = True
|
||||
else:
|
||||
if not tf_utils.is_symbolic_tensor(input_tensor):
|
||||
raise_eager_tensor_error = True
|
||||
if raise_eager_tensor_error:
|
||||
raise ValueError('You should not pass an EagerTensor to `Input`. '
|
||||
'For example, instead of creating an '
|
||||
'InputLayer, you should instantiate your model and '
|
||||
@ -173,7 +185,8 @@ class InputLayer(base_layer.Layer):
|
||||
node_module.Node(layer=self, outputs=input_tensor)
|
||||
|
||||
# Store type spec
|
||||
if isinstance(input_tensor, composite_tensor.CompositeTensor):
|
||||
if isinstance(input_tensor, (
|
||||
composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
|
||||
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
|
||||
else:
|
||||
self._type_spec = tensor_spec.TensorSpec(
|
||||
|
229
tensorflow/python/keras/engine/keras_tensor.py
Normal file
229
tensorflow/python/keras/engine/keras_tensor.py
Normal file
@ -0,0 +1,229 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras Input Tensor used to track functional API Topology."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import type_spec as type_spec_module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
_KERAS_TENSORS_ENABLED = False
|
||||
|
||||
|
||||
def enable_keras_tensors():
|
||||
"""Enable using KerasTensors in Keras's functional API."""
|
||||
global _KERAS_TENSORS_ENABLED
|
||||
_KERAS_TENSORS_ENABLED = True
|
||||
|
||||
|
||||
def disable_keras_tensors():
|
||||
"""Disable using KerasTensors in Keras's functional API."""
|
||||
global _KERAS_TENSORS_ENABLED
|
||||
_KERAS_TENSORS_ENABLED = False
|
||||
|
||||
|
||||
def keras_tensors_enabled():
|
||||
"""Return a bool specifying if KerasTensors are enabled."""
|
||||
return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions()
|
||||
|
||||
|
||||
class KerasTensor(object):
|
||||
"""A representation of a Keras in/output during Functional API construction.
|
||||
|
||||
`KerasTensor`s are an alternative representation for Keras `Inputs`
|
||||
and for intermediate outputs of layers during Functional API construction of
|
||||
models. They are a lightweight data structure comprised of only the
|
||||
`tf.TypeSpec` of the Tensor that will be consumed/produced in the
|
||||
corresponding position of the model.
|
||||
|
||||
They implement just small subset of `tf.Tensor`'s attributes and
|
||||
methods, and also overload
|
||||
the same operators as `tf.Tensor` and automatically turn them into
|
||||
Keras layers in the model.
|
||||
|
||||
`KerasTensor`s are still internal-only and are a work in progress, but they
|
||||
have several advantages over using a graph `tf.Tensor` to represent
|
||||
symbolic values in functional models.
|
||||
- Unlike symbolic tensors, they do not need to refer to a graph. This means
|
||||
Keras does not need to maintain a never-deleted global background graph
|
||||
containing all layers ever called during functional model construction when
|
||||
constructing Functional Models with KerasTensors. These memory savings
|
||||
can be significant.
|
||||
|
||||
- Triggering Keras functional model construction is simpler
|
||||
when it just has to check whether something is a KerasTensor, rather
|
||||
than trying to infer if a tensor was meant to be a symbolic keras
|
||||
representation or just a value produced during function tracing.
|
||||
|
||||
- Autolambda layers (converting tf ops on symbolic Keras tensors to lambda
|
||||
Keras layers in the model) use TF's internal dispatching mechanism, instead
|
||||
of trying to manually walk a graph and extract nodes from it.
|
||||
The dispatching mechanism is simpler, works more reliably, and is less
|
||||
likely to run into issues with composite tensors or strange tf ops/nodes.
|
||||
|
||||
(And when it fails, it's by design: because dispatch is explicitly not
|
||||
supported on the op & it's more obvious that dispatch doesn't support the
|
||||
setting).
|
||||
|
||||
- Because they support arbitrary typespecs, models/layers that use
|
||||
KerasTensors are generally more friendly to composite tensors of different
|
||||
types than using symbolic graph tensors (which must have a TensorSpec and
|
||||
can't have arbitrary typespecs)
|
||||
|
||||
To experiment with using KerasTensors instead of symbolic graph `tf.Tensors`,
|
||||
import keras_tensor directly and call `keras_tensor.enable_keras_tensors()`
|
||||
"""
|
||||
|
||||
def __init__(self, type_spec, name=None):
|
||||
"""Construct a KerasTensor from a type_spec and an optional name."""
|
||||
if not isinstance(type_spec, type_spec_module.TypeSpec):
|
||||
raise ValueError('KerasTensors must be constructed with a `tf.TypeSpec`.')
|
||||
|
||||
self._type_spec = type_spec
|
||||
if name is None and hasattr(type_spec, 'name'):
|
||||
name = type_spec.name
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def type_spec(self):
|
||||
"""Returns the `TypeSpec` that represents this Tensor."""
|
||||
return self._type_spec
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Returns the `TensorShape` that represents the shape of the tensor."""
|
||||
# TODO(kaftan): This is only valid for normal/sparse/ragged tensors.
|
||||
# may need to raise an error when it's not valid for a type_spec,
|
||||
# but some keras code (e.g. build-related stuff) will likely fail when
|
||||
# it can't access shape or dtype
|
||||
return self._type_spec._shape # pylint: disable=protected-access
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Returns the `dtype` of elements in the tensor."""
|
||||
# TODO(kaftan): This is only valid for normal/sparse/ragged tensors.
|
||||
# may need to raise an error when it's not valid for a type_spec,
|
||||
# but some keras code (e.g. build-related stuff) will likely fail when
|
||||
# it can't access shape or dtype
|
||||
return self._type_spec._dtype # pylint: disable=protected-access
|
||||
|
||||
def ref(self):
|
||||
"""Returns a hashable reference object to this KerasTensor.
|
||||
|
||||
The primary use case for this API is to put KerasTensors in a
|
||||
set/dictionary. We can't put tensors in a set/dictionary as
|
||||
`tensor.__hash__()` is not available and tensor equality (`==`) is supposed
|
||||
to produce a tensor representing if the two inputs are equal.
|
||||
|
||||
See the documentation of `tf.Tensor.ref()` for more info.
|
||||
"""
|
||||
return object_identity.Reference(self)
|
||||
|
||||
def __iter__(self):
|
||||
shape = None
|
||||
if self.shape.ndims is not None:
|
||||
shape = [dim.value for dim in self.shape.dims]
|
||||
|
||||
if shape is None:
|
||||
raise TypeError('Cannot iterate over a KerasTensor with unknown shape.')
|
||||
if not shape:
|
||||
raise TypeError('Cannot iterate over a scalar.')
|
||||
if shape[0] is None:
|
||||
raise TypeError(
|
||||
'Cannot iterate over a KerasTensor with unknown first dimension.')
|
||||
return _KerasTensorIterator(self, shape[0])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Returns the (optionally provided) name of the described tensor."""
|
||||
return self._name
|
||||
|
||||
@classmethod
|
||||
def _overload_all_operators(cls): # pylint: disable=invalid-name
|
||||
"""Register overloads for all operators."""
|
||||
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
|
||||
cls._overload_operator(operator)
|
||||
|
||||
@classmethod
|
||||
def _overload_operator(cls, operator): # pylint: disable=invalid-name
|
||||
"""Overload an operator with the same overloading as `ops.Tensor`.
|
||||
|
||||
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
|
||||
|
||||
Args:
|
||||
operator: string. The operator name.
|
||||
"""
|
||||
tensor_oper = getattr(ops.Tensor, operator)
|
||||
|
||||
# Compatibility with Python 2:
|
||||
# Python 2 unbound methods have type checks for the first arg,
|
||||
# so we need to extract the underlying function
|
||||
tensor_oper = getattr(tensor_oper, '__func__', tensor_oper)
|
||||
|
||||
setattr(cls, operator, tensor_oper)
|
||||
|
||||
|
||||
KerasTensor._overload_all_operators() # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _KerasTensorIterator(object):
|
||||
"""Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
|
||||
|
||||
def __init__(self, tensor, dim0):
|
||||
self._tensor = tensor
|
||||
self._index = 0
|
||||
self._limit = dim0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._index == self._limit:
|
||||
raise StopIteration
|
||||
result = self._tensor[self._index]
|
||||
self._index += 1
|
||||
return result
|
||||
|
||||
next = __next__ # python2.x compatibility.
|
||||
|
||||
|
||||
def keras_tensor_to_placeholder(x):
|
||||
"""TODO(kaftan): Docstring."""
|
||||
if isinstance(x, KerasTensor):
|
||||
def tensor_spec_to_placeholder(tensorspec):
|
||||
return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
|
||||
ph = nest.map_structure(tensor_spec_to_placeholder, x.type_spec,
|
||||
expand_composites=True)
|
||||
return ph
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def keras_tensor_from_tensor(x):
|
||||
name = getattr(x, 'name', None)
|
||||
out = KerasTensor(type_spec_module.type_spec_from_value(x), name=name)
|
||||
if hasattr(x, '_keras_mask'):
|
||||
out._keras_mask = KerasTensor( # pylint: disable=protected-access
|
||||
type_spec_module.type_spec_from_value(x._keras_mask)) # pylint: disable=protected-access
|
||||
|
||||
return out
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -80,11 +81,13 @@ class Node(object):
|
||||
self._single_positional_tensor_passed = (not self.call_kwargs and len(
|
||||
self.call_args) == 1 and tensor_util.is_tensor(self.call_args[0]))
|
||||
|
||||
# Create TensorFlowOpLayers if needed.
|
||||
for obj in self._flat_arguments:
|
||||
if (isinstance(obj, ops.Tensor) and
|
||||
base_layer_utils.needs_keras_history(obj, ignore_call_context=True)):
|
||||
base_layer_utils.create_keras_history(obj)
|
||||
if not keras_tensor.keras_tensors_enabled():
|
||||
# Create TensorFlowOpLayers if needed.
|
||||
for obj in self._flat_arguments:
|
||||
if (isinstance(obj, ops.Tensor) and
|
||||
base_layer_utils.needs_keras_history(
|
||||
obj, ignore_call_context=True)):
|
||||
base_layer_utils.create_keras_history(obj)
|
||||
|
||||
self._keras_inputs = []
|
||||
self._keras_inputs_ids_and_indices = []
|
||||
|
@ -690,7 +690,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
metrics=['accuracy'],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_that_trainable_disables_updates(self):
|
||||
val_a = np.random.random((10, 4))
|
||||
val_out = np.random.random((10, 4))
|
||||
@ -1584,7 +1584,7 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_sparse_op_with_op_layer(self):
|
||||
inputs = layers_module.Input(shape=(2,), sparse=True, name='sparse_tensor')
|
||||
output = sparse_ops.sparse_minimum(inputs, inputs)
|
||||
@ -2817,7 +2817,7 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
|
||||
scores = model.train_on_batch(x, y, sample_weight=w)
|
||||
self.assertArrayNear(scores, [0.3328, 0.8], 0.001)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_add_metric_with_tensor_on_model(self):
|
||||
x = layers_module.Input(shape=(1,))
|
||||
y = layers_module.Dense(1, kernel_initializer='ones')(x)
|
||||
@ -2932,7 +2932,8 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
|
||||
self.assertEqual(history.history['metric_1'][-1], 5)
|
||||
self.assertAlmostEqual(history.history['val_metric_1'][-1], 5, 0)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
||||
skip_keras_tensors=True)
|
||||
def test_model_metrics_list(self):
|
||||
|
||||
class LayerWithAddMetric(layers_module.Layer):
|
||||
|
@ -303,7 +303,8 @@ def _test_sequential_model_type(f, test_or_class, *args, **kwargs):
|
||||
def run_all_keras_modes(test_or_class=None,
|
||||
config=None,
|
||||
always_skip_v1=False,
|
||||
always_skip_eager=False):
|
||||
always_skip_eager=False,
|
||||
**kwargs):
|
||||
"""Execute the decorated test with all keras execution modes.
|
||||
|
||||
This decorator is intended to be applied either to individual test methods in
|
||||
@ -361,6 +362,9 @@ def run_all_keras_modes(test_or_class=None,
|
||||
when Tensorflow v2 behavior is not enabled.
|
||||
always_skip_eager: If True, does not execute the decorated test
|
||||
with eager execution modes.
|
||||
**kwargs: Additional kwargs for configuring tests for
|
||||
in-progress Keras behaviors/ refactorings that we haven't fully
|
||||
rolled out yet
|
||||
|
||||
Returns:
|
||||
Returns a decorator that will run the decorated test method multiple times.
|
||||
@ -369,8 +373,14 @@ def run_all_keras_modes(test_or_class=None,
|
||||
ImportError: If abseil parameterized is not installed or not included as
|
||||
a target dependency.
|
||||
"""
|
||||
skip_keras_tensors = kwargs.pop('skip_keras_tensors', False)
|
||||
if kwargs:
|
||||
raise ValueError('Unrecognized keyword args: {}'.format(kwargs))
|
||||
|
||||
params = [('_v2_function', 'v2_function')]
|
||||
if not skip_keras_tensors:
|
||||
params.append(('_v2_function_use_keras_tensors',
|
||||
'v2_function_use_keras_tensors'))
|
||||
if not always_skip_eager:
|
||||
params.append(('_v2_eager', 'v2_eager'))
|
||||
if not (always_skip_v1 or tf2.enabled()):
|
||||
@ -390,6 +400,8 @@ def run_all_keras_modes(test_or_class=None,
|
||||
_v2_eager_test(f, self, *args, **kwargs)
|
||||
elif run_mode == 'v2_function':
|
||||
_v2_function_test(f, self, *args, **kwargs)
|
||||
elif run_mode == 'v2_function_use_keras_tensors':
|
||||
_v2_function_and_kerastensors_test(f, self, *args, **kwargs)
|
||||
else:
|
||||
return ValueError('Unknown run mode %s' % run_mode)
|
||||
|
||||
@ -417,6 +429,13 @@ def _v2_function_test(f, test_or_class, *args, **kwargs):
|
||||
f(test_or_class, *args, **kwargs)
|
||||
|
||||
|
||||
def _v2_function_and_kerastensors_test(f, test_or_class, *args, **kwargs):
|
||||
with context.eager_mode():
|
||||
with testing_utils.run_eagerly_scope(False):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
f(test_or_class, *args, **kwargs)
|
||||
|
||||
|
||||
def _test_or_class_decorator(test_or_class, single_method_decorator):
|
||||
"""Decorate a test or class with a decorator intended for one method.
|
||||
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -206,7 +207,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def testBody(self):
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
should_run_eagerly = testing_utils.should_run_eagerly()
|
||||
@ -242,6 +243,54 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
ts.run(res)
|
||||
self.assertLen(l, 4)
|
||||
|
||||
def test_run_all_keras_modes_include_keras_tensors(self):
|
||||
l = []
|
||||
|
||||
class ExampleTest(keras_parameterized.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_all_keras_modes()
|
||||
def testBody(self):
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
should_run_eagerly = testing_utils.should_run_eagerly()
|
||||
l.append((mode, should_run_eagerly,
|
||||
keras_tensor.keras_tensors_enabled()))
|
||||
|
||||
e = ExampleTest()
|
||||
if not tf2.enabled():
|
||||
e.testBody_v1_session()
|
||||
e.testBody_v2_eager()
|
||||
e.testBody_v2_function()
|
||||
e.testBody_v2_function_use_keras_tensors()
|
||||
|
||||
if not tf2.enabled():
|
||||
self.assertLen(l, 4)
|
||||
self.assertAllEqual(l, [
|
||||
("graph", False, False),
|
||||
("eager", True, False),
|
||||
("eager", False, False),
|
||||
("eager", False, True),
|
||||
])
|
||||
|
||||
ts = unittest.makeSuite(ExampleTest)
|
||||
res = unittest.TestResult()
|
||||
ts.run(res)
|
||||
self.assertLen(l, 8)
|
||||
else:
|
||||
self.assertLen(l, 3)
|
||||
self.assertAllEqual(l, [
|
||||
("eager", True, False),
|
||||
("eager", False, False),
|
||||
("eager", False, True),
|
||||
])
|
||||
|
||||
ts = unittest.makeSuite(ExampleTest)
|
||||
res = unittest.TestResult()
|
||||
ts.run(res)
|
||||
self.assertLen(l, 6)
|
||||
|
||||
def test_run_all_keras_modes_extra_params(self):
|
||||
l = []
|
||||
|
||||
@ -250,7 +299,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters(
|
||||
[dict(testcase_name="_0", with_brackets=True),
|
||||
dict(testcase_name="_1", with_brackets=False)])
|
||||
@ -300,7 +349,8 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
||||
skip_keras_tensors=True)
|
||||
def testBody(self):
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
should_run_eagerly = testing_utils.should_run_eagerly()
|
||||
@ -330,7 +380,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def testBody(self):
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
should_run_eagerly = testing_utils.should_run_eagerly()
|
||||
@ -382,7 +432,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def testBody(self):
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
@ -431,7 +481,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
l = []
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
class ExampleTest(keras_parameterized.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
@ -491,7 +541,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters(dict(testcase_name="_arg",
|
||||
arg=True))
|
||||
def testBody(self, arg):
|
||||
@ -537,7 +587,7 @@ class KerasParameterizedTest(keras_parameterized.TestCase):
|
||||
|
||||
self.assertLen(l, len(expected_combinations) * 2)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters(dict(testcase_name="argument",
|
||||
arg=True))
|
||||
def test_run_all_keras_modes_extra_params_2(self, arg):
|
||||
|
@ -139,8 +139,10 @@ py_library(
|
||||
"//tensorflow/python/keras:base_layer",
|
||||
"//tensorflow/python/keras:constraints",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:losses",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
"//tensorflow/python/keras/engine:keras_tensor",
|
||||
"//tensorflow/python/keras/layers/ops:core",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.layers.ops import core as core_ops
|
||||
@ -52,8 +53,12 @@ from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
||||
from tensorflow.python.util.tf_export import get_symbol_from_name
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -1255,3 +1260,155 @@ class ActivityRegularization(Layer):
|
||||
config = {'l1': self.l1, 'l2': self.l2}
|
||||
base_config = super(ActivityRegularization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class TFOpLambda(Layer):
|
||||
"""Wraps TF API symbols in a `Layer` object.
|
||||
|
||||
It is inserted by the Functional API construction whenever users call
|
||||
a supported TF symbol on KerasTensors.
|
||||
|
||||
Like Lambda layers, this layer tries to raise warnings when it detects users
|
||||
explicitly use variables in the call. (To let them know
|
||||
that the layer will not capture the variables).
|
||||
|
||||
This is useful in the case where users do something like:
|
||||
x = keras.Input(...)
|
||||
y = tf.Variable(...)
|
||||
out = x * tf_variable
|
||||
"""
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, function, **kwargs):
|
||||
self.function = function
|
||||
self.symbol = (
|
||||
get_canonical_name_for_symbol(self.function,
|
||||
add_prefix_to_v1_names=True) or
|
||||
get_canonical_name_for_symbol(self.function,
|
||||
api_name='keras',
|
||||
add_prefix_to_v1_names=True))
|
||||
kwargs['autocast'] = False
|
||||
|
||||
# Decorate the function to produce this layer's call method
|
||||
def _call_wrapper(*args, **kwargs):
|
||||
return self._call_wrapper(*args, **kwargs)
|
||||
self.call = tf_decorator.make_decorator(function, _call_wrapper)
|
||||
|
||||
super(TFOpLambda, self).__init__(**kwargs)
|
||||
|
||||
# Warning on every invocation will be quite irksome in Eager mode.
|
||||
self._already_warned = False
|
||||
|
||||
self._expects_training_arg = False
|
||||
self._expects_mask_arg = False
|
||||
|
||||
def _call_wrapper(self, *args, **kwargs):
|
||||
created_variables = []
|
||||
def _variable_creator(next_creator, **creator_kwargs):
|
||||
var = next_creator(**creator_kwargs)
|
||||
created_variables.append(var)
|
||||
return var
|
||||
|
||||
with backprop.GradientTape(watch_accessed_variables=True) as tape, \
|
||||
variable_scope.variable_creator_scope(_variable_creator):
|
||||
# We explicitly drop `name` arguments here,
|
||||
# to guard against the case where an op explicitly has a
|
||||
# `name` passed (which is susceptible to producing
|
||||
# multiple ops w/ the same name when the layer is reused)
|
||||
kwargs.pop('name', None)
|
||||
result = self.function(*args, **kwargs)
|
||||
self._check_variables(created_variables, tape.watched_variables())
|
||||
return result
|
||||
|
||||
def _check_variables(self, created_variables, accessed_variables):
|
||||
if not created_variables and not accessed_variables:
|
||||
# In the common case that a Lambda layer does not touch a Variable, we
|
||||
# don't want to incur the runtime cost of assembling any state used for
|
||||
# checking only to immediately discard it.
|
||||
return
|
||||
|
||||
tracked_weights = set(v.ref() for v in self.weights)
|
||||
untracked_new_vars = [
|
||||
v for v in created_variables if v.ref() not in tracked_weights
|
||||
]
|
||||
if untracked_new_vars:
|
||||
variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars)
|
||||
error_str = textwrap.dedent(
|
||||
'''
|
||||
The following Variables were created within a Lambda layer ({name})
|
||||
but are not tracked by said layer:
|
||||
{variable_str}
|
||||
The layer cannot safely ensure proper Variable reuse across multiple
|
||||
calls, and consquently this behavior is disallowed for safety. Lambda
|
||||
layers are not well suited to stateful computation; instead, writing a
|
||||
subclassed Layer is the recommend way to define layers with
|
||||
Variables.'''
|
||||
).format(name=self.name, variable_str=variable_str)
|
||||
raise ValueError(error_str)
|
||||
|
||||
untracked_used_vars = [
|
||||
v for v in accessed_variables if v.ref() not in tracked_weights
|
||||
]
|
||||
if untracked_used_vars and not self._already_warned:
|
||||
variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars)
|
||||
self._warn(textwrap.dedent(
|
||||
'''
|
||||
The following Variables were used a Lambda layer's call ({name}), but
|
||||
are not present in its tracked objects:
|
||||
{variable_str}
|
||||
It is possible that this is intended behavior, but it is more likely
|
||||
an omission. This is a strong indication that this layer should be
|
||||
formulated as a subclassed Layer rather than a Lambda layer.'''
|
||||
).format(name=self.name, variable_str=variable_str))
|
||||
self._already_warned = True
|
||||
|
||||
def _warn(self, msg):
|
||||
# This method will be overridden in a unit test to raise an error, because
|
||||
# self.assertWarns is not universally implemented.
|
||||
return tf_logging.warn(msg)
|
||||
|
||||
def get_config(self):
|
||||
if not self.symbol:
|
||||
raise ValueError('This Keras op layer was generated from %s, a method '
|
||||
'that is not an exposed in the TensorFlow API. This '
|
||||
'may have happened if the method was explicitly '
|
||||
'decorated to add dispatching support, and it was used '
|
||||
'during Functional model construction. '
|
||||
'To ensure cross-version compatibility of Keras models '
|
||||
'that use op layers, only op layers produced from '
|
||||
'exported TF API symbols can be serialized.'
|
||||
% self.function)
|
||||
config = {
|
||||
'function': self.symbol
|
||||
}
|
||||
|
||||
base_config = super(TFOpLambda, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
config = config.copy()
|
||||
symbol_name = config['function']
|
||||
function = get_symbol_from_name(symbol_name)
|
||||
if not function:
|
||||
raise ValueError(
|
||||
'TF symbol `tf.%s` could not be found.' % symbol_name)
|
||||
|
||||
config['function'] = function
|
||||
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||
"""A global dispatcher that allows building a functional model with TF Ops."""
|
||||
|
||||
def handle(self, op, args, kwargs):
|
||||
"""Handle the specified operation with the specified arguments."""
|
||||
if any(
|
||||
isinstance(x, keras_tensor.KerasTensor)
|
||||
for x in nest.flatten([args, kwargs])):
|
||||
return TFOpLambda(op)(*args, **kwargs)
|
||||
else:
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
KerasOpDispatcher().register()
|
||||
|
@ -523,7 +523,11 @@ class Concatenate(_Merge):
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def compute_output_shape(self, input_shape):
|
||||
if not isinstance(input_shape, (tuple, list)):
|
||||
if ((not isinstance(input_shape, (tuple, list))) or
|
||||
(not isinstance(input_shape[0], (tuple, list)))):
|
||||
# The tf_utils.shape_type_conversion decorator turns tensorshapes
|
||||
# into tuples, so we need to verify that `input_shape` is a list/tuple,
|
||||
# *and* that the individual elements are themselves shape tuples.
|
||||
raise ValueError('A `Concatenate` layer should be called '
|
||||
'on a list of inputs.')
|
||||
input_shapes = input_shape
|
||||
|
@ -50,7 +50,6 @@ def _single_identity_op_at_end():
|
||||
inputs = keras.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10)(inputs)
|
||||
outputs = array_ops.identity(x)
|
||||
assert 'Identity' in outputs.name
|
||||
return keras.Model(inputs, outputs)
|
||||
|
||||
|
||||
@ -186,7 +185,7 @@ def _reuse_ancillary_layer():
|
||||
return model
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -313,11 +312,6 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
e = 3 # Fudge factor to prevent flakiness.
|
||||
self.assertLess(size_500, (10 * e) * size_50)
|
||||
|
||||
def test_no_mask_tracking(self):
|
||||
x = keras.backend.placeholder((10, 10))
|
||||
y = keras.layers.Masking(0.)(x)
|
||||
self.assertTrue(y._keras_mask._keras_history_checked)
|
||||
|
||||
def test_built(self):
|
||||
inputs = keras.Input(shape=(10,))
|
||||
outputs = gen_nn_ops.relu(inputs)
|
||||
|
@ -79,7 +79,7 @@ def _get_model(input_shape=(4,)):
|
||||
|
||||
class TestModelCloning(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'has_input_layer',
|
||||
'input_shape': (4,),
|
||||
@ -142,7 +142,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
self.assertIsInstance(new_model._layers[0], keras.layers.InputLayer)
|
||||
self.assertTrue(new_model._is_graph_network)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'clone_weights', 'share_weights': False},
|
||||
{'testcase_name': 'share_weights', 'share_weights': True},
|
||||
|
@ -40,7 +40,8 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
||||
skip_keras_tensors=True)
|
||||
class LinearModelTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_linear_model_with_single_input(self):
|
||||
|
@ -37,7 +37,8 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
||||
skip_keras_tensors=True)
|
||||
class WideDeepModelTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_wide_deep_model(self):
|
||||
|
@ -83,7 +83,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
|
||||
self.assertEqual(len(model.losses), 1)
|
||||
model.fit(x_train, y_train, batch_size=10, epochs=1, verbose=0)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters([
|
||||
('l1', regularizers.l1()),
|
||||
('l2', regularizers.l2()),
|
||||
@ -126,7 +126,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
|
||||
model.get_config(), custom_objects={'my_regularizer': my_regularizer})
|
||||
self.assertEqual(model2.layers[1].kernel_regularizer, my_regularizer)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters([
|
||||
('l1', regularizers.l1()),
|
||||
('l2', regularizers.l2()),
|
||||
@ -144,7 +144,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
self.assertLen(model.losses, 5)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters([
|
||||
('l1', regularizers.l1()),
|
||||
('l2', regularizers.l2()),
|
||||
@ -166,7 +166,7 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
self.assertLen(model.losses, 6)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@parameterized.named_parameters([
|
||||
('l1', regularizers.l1()),
|
||||
('l2', regularizers.l2()),
|
||||
|
@ -147,6 +147,7 @@ tf_py_test(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
|
||||
@ -331,6 +332,29 @@ def run_eagerly_scope(value):
|
||||
_thread_local_data.run_eagerly = previous_value
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def use_keras_tensors_scope(value):
|
||||
"""Provides a scope within which we use KerasTensors in the func. API or not.
|
||||
|
||||
The boolean gets restored to its original value upon exiting the scope.
|
||||
|
||||
Arguments:
|
||||
value: Bool specifying if we should build functional models
|
||||
using KerasTensors in the active test.
|
||||
Should be True or False.
|
||||
|
||||
Yields:
|
||||
The provided value.
|
||||
"""
|
||||
previous_value = keras_tensor._KERAS_TENSORS_ENABLED # pylint: disable=protected-access
|
||||
try:
|
||||
keras_tensor._KERAS_TENSORS_ENABLED = value # pylint: disable=protected-access
|
||||
yield value
|
||||
finally:
|
||||
# Restore KerasTensor usage to initial value.
|
||||
keras_tensor._KERAS_TENSORS_ENABLED = previous_value # pylint: disable=protected-access
|
||||
|
||||
|
||||
def should_run_eagerly():
|
||||
"""Returns whether the models we are testing should be run eagerly."""
|
||||
if _thread_local_data.run_eagerly is None:
|
||||
|
@ -69,7 +69,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
|
||||
self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32')
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_loss_on_model_fit(self):
|
||||
inputs = Input(shape=(1,))
|
||||
targets = Input(shape=(1,))
|
||||
@ -85,7 +85,8 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True,
|
||||
always_skip_v1=True)
|
||||
def test_loss_callable_on_model_fit(self):
|
||||
model = testing_utils.get_model_from_layers([testing_utils.Bias()],
|
||||
input_shape=(1,))
|
||||
@ -144,7 +145,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
loss = [train_step(self.x, self.y) for _ in range(5)]
|
||||
self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_loss_with_sample_weight_on_model_fit(self):
|
||||
inputs = Input(shape=(1,))
|
||||
targets = Input(shape=(1,))
|
||||
@ -181,7 +182,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
|
||||
self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_loss_with_sample_weight_in_model_call(self):
|
||||
|
||||
class MyModel(Model):
|
||||
@ -209,7 +210,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
eval_out = model.evaluate([self.x, self.y, self.w])
|
||||
self.assertAlmostEqual(eval_out, 1.0, 3)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_loss_with_sample_weight_in_layer_call(self):
|
||||
|
||||
class MyLayer(layers.Layer):
|
||||
@ -244,7 +245,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
output = model.test_on_batch([self.x, self.y, self.w])
|
||||
self.assertAlmostEqual(output, 1.0, 3)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_loss_on_layer(self):
|
||||
|
||||
class MyLayer(layers.Layer):
|
||||
@ -265,7 +266,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||
self.assertEqual(loss, 2 * 3)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_activity_regularizer(self):
|
||||
loss = {}
|
||||
@ -299,7 +300,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
loss[reg] = model.evaluate(x, y)
|
||||
self.assertLess(loss[None], loss['l2'])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_activity_regularizer_loss_value(self):
|
||||
layer = layers.Dense(
|
||||
@ -318,7 +319,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
loss = model.test_on_batch(x)
|
||||
self.assertAlmostEqual(0.01, loss, places=4)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_activity_regularizer_batch_independent(self):
|
||||
inputs = layers.Input(shape=(10,))
|
||||
x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs)
|
||||
@ -334,7 +335,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32'))
|
||||
self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_with_shared_layer(self):
|
||||
|
||||
class LayerWithLoss(layers.Layer):
|
||||
@ -351,7 +352,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
self.assertEqual(len(m2.losses), 2)
|
||||
self.assertAllClose(m2.losses, [6, 12])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_with_shared_nested_layer(self):
|
||||
|
||||
class LayerWithLoss(layers.Layer):
|
||||
@ -377,7 +378,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
self.assertEqual(len(m2.losses), 2)
|
||||
self.assertAllClose(m2.losses, [6, 12])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_clear_losses(self):
|
||||
|
||||
class LayerWithSharedNestedLossLayer(layers.Layer):
|
||||
@ -428,7 +429,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
self.assertEqual(len(model.get_losses_for(x4)), 2)
|
||||
self.assertEqual(len(model.get_losses_for(None)), 1)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_invalid_constant_input(self):
|
||||
with context.eager_mode():
|
||||
inputs = Input(shape=(1,))
|
||||
@ -439,7 +440,7 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
'Expected a symbolic Tensors or a callable for the loss value'):
|
||||
model.add_loss(1.)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(skip_keras_tensors=True)
|
||||
def test_invalid_variable_input(self):
|
||||
with context.eager_mode():
|
||||
inputs = Input(shape=(1,))
|
||||
|
@ -506,7 +506,8 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
|
||||
model_input = input_layer.Input(
|
||||
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32,
|
||||
batch_size=2)
|
||||
self.assertIsInstance(model_input, ragged_tensor.RaggedTensor)
|
||||
self.assertIsInstance(model_input._type_spec,
|
||||
ragged_tensor.RaggedTensorSpec)
|
||||
self.assertEqual(model_input.shape.as_list(), [2, None, None])
|
||||
layers = [ToDense(default_value=-1)]
|
||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||
@ -602,7 +603,8 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
||||
|
||||
|
||||
@keras_parameterized.run_with_all_model_types()
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True,
|
||||
skip_keras_tensors=True)
|
||||
class CompositeTensorModelPredictTest(keras_parameterized.TestCase):
|
||||
|
||||
def _normalize_shape(self, shape):
|
||||
|
Loading…
x
Reference in New Issue
Block a user