STT-tensorflow/tensorflow/python/keras/combinations.py
Tomer Kaftan 0e1f3de50a This change adds WIP support for using KerasTensor objects in Keras's functional API instead of symbolic graph tf.Tensors. It is controlled by an internal behavior flag that is disabled by default and is not yet exposed in TF's APIs.
`KerasTensor`s are an alternative representation for Keras `Inputs`
  and for intermediate outputs of layers during Functional API construction of
  models. They are a lightweight data structure comprised of only the
  `tf.TypeSpec` of the Tensor that will be consumed/produced in the
  corresponding position of the model.

  They implement just small subset of `tf.Tensor`'s attributes and
  methods, and also overload
  the same operators as `tf.Tensor` and automatically turn them into
  Keras layers in the model.

  `KerasTensor`s are still internal-only and are a work in progress, but they
  have several advantages over using a graph `tf.Tensor` to represent
  symbolic values in functional models.
  - Unlike symbolic tensors, they do not need to refer to a graph. This means
    Keras does not need to maintain a never-deleted global background graph
    containing all layers ever called during functional model construction when
    constructing Functional Models with KerasTensors. These memory savings
    can be significant.

  - Triggering Keras functional model construction is simpler
    when it just has to check whether something is a KerasTensor, rather
    than trying to infer if a tensor was meant to be a symbolic keras
    representation or just a value produced during function tracing. This means we can add support for cases where values in nest.flatten(*args, **kwargs) are a completely arbitrary mix of KerasTensors and objects that are not KerasTensors, as long as any value is a KerasTensor.

  - Autolambda layers (converting tf ops on symbolic Keras tensors to lambda
    Keras layers in the model) use TF's internal dispatching mechanism, instead
    of trying to manually walk a graph and extract nodes from it.
    The dispatching mechanism is simpler, works more reliably, and is less
    likely to run into issues with composite tensors or strange tf ops/nodes.

    (And when it fails, it's by design: because dispatch is explicitly not
    supported on the op & it's more obvious that dispatch doesn't support the
    setting).

  - Because they support arbitrary typespecs, models/layers that use
    KerasTensors are generally more friendly to composite tensors of different
    types than using symbolic graph tensors (which must have a TensorSpec and
    can't have arbitrary typespecs)

  To experiment with using KerasTensors instead of symbolic graph `tf.Tensors`,
  import keras_tensor directly and call `keras_tensor.enable_keras_tensors()`

PiperOrigin-RevId: 315009281
Change-Id: I6765f3a44da43f965ec261b6b193df26598cffae
2020-06-05 15:47:28 -07:00

136 lines
4.5 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module customizes `test_combinations` for `tf.keras` related tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python import tf2
from tensorflow.python.framework import combinations
from tensorflow.python.framework import test_combinations
from tensorflow.python.keras import testing_utils
KERAS_MODEL_TYPES = ['functional', 'subclass', 'sequential']
def keras_mode_combinations(mode=None, run_eagerly=None):
"""Returns the default test combinations for tf.keras tests.
Note that if tf2 is enabled, then v1 session test will be skipped.
Args:
mode: List of modes to run the tests. The valid options are 'graph' and
'eager'. Default to ['graph', 'eager'] if not specified. If a empty list
is provide, then the test will run under the context based on tf's
version, eg graph for v1 and eager for v2.
run_eagerly: List of `run_eagerly` value to be run with the tests.
Default to [True, False] if not specified. Note that for `graph` mode,
run_eagerly value will only be False.
Returns:
A list contains all the combinations to be used to generate test cases.
"""
if mode is None:
mode = ['eager'] if tf2.enabled() else ['graph', 'eager']
if run_eagerly is None:
run_eagerly = [True, False]
result = []
if 'eager' in mode:
result += combinations.combine(mode=['eager'], run_eagerly=run_eagerly)
if 'graph' in mode:
result += combinations.combine(mode=['graph'], run_eagerly=[False])
return result
def keras_model_type_combinations():
return combinations.combine(model_type=KERAS_MODEL_TYPES)
def keras_tensor_combinations():
return combinations.combine(use_keras_tensors=['True', 'False'])
class KerasModeCombination(test_combinations.TestCombination):
"""Combination for Keras test mode.
It by default includes v1_session, v2_eager and v2_tf_function.
"""
def context_managers(self, kwargs):
run_eagerly = kwargs.pop('run_eagerly', None)
if run_eagerly is not None:
return [testing_utils.run_eagerly_scope(run_eagerly)]
else:
return []
def parameter_modifiers(self):
return [test_combinations.OptionalParameter('run_eagerly')]
class KerasModelTypeCombination(test_combinations.TestCombination):
"""Combination for Keras model types when doing model test.
It by default includes 'functional', 'subclass', 'sequential'.
Various methods in `testing_utils` to get models will auto-generate a model
of the currently active Keras model type. This allows unittests to confirm
the equivalence between different Keras models.
"""
def context_managers(self, kwargs):
model_type = kwargs.pop('model_type', None)
if model_type in KERAS_MODEL_TYPES:
return [testing_utils.model_type_scope(model_type)]
else:
return []
def parameter_modifiers(self):
return [test_combinations.OptionalParameter('model_type')]
class KerasTensorCombination(test_combinations.TestCombination):
"""Combination for whether KerasTensors are being used or not.
It by default includes `True` and `False`:
running Keras's functional API with KerasTensors
as the inputs, and without.
"""
def context_managers(self, kwargs):
use_keras_tensors = kwargs.pop('use_keras_tensors', None)
if use_keras_tensors is not None:
return [testing_utils.use_keras_tensors_scope(use_keras_tensors)]
else:
return []
def parameter_modifiers(self):
return [test_combinations.OptionalParameter('use_keras_tensors')]
_defaults = combinations.generate.keywords['test_combinations']
generate = functools.partial(
combinations.generate,
test_combinations=_defaults +
(KerasModeCombination(), KerasModelTypeCombination(),
KerasTensorCombination()))
combine = test_combinations.combine
times = test_combinations.times
NamedObject = test_combinations.NamedObject