`KerasTensor`s are an alternative representation for Keras `Inputs` and for intermediate outputs of layers during Functional API construction of models. They are a lightweight data structure comprised of only the `tf.TypeSpec` of the Tensor that will be consumed/produced in the corresponding position of the model. They implement just small subset of `tf.Tensor`'s attributes and methods, and also overload the same operators as `tf.Tensor` and automatically turn them into Keras layers in the model. `KerasTensor`s are still internal-only and are a work in progress, but they have several advantages over using a graph `tf.Tensor` to represent symbolic values in functional models. - Unlike symbolic tensors, they do not need to refer to a graph. This means Keras does not need to maintain a never-deleted global background graph containing all layers ever called during functional model construction when constructing Functional Models with KerasTensors. These memory savings can be significant. - Triggering Keras functional model construction is simpler when it just has to check whether something is a KerasTensor, rather than trying to infer if a tensor was meant to be a symbolic keras representation or just a value produced during function tracing. This means we can add support for cases where values in nest.flatten(*args, **kwargs) are a completely arbitrary mix of KerasTensors and objects that are not KerasTensors, as long as any value is a KerasTensor. - Autolambda layers (converting tf ops on symbolic Keras tensors to lambda Keras layers in the model) use TF's internal dispatching mechanism, instead of trying to manually walk a graph and extract nodes from it. The dispatching mechanism is simpler, works more reliably, and is less likely to run into issues with composite tensors or strange tf ops/nodes. (And when it fails, it's by design: because dispatch is explicitly not supported on the op & it's more obvious that dispatch doesn't support the setting). - Because they support arbitrary typespecs, models/layers that use KerasTensors are generally more friendly to composite tensors of different types than using symbolic graph tensors (which must have a TensorSpec and can't have arbitrary typespecs) To experiment with using KerasTensors instead of symbolic graph `tf.Tensors`, import keras_tensor directly and call `keras_tensor.enable_keras_tensors()` PiperOrigin-RevId: 315009281 Change-Id: I6765f3a44da43f965ec261b6b193df26598cffae
136 lines
4.5 KiB
Python
136 lines
4.5 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""This module customizes `test_combinations` for `tf.keras` related tests."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
|
|
from tensorflow.python import tf2
|
|
from tensorflow.python.framework import combinations
|
|
from tensorflow.python.framework import test_combinations
|
|
from tensorflow.python.keras import testing_utils
|
|
|
|
KERAS_MODEL_TYPES = ['functional', 'subclass', 'sequential']
|
|
|
|
|
|
def keras_mode_combinations(mode=None, run_eagerly=None):
|
|
"""Returns the default test combinations for tf.keras tests.
|
|
|
|
Note that if tf2 is enabled, then v1 session test will be skipped.
|
|
|
|
Args:
|
|
mode: List of modes to run the tests. The valid options are 'graph' and
|
|
'eager'. Default to ['graph', 'eager'] if not specified. If a empty list
|
|
is provide, then the test will run under the context based on tf's
|
|
version, eg graph for v1 and eager for v2.
|
|
run_eagerly: List of `run_eagerly` value to be run with the tests.
|
|
Default to [True, False] if not specified. Note that for `graph` mode,
|
|
run_eagerly value will only be False.
|
|
|
|
Returns:
|
|
A list contains all the combinations to be used to generate test cases.
|
|
"""
|
|
if mode is None:
|
|
mode = ['eager'] if tf2.enabled() else ['graph', 'eager']
|
|
if run_eagerly is None:
|
|
run_eagerly = [True, False]
|
|
result = []
|
|
if 'eager' in mode:
|
|
result += combinations.combine(mode=['eager'], run_eagerly=run_eagerly)
|
|
if 'graph' in mode:
|
|
result += combinations.combine(mode=['graph'], run_eagerly=[False])
|
|
return result
|
|
|
|
|
|
def keras_model_type_combinations():
|
|
return combinations.combine(model_type=KERAS_MODEL_TYPES)
|
|
|
|
|
|
def keras_tensor_combinations():
|
|
return combinations.combine(use_keras_tensors=['True', 'False'])
|
|
|
|
|
|
class KerasModeCombination(test_combinations.TestCombination):
|
|
"""Combination for Keras test mode.
|
|
|
|
It by default includes v1_session, v2_eager and v2_tf_function.
|
|
"""
|
|
|
|
def context_managers(self, kwargs):
|
|
run_eagerly = kwargs.pop('run_eagerly', None)
|
|
|
|
if run_eagerly is not None:
|
|
return [testing_utils.run_eagerly_scope(run_eagerly)]
|
|
else:
|
|
return []
|
|
|
|
def parameter_modifiers(self):
|
|
return [test_combinations.OptionalParameter('run_eagerly')]
|
|
|
|
|
|
class KerasModelTypeCombination(test_combinations.TestCombination):
|
|
"""Combination for Keras model types when doing model test.
|
|
|
|
It by default includes 'functional', 'subclass', 'sequential'.
|
|
|
|
Various methods in `testing_utils` to get models will auto-generate a model
|
|
of the currently active Keras model type. This allows unittests to confirm
|
|
the equivalence between different Keras models.
|
|
"""
|
|
|
|
def context_managers(self, kwargs):
|
|
model_type = kwargs.pop('model_type', None)
|
|
if model_type in KERAS_MODEL_TYPES:
|
|
return [testing_utils.model_type_scope(model_type)]
|
|
else:
|
|
return []
|
|
|
|
def parameter_modifiers(self):
|
|
return [test_combinations.OptionalParameter('model_type')]
|
|
|
|
|
|
class KerasTensorCombination(test_combinations.TestCombination):
|
|
"""Combination for whether KerasTensors are being used or not.
|
|
|
|
It by default includes `True` and `False`:
|
|
running Keras's functional API with KerasTensors
|
|
as the inputs, and without.
|
|
"""
|
|
|
|
def context_managers(self, kwargs):
|
|
use_keras_tensors = kwargs.pop('use_keras_tensors', None)
|
|
|
|
if use_keras_tensors is not None:
|
|
return [testing_utils.use_keras_tensors_scope(use_keras_tensors)]
|
|
else:
|
|
return []
|
|
|
|
def parameter_modifiers(self):
|
|
return [test_combinations.OptionalParameter('use_keras_tensors')]
|
|
|
|
|
|
_defaults = combinations.generate.keywords['test_combinations']
|
|
generate = functools.partial(
|
|
combinations.generate,
|
|
test_combinations=_defaults +
|
|
(KerasModeCombination(), KerasModelTypeCombination(),
|
|
KerasTensorCombination()))
|
|
combine = test_combinations.combine
|
|
times = test_combinations.times
|
|
NamedObject = test_combinations.NamedObject
|