Restructure the Keras class hierarchy for Network, Model and Sequential.

The intention of this change is to reduce the code complexity within Keras class, especially for Network, which currently contains logic for both subclass Model and functional Model.

After this change, the subclass model and functional model become individual class and become self contained.

1. Model is now the base class for subclass model. It doesn't contains network structure management, and the topology will be created within __init__ and __call__, which is for user to implement. It also contains compile/fit/eval/predict, which is the basic functionality for model training.

2. Functional is created based on existing Network class. It extends the Model, which allows it leverage compile/fit/eval/predict. In addition, it also take input/output as init parameter and manage the network topology.

3. Sequential model is now a subclass of Functional, since it will use Functional's method to manage it topology (layer stacking).

Model(input, output) will create a Functional under the hood, and behave the same way as before.

PiperOrigin-RevId: 311232972
Change-Id: I6dd32e089cd294d35d5a1f3684e1a1ae1a0ab320
This commit is contained in:
Scott Zhu 2020-05-12 17:16:41 -07:00 committed by TensorFlower Gardener
parent ce5488f85f
commit bb15c97379
29 changed files with 1023 additions and 1052 deletions

View File

@ -21,8 +21,8 @@ py_library(
srcs = [
"__init__.py",
"compile_utils.py",
"functional.py",
"input_layer.py",
"network.py",
"node.py",
"partial_batch_padding_handler.py",
"saving.py",
@ -460,9 +460,9 @@ tf_py_test(
)
tf_py_test(
name = "network_test",
name = "functional_test",
size = "medium",
srcs = ["network_test.py"],
srcs = ["functional_test.py"],
python_version = "PY3",
shard_count = 8,
tags = [

View File

@ -1006,13 +1006,23 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
"""Whether the layer is dynamic (eager-only); set in the constructor."""
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
# then this cache logic must be updated.
return self._dynamic
return self._dynamic or any(layer.dynamic
for layer in self._unique_sublayers())
def _unique_sublayers(self):
# Model.layers will use this as implementation, but we can't expose this
# one as the public property since it might conflict with subclass layers
# which also have user defined layers property.
self._maybe_create_attribute('_layers', [])
return list(
trackable_layer_utils.filter_empty_layer_containers(self._layers))
@property
@doc_controls.do_not_doc_inheritable
@trackable_layer_utils.cache_recursive_attribute('stateful')
def stateful(self):
return self._stateful
return self._stateful or any(
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
@stateful.setter
@trackable_layer_utils.invalidate_recursive_cache('stateful')

View File

@ -833,13 +833,15 @@ class Layer(base_layer.Layer):
def dynamic(self):
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
# then this cache logic must be updated.
return self._dynamic
return self._dynamic or any(layer.dynamic
for layer in self._unique_sublayers())
@property
@doc_controls.do_not_generate_docs
@trackable_layer_utils.cache_recursive_attribute('stateful')
def stateful(self):
return self._stateful
return self._stateful or any(
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
@stateful.setter
@trackable_layer_utils.invalidate_recursive_cache('stateful')

View File

@ -33,8 +33,8 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.engine import input_layer as input_layer_lib
from tensorflow.python.keras.engine import network as network_lib
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.utils import layer_utils
@ -89,7 +89,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
self.assertEqual(len(layer.updates), 3)
network = network_lib.Network(x2, y2)
network = functional.Functional(x2, y2)
self.assertEqual(len(network.updates), 3)
x3 = input_layer_lib.Input(shape=(1,))
@ -120,7 +120,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
dense_a = layers.Dense(4, name='dense_a')
dense_b = layers.Dense(2, name='dense_b')
y = dense_b(dense_a(x))
network = network_lib.Network(x, y, name='dense_network')
network = functional.Functional(x, y, name='dense_network')
# test various get_layer by index
self.assertEqual(network.get_layer(index=1), dense_a)
@ -251,7 +251,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
x = input_layer_lib.Input(shape=(32,))
dense = layers.Dense(2)
y = dense(x)
network = network_lib.Network(x, y, name='dense_network')
network = functional.Functional(x, y, name='dense_network')
# test basic attributes
self.assertEqual(network.name, 'dense_network')
@ -740,7 +740,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
else:
x = input_layer_lib.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable
network = network_lib.Network(x, y)
network = functional.Functional(x, y)
# test callability on Input
x_2 = input_layer_lib.Input(shape=(32,))
@ -1102,7 +1102,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
def test_subclassed_error_if_init_not_called(self):
class MyNetwork(network_lib.Network):
class MyNetwork(training_lib.Model):
def __init__(self):
self._foo = [layers.Dense(10), layers.Dense(10)]
@ -1124,10 +1124,12 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
inputs = input_layer_lib.Input(shape=(32,))
outputs = layers.Dense(4)(inputs)
with self.assertRaisesRegexp(TypeError, 'unexpected argument'):
with self.assertRaisesRegexp(TypeError,
'got an unexpected keyword argument'):
model = training_lib.Model(
inputs, outputs, name='m', trainable=False, dtype='int64')
with self.assertRaisesRegexp(TypeError, 'unexpected argument'):
with self.assertRaisesRegexp(TypeError,
'got an unexpected keyword argument'):
model = training_lib.Model(
inputs, outputs, name='m', trainable=False, dynamic=False)
@ -1136,8 +1138,10 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
self.assertFalse(model.trainable)
self.assertFalse(model.dynamic)
class SubclassModel(training_lib.Model):
pass
# Subclassed model
model = training_lib.Model(
model = SubclassModel(
name='subclassed', trainable=True, dtype='int64', dynamic=True)
self.assertEqual('subclassed', model.name)
self.assertTrue(model.dynamic)
@ -1150,9 +1154,9 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
input_tensor2 = input_layer_lib.Input(shape=[10], name='b')
output_tensor1 = layers.Dense(units=10)(input_tensor1)
net = network_lib.Network(
net = functional.Functional(
inputs=[input_tensor1, input_tensor2], outputs=[output_tensor1])
net2 = network_lib.Network.from_config(net.get_config())
net2 = functional.Functional.from_config(net.get_config())
self.assertLen(net2.inputs, 2)
self.assertEqual('a', net2.layers[0].name)
self.assertEqual('b', net2.layers[1].name)
@ -1180,8 +1184,8 @@ class DeferredModeTest(keras_parameterized.TestCase):
self.assertEqual(x.shape.as_list(), [None, 2])
outputs = layers.Dense(4)(x)
network = network_lib.Network(inputs, outputs)
self.assertIsInstance(network, network_lib.Network)
network = functional.Functional(inputs, outputs)
self.assertIsInstance(network, functional.Functional)
if context.executing_eagerly():
# It should be possible to call such a network on EagerTensors.
@ -1204,7 +1208,7 @@ class DeferredModeTest(keras_parameterized.TestCase):
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
c = layers.Dense(2)(c)
network = network_lib.Network([input_a, input_b], [a, c])
network = functional.Functional([input_a, input_b], [a, c])
if context.executing_eagerly():
a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
@ -1484,9 +1488,9 @@ class NestedNetworkTest(keras_parameterized.TestCase):
'x2': input_layer_lib.Input(shape=(1,))
}
outputs = layers.Add()([inputs['x1'], inputs['x2']])
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
network = network_lib.Network.from_config(network.get_config())
network = functional.Functional.from_config(network.get_config())
result_tensor = network({
'x': array_ops.ones((1, 1), 'float32'),
@ -1509,9 +1513,9 @@ class NestedNetworkTest(keras_parameterized.TestCase):
'x*x': layers.Multiply()([inputs, inputs])
}
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
network = network_lib.Network.from_config(network.get_config())
network = functional.Functional.from_config(network.get_config())
result_tensor = network(array_ops.ones((1, 1), 'float32'))
result = self.evaluate(result_tensor)
@ -1531,7 +1535,8 @@ class NestedNetworkTest(keras_parameterized.TestCase):
'x1+x2': layers.Add()([inner_inputs['x1'], inner_inputs['x2']]),
'x1*x2': layers.Multiply()([inner_inputs['x1'], inner_inputs['x2']])
}
inner_network = network_lib.Network(inner_inputs, inner_outputs)
inner_network = functional.Functional(
inner_inputs, inner_outputs)
inputs = [
input_layer_lib.Input(shape=(1,)),
@ -1539,9 +1544,9 @@ class NestedNetworkTest(keras_parameterized.TestCase):
]
middle = inner_network({'x1': inputs[0], 'x2': inputs[1]})
outputs = layers.Add()([middle['x1+x2'], middle['x1*x2']])
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
network = network_lib.Network.from_config(network.get_config())
network = functional.Functional.from_config(network.get_config())
# Computes: `(x1+x2) + (x1*x2)`
result_tensor = network(
@ -1735,13 +1740,13 @@ class DTypeTest(keras_parameterized.TestCase):
def test_graph_network_dtype(self):
inputs = input_layer_lib.Input((10,))
outputs = layers.Dense(10)(inputs)
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
self.assertEqual(network.dtype, 'float32')
@testing_utils.enable_v2_dtype_behavior
def test_subclassed_network_dtype(self):
class IdentityNetwork(network_lib.Network):
class IdentityNetwork(training_lib.Model):
def call(self, inputs):
return inputs
@ -1785,11 +1790,11 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
def layer_and_network_test(self):
# Top level layer
network = network_lib.Network()
network = functional.Functional()
layer_0 = AttrTrackingLayer()
sub_network = network_lib.Network()
sub_network = functional.Functional()
layer_1 = AttrTrackingLayer(dynamic=True)
layer_2 = AttrTrackingLayer()
sub_network.sub_layers = [layer_1, layer_2]
@ -1887,7 +1892,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
x = input_layer_lib.Input(shape=(None, 32))
dense = layers.Dense(2)
y = dense(x)
network = network_lib.Network(x, y, name='dense_network')
network = functional.Functional(x, y, name='dense_network')
for i in range(999, 1024):
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
@ -1895,7 +1900,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
def test_2d_inputs_squeezed_to_1d(self):
input_1d = input_layer_lib.Input(shape=())
outputs = input_1d * 2.
net = network_lib.Network(input_1d, outputs)
net = functional.Functional(input_1d, outputs)
x = np.ones((10, 1))
y = net(x)
@ -1904,7 +1909,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
def test_1d_inputs_expanded_to_2d(self):
input_1d = input_layer_lib.Input(shape=(1,))
outputs = input_1d * 2.
net = network_lib.Network(input_1d, outputs)
net = functional.Functional(input_1d, outputs)
x = np.ones((10,))
y = net(x)
@ -1927,14 +1932,14 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=True)
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
# Hard-coded value passed during construction is respected.
self.assertAllEqual(network(x, training=False), x)
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=False)
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
network(x, training=True)
# Hard-coded value passed during construction is respected.
@ -1942,7 +1947,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=None)
network = network_lib.Network(inputs, outputs)
network = functional.Functional(inputs, outputs)
# `None` value passed during construction is overridden.
self.assertAllEqual(network(x, training=True), x)

View File

@ -26,8 +26,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import layers as layer_module
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils
@ -35,7 +35,6 @@ from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.deprecation import deprecated
@ -48,7 +47,7 @@ SINGLE_LAYER_OUTPUT_ERROR_MSG = ('All layers in a Sequential model should have '
@keras_export('keras.Sequential', 'keras.models.Sequential')
class Sequential(training.Model):
class Sequential(functional.Functional):
"""`Sequential` groups a linear stack of layers into a `tf.keras.Model`.
`Sequential` provides training and inference features on this model.
@ -113,7 +112,9 @@ class Sequential(training.Model):
layers: Optional list of layers to add to the model.
name: Optional name for the model.
"""
super(Sequential, self).__init__(name=name, autocast=False)
# Skip the init in FunctionalModel since model doesn't have input/output yet
super(functional.Functional, self).__init__( # pylint: disable=bad-super-call
name=name, autocast=False)
self.supports_masking = True
self._compute_output_and_mask_jointly = True
self._auto_track_sub_layers = False
@ -152,11 +153,6 @@ class Sequential(training.Model):
return layers[1:]
return layers[:]
@property
@trackable_layer_utils.cache_recursive_attribute('dynamic')
def dynamic(self):
return any(layer.dynamic for layer in self.layers)
@trackable.no_automatic_dependency_tracking
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
@ -233,7 +229,7 @@ class Sequential(training.Model):
self.built = True
if set_inputs or self._graph_initialized:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self._init_graph_network(self.inputs, self.outputs)
self._graph_initialized = True
else:
self._layers.append(layer)
@ -267,7 +263,7 @@ class Sequential(training.Model):
elif self._graph_initialized:
self.layers[-1]._outbound_nodes = []
self.outputs = [self.layers[-1].output]
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self._init_graph_network(self.inputs, self.outputs)
self.built = True
@trackable.no_automatic_dependency_tracking
@ -341,7 +337,7 @@ class Sequential(training.Model):
# case, we fall back to the legacy deferred behavior.
# TODO(fchollet): consider raising here, as we should not be
# supporting such layers.
self._init_graph_network(inputs, outputs, name=self.name)
self._init_graph_network(inputs, outputs)
self._graph_initialized = True
except: # pylint:disable=bare-except
self._use_legacy_deferred_behavior = True
@ -350,7 +346,7 @@ class Sequential(training.Model):
@generic_utils.default
def build(self, input_shape=None):
if self._graph_initialized:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self._init_graph_network(self.inputs, self.outputs)
else:
if input_shape is None:
raise ValueError('You must provide an `input_shape` argument.')
@ -380,7 +376,7 @@ class Sequential(training.Model):
if self._graph_initialized:
if not self.built:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self._init_graph_network(self.inputs, self.outputs)
return super(Sequential, self).call(inputs, training=training, mask=mask)
outputs = inputs # handle the corner case where self.layers is empty
@ -519,6 +515,13 @@ class Sequential(training.Model):
return False
return True
def _assert_weights_created(self):
if self._graph_initialized:
return
# When the graph has not been initialized, use the Model's implementation to
# to check if the weights has been created.
super(functional.Functional, self)._assert_weights_created() # pylint: disable=bad-super-call
def _get_shape_tuple(t):
if hasattr(t, 'shape'):

View File

@ -20,6 +20,9 @@ from __future__ import print_function
import copy
import itertools
import json
import os
import six
from tensorflow.python.autograph.lang import directives
from tensorflow.python.distribute import distribute_coordinator as dc
@ -31,19 +34,31 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import compile_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import network
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -52,12 +67,33 @@ from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
# pylint: disable=g-import-not-at-top
try:
import h5py
except ImportError:
h5py = None
try:
import yaml
except ImportError:
yaml = None
# pylint: enable=g-import-not-at-top
_keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras',
@ -97,8 +133,25 @@ def disable_multi_worker(method):
target=method, decorator_func=_method_wrapper)
def inject_functional_model_class(cls):
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top
if cls == Model or cls == training_v1.Model:
return functional.Functional
cls.__bases__ = tuple(inject_functional_model_class(base)
for base in cls.__bases__)
return cls
def is_functional_model_init_params(args, kwargs):
return (len(args) == 2 or
len(args) == 1 and 'outputs' in kwargs or
'inputs' in kwargs and 'outputs' in kwargs)
@keras_export('keras.Model', 'keras.models.Model')
class Model(network.Network, version_utils.ModelVersionSelector):
class Model(base_layer.Layer, version_utils.ModelVersionSelector):
"""`Model` groups layers into an object with training and inference features.
Arguments:
@ -174,11 +227,61 @@ class Model(network.Network, version_utils.ModelVersionSelector):
_TF_MODULE_IGNORED_PROPERTIES = frozenset(
itertools.chain(('_train_counter', '_test_counter', '_predict_counter',
'_steps_per_execution'),
network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access
base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access
def __new__(cls, *args, **kwargs):
# Signature detection
if is_functional_model_init_params(args, kwargs) and cls == Model:
# Functional model
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
return functional.Functional(*args, **kwargs)
else:
return super(Model, cls).__new__(cls, *args, **kwargs)
@trackable.no_automatic_dependency_tracking
def __init__(self, *args, **kwargs):
super(Model, self).__init__(*args, **kwargs)
_keras_api_gauge.get_cell('model').set(True)
# Special case for Subclassed Functional Model, which we couldn't detect
# when __new__ is called. We only realize it is a functional model when it
# calls super.__init__ with input and output tensor.
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
if (is_functional_model_init_params(args, kwargs) and
not isinstance(self, functional.Functional)):
inject_functional_model_class(self.__class__)
functional.Functional.__init__(self, *args, **kwargs)
return
# The following are implemented as property functions:
# self.trainable_weights
# self.non_trainable_weights
generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic',
'name', 'autocast'})
super(Model, self).__init__(**kwargs)
# By default, Model is a subclass model, which is not in graph network.
self._is_graph_network = False
self.inputs = None
self.outputs = None
self.input_names = None
self.output_names = None
# stop_training is used by callback to stop training when error happens
self.stop_training = False
self.history = None
# These objects are used in the default `Model.compile`. They are not
# guaranteed to be set after `Model.compile` is called, as users can
# override compile with custom logic.
self.compiled_loss = None
self.compiled_metrics = None
# This is True for Sequential networks and Functional networks.
self._compute_output_and_mask_jointly = False
# Don't reset compilation if already done. This may occur if calling
# `__init__` (or `_init_graph_network`) on an already-compiled model
# such as a Sequential model. Sequential models may need to rebuild
# themselves after compilation.
self._maybe_create_attribute('_is_compiled', False)
self._maybe_create_attribute('optimizer', None)
# Model must be created under scope of DistStrat it will be trained with.
if ds_context.has_strategy():
self._distribution_strategy = ds_context.get_strategy()
@ -186,23 +289,20 @@ class Model(network.Network, version_utils.ModelVersionSelector):
self._distribution_strategy = None
# Defaults to value of `tf.config.experimental_functions_run_eagerly`.
self._run_eagerly = None
self.stop_training = False
# Initialize cache attrs.
self._reset_compile_cache()
# Fault-tolerance handler. Set in `ModelCheckpoint`.
self._training_state = None
self.history = None
# These objects are used in the default `Model.compile`. They are not
# guaranteed to be set after `Model.compile` is called, as users can
# override compile with custom logic.
self.compiled_loss = None
self.compiled_metrics = None
self._saved_model_inputs_spec = None
self._trackable_saver = (
trackable_utils.saver_with_op_caching(self))
self._steps_per_execution = None
self._init_batch_counters()
self._base_model_initialized = True
_keras_api_gauge.get_cell('model').set(True)
@trackable.no_automatic_dependency_tracking
def _init_batch_counters(self):
@ -214,67 +314,153 @@ class Model(network.Network, version_utils.ModelVersionSelector):
self._predict_counter = variables.Variable(
0, dtype='int64', aggregation=agg)
def get_weights(self):
"""Retrieves the weights of the model.
def __setattr__(self, name, value):
if not getattr(self, '_self_setattr_tracking', True):
super(Model, self).__setattr__(name, value)
return
Returns:
A flat list of Numpy arrays.
"""
with self.distribute_strategy.scope():
return super(Model, self).get_weights()
if all(
isinstance(v, (base_layer.Layer,
data_structures.TrackableDataStructure)) or
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
try:
self._base_model_initialized
except AttributeError:
# six.raise_from supresses the original AttributeError from being raised
six.raise_from(
RuntimeError('It looks like you are subclassing `Model` and you '
'forgot to call `super(YourClass, self).__init__()`.'
' Always start with this line.'), None)
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
super(Model, self).__setattr__(name, value)
If `by_name` is False weights are loaded based on the network's
topology. This means the architecture should be the same as when the weights
were saved. Note that layers that don't have weights are not taken into
account in the topological ordering, so adding or removing layers is fine as
long as they don't have weights.
# Keep track of metric instance created in subclassed model/layer.
# We do this so that we can maintain the correct order of metrics by adding
# the instance to the `metrics` list as soon as it is created.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
if isinstance(value, metrics_module.Metric):
self._metrics.append(value)
If `by_name` is True, weights are loaded into layers only if they share the
same name. This is useful for fine-tuning or transfer-learning models where
some of the layers have changed.
@generic_utils.default
def build(self, input_shape):
"""Builds the model based on input shapes received.
Only topological loading (`by_name=False`) is supported when loading weights
from the TensorFlow format. Note that topological loading differs slightly
between TensorFlow and HDF5 formats for user-defined classes inheriting from
`tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
TensorFlow format loads based on the object-local names of attributes to
which layers are assigned in the `Model`'s constructor.
This is to be used for subclassed models, which do not know at instantiation
time what their inputs look like.
Arguments:
filepath: String, path to the weights file to load. For weight files in
TensorFlow format, this is the file prefix (the same as was passed
to `save_weights`).
by_name: Boolean, whether to load weights by name or by topological
order. Only topological loading is supported for weight files in
TensorFlow format.
skip_mismatch: Boolean, whether to skip loading of layers where there is
a mismatch in the number of weights, or a mismatch in the shape of
the weight (only valid when `by_name=True`).
This method only exists for users who want to call `model.build()` in a
standalone way (as a substitute for calling the model on real data to
build it). It will never be called by the framework (and thus it will
never throw unexpected errors in an unrelated workflow).
Returns:
When loading a weight file in TensorFlow format, returns the same status
object as `tf.train.Checkpoint.restore`. When graph building, restore
ops are run automatically as soon as the network is built (on first call
for user-defined classes inheriting from `Model`, immediately if it is
already built).
When loading weights in HDF5 format, returns `None`.
Args:
input_shape: Single tuple, TensorShape, or list of shapes, where shapes
are tuples, integers, or TensorShapes.
Raises:
ImportError: If h5py is not available and the weight file is in HDF5
format.
ValueError: If `skip_mismatch` is set to `True` when `by_name` is
`False`.
ValueError:
1. In case of invalid user-provided data (not of type tuple,
list, or TensorShape).
2. If the model requires call arguments that are agnostic
to the input shapes (positional or kwarg in call signature).
3. If not all layers were properly built.
4. If float type inputs are not supported within the layers.
In each of these cases, the user should build their model by calling it
on real tensor data.
"""
if dist_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and
(not network._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.')
return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
if self._is_graph_network:
super(Model, self).build(input_shape)
return
if input_shape is None:
raise ValueError('Input shape must be defined when calling build on a '
'model subclass network.')
valid_types = (tuple, list, tensor_shape.TensorShape)
if not isinstance(input_shape, valid_types):
raise ValueError('Specified input shape is not one of the valid types. '
'Please specify a batch input shape of type tuple or '
'list of input shapes. User provided '
'input type: {}'.format(type(input_shape)))
if input_shape and not self.inputs:
# We create placeholders for the `None`s in the shape and build the model
# in a Graph. Since tf.Variable is compatible with both eager execution
# and graph building, the variables created after building the model in
# a Graph are still valid when executing eagerly.
if context.executing_eagerly():
graph = func_graph.FuncGraph('build_graph')
else:
graph = backend.get_graph()
with graph.as_default():
if isinstance(input_shape, list):
x = [base_layer_utils.generate_placeholders_from_shape(shape)
for shape in input_shape]
elif isinstance(input_shape, dict):
x = {
k: base_layer_utils.generate_placeholders_from_shape(shape)
for k, shape in input_shape.items()
}
else:
x = base_layer_utils.generate_placeholders_from_shape(input_shape)
kwargs = {}
call_signature = self._call_full_argspec
call_args = call_signature.args
# Exclude `self`, `inputs`, and any argument with a default value.
if len(call_args) > 2:
if call_signature.defaults:
call_args = call_args[2:-len(call_signature.defaults)]
else:
call_args = call_args[2:]
for arg in call_args:
if arg == 'training':
# Case where `training` is a positional arg with no default.
kwargs['training'] = False
else:
# Has invalid call signature with unknown positional arguments.
raise ValueError(
'Currently, you cannot build your model if it has '
'positional or keyword arguments that are not '
'inputs to the model, but are required for its '
'`call` method. Instead, in order to instantiate '
'and build your model, `call` your model on real '
'tensor data with all expected call arguments.')
elif len(call_args) < 2:
# Signature without `inputs`.
raise ValueError('You can only call `build` on a model if its `call` '
'method accepts an `inputs` argument.')
try:
self.call(x, **kwargs)
except (errors.InvalidArgumentError, TypeError):
raise ValueError('You cannot build your model by calling `build` '
'if your layers do not support float type inputs. '
'Instead, in order to instantiate and build your '
'model, `call` your model on real tensor data (of '
'the correct dtype).')
super(Model, self).build(input_shape)
def call(self, inputs, training=None, mask=None):
"""Calls the model on new inputs.
In this case `call` just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Arguments:
inputs: A tensor or list of tensors.
training: Boolean or boolean scalar tensor, indicating whether to run
the `Network` in training mode or inference mode.
mask: A mask or list of masks. A mask can be
either a tensor or None (no mask).
Returns:
A tensor if there is a single output, or
a list of tensors if there are more than one outputs.
"""
raise NotImplementedError('When subclassing the `Model` class, you should '
'implement a `call` method.')
def compile(self,
optimizer='rmsprop',
@ -399,6 +585,10 @@ class Model(network.Network, version_utils.ModelVersionSelector):
dtype='int64',
aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
@property
def _should_compute_mask(self):
return False
@property
def metrics(self):
"""Returns the model's metrics added using `compile`, `add_metric` APIs.
@ -1661,6 +1851,564 @@ class Model(network.Network, version_utils.ModelVersionSelector):
verbose=verbose,
callbacks=callbacks)
######################################################################
# Functions below are not training related. They are for model weights
# tracking, save/load, serialization, etc.
######################################################################
@property
def trainable_weights(self):
self._assert_weights_created()
return self._dedup_weights(
trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights))
@property
def non_trainable_weights(self):
self._assert_weights_created()
return self._dedup_weights(
trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights +
self._trainable_weights))
def get_weights(self):
"""Retrieves the weights of the model.
Returns:
A flat list of Numpy arrays.
"""
with self.distribute_strategy.scope():
return super(Model, self).get_weights()
def save(self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None):
"""Saves the model to Tensorflow SavedModel or a single HDF5 file.
The savefile includes:
- The model architecture, allowing to re-instantiate the model.
- The model weights.
- The state of the optimizer, allowing to resume training
exactly where you left off.
This allows you to save the entirety of the state of a model
in a single file.
Saved models can be reinstantiated via `keras.models.load_model`.
The model returned by `load_model` is a compiled model ready to be used
(unless the saved model was never compiled in the first place).
Models built with the Sequential and Functional API can be saved to both the
HDF5 and SavedModel formats. Subclassed models can only be saved with the
SavedModel format.
Note that the model weights may have different scoped names after being
loaded. Scoped names include the model/layer names, such as
`"dense_1/kernel:0"`. It is recommended that you use the layer properties to
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
Arguments:
filepath: String, PathLike, path to SavedModel or H5 file to save the
model.
overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
save_format: Either `'tf'` or `'h5'`, indicating whether to save the
model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
and 'h5' in TF 1.X.
signatures: Signatures to save with the SavedModel. Applicable to the
'tf' format only. Please see the `signatures` argument in
`tf.saved_model.save` for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies
options for saving to SavedModel.
Example:
```python
from keras.models import load_model
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del model # deletes the existing model
# returns a compiled model
# identical to the previous one
model = load_model('my_model.h5')
```
"""
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
signatures, options)
def save_weights(self, filepath, overwrite=True, save_format=None):
"""Saves all layer weights.
Either saves in HDF5 or in TensorFlow format based on the `save_format`
argument.
When saving in HDF5 format, the weight file has:
- `layer_names` (attribute), a list of strings
(ordered names of model layers).
- For every layer, a `group` named `layer.name`
- For every such layer group, a group attribute `weight_names`,
a list of strings
(ordered names of weights tensor of the layer).
- For every weight in the layer, a dataset
storing the weight value, named after the weight tensor.
When saving in TensorFlow format, all objects referenced by the network are
saved in the same format as `tf.train.Checkpoint`, including any `Layer`
instances or `Optimizer` instances assigned to object attributes. For
networks constructed from inputs and outputs using `tf.keras.Model(inputs,
outputs)`, `Layer` instances used by the network are tracked/saved
automatically. For user-defined classes which inherit from `tf.keras.Model`,
`Layer` instances must be assigned to object attributes, typically in the
constructor. See the documentation of `tf.train.Checkpoint` and
`tf.keras.Model` for details.
While the formats are the same, do not mix `save_weights` and
`tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be
loaded using `Model.load_weights`. Checkpoints saved using
`tf.train.Checkpoint.save` should be restored using the corresponding
`tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
`save_weights` for training checkpoints.
The TensorFlow format matches objects and variables by starting at a root
object, `self` for `save_weights`, and greedily matching attribute
names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this
is the `Checkpoint` even if the `Checkpoint` has a model attached. This
means saving a `tf.keras.Model` using `save_weights` and loading into a
`tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match
the `Model`'s variables. See the [guide to training
checkpoints](https://www.tensorflow.org/guide/checkpoint) for details
on the TensorFlow format.
Arguments:
filepath: String or PathLike, path to the file to save the weights to.
When saving in TensorFlow format, this is the prefix used for
checkpoint files (multiple files are generated). Note that the '.h5'
suffix causes weights to be saved in HDF5 format.
overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt.
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
'.keras' will default to HDF5 if `save_format` is `None`. Otherwise
`None` defaults to 'tf'.
Raises:
ImportError: If h5py is not available when attempting to save in HDF5
format.
ValueError: For invalid/unknown format arguments.
"""
self._assert_weights_created()
filepath = path_to_string(filepath)
filepath_is_h5 = _is_hdf5_filepath(filepath)
if save_format is None:
if filepath_is_h5:
save_format = 'h5'
else:
save_format = 'tf'
else:
user_format = save_format.lower().strip()
if user_format in ('tensorflow', 'tf'):
save_format = 'tf'
elif user_format in ('hdf5', 'h5', 'keras'):
save_format = 'h5'
else:
raise ValueError(
'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
save_format,))
if save_format == 'tf' and filepath_is_h5:
raise ValueError(
('save_weights got save_format="tf"/"tensorflow", but the '
'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
'when saving in TensorFlow format.')
% filepath)
if save_format == 'h5' and h5py is None:
raise ImportError(
'`save_weights` requires h5py when saving in hdf5.')
if save_format == 'tf':
check_filepath = filepath + '.index'
else:
check_filepath = filepath
# If file exists and should not be overwritten:
if not overwrite and os.path.isfile(check_filepath):
proceed = ask_to_proceed_with_overwrite(check_filepath)
if not proceed:
return
if save_format == 'h5':
with h5py.File(filepath, 'w') as f:
hdf5_format.save_weights_to_hdf5_group(f, self.layers)
else:
if context.executing_eagerly():
session = None
else:
session = backend.get_session()
optimizer = getattr(self, 'optimizer', None)
if (optimizer
and not isinstance(optimizer, trackable.Trackable)):
logging.warning(
('This model was compiled with a Keras optimizer (%s) but is being '
'saved in TensorFlow format with `save_weights`. The model\'s '
'weights will be saved, but unlike with TensorFlow optimizers in '
'the TensorFlow format the optimizer\'s state will not be '
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
% (optimizer,))
self._trackable_saver.save(filepath, session=session)
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(filepath),
model_checkpoint_path=filepath,
save_relative_paths=True,
all_model_checkpoint_paths=[filepath])
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
If `by_name` is False weights are loaded based on the network's
topology. This means the architecture should be the same as when the weights
were saved. Note that layers that don't have weights are not taken into
account in the topological ordering, so adding or removing layers is fine as
long as they don't have weights.
If `by_name` is True, weights are loaded into layers only if they share the
same name. This is useful for fine-tuning or transfer-learning models where
some of the layers have changed.
Only topological loading (`by_name=False`) is supported when loading weights
from the TensorFlow format. Note that topological loading differs slightly
between TensorFlow and HDF5 formats for user-defined classes inheriting from
`tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
TensorFlow format loads based on the object-local names of attributes to
which layers are assigned in the `Model`'s constructor.
Arguments:
filepath: String, path to the weights file to load. For weight files in
TensorFlow format, this is the file prefix (the same as was passed
to `save_weights`).
by_name: Boolean, whether to load weights by name or by topological
order. Only topological loading is supported for weight files in
TensorFlow format.
skip_mismatch: Boolean, whether to skip loading of layers where there is
a mismatch in the number of weights, or a mismatch in the shape of
the weight (only valid when `by_name=True`).
Returns:
When loading a weight file in TensorFlow format, returns the same status
object as `tf.train.Checkpoint.restore`. When graph building, restore
ops are run automatically as soon as the network is built (on first call
for user-defined classes inheriting from `Model`, immediately if it is
already built).
When loading weights in HDF5 format, returns `None`.
Raises:
ImportError: If h5py is not available and the weight file is in HDF5
format.
ValueError: If `skip_mismatch` is set to `True` when `by_name` is
`False`.
"""
if dist_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and
(not _is_hdf5_filepath(filepath))):
raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.')
if skip_mismatch and not by_name:
raise ValueError(
'When calling model.load_weights, skip_mismatch can only be set to '
'True when by_name is True.')
filepath = path_to_string(filepath)
if _is_hdf5_filepath(filepath):
save_format = 'h5'
else:
try:
py_checkpoint_reader.NewCheckpointReader(filepath)
save_format = 'tf'
except errors_impl.DataLossError:
# The checkpoint is not readable in TensorFlow format. Try HDF5.
save_format = 'h5'
if save_format == 'tf':
status = self._trackable_saver.restore(filepath)
if by_name:
raise NotImplementedError(
'Weights may only be loaded based on topology into Models when '
'loading TensorFlow-formatted weights (got by_name=True to '
'load_weights).')
if not context.executing_eagerly():
session = backend.get_session()
# Restore existing variables (if any) immediately, and set up a
# streaming restore for any variables created in the future.
trackable_utils.streaming_restore(status=status, session=session)
status.assert_nontrivial_match()
return status
if h5py is None:
raise ImportError(
'`load_weights` requires h5py when loading weights from HDF5.')
if not self._is_graph_network and not self.built:
raise ValueError(
'Unable to load weights saved in HDF5 format into a subclassed '
'Model which has not created its variables yet. Call the Model '
'first, then load the weights.')
self._assert_weights_created()
with h5py.File(filepath, 'r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
if by_name:
hdf5_format.load_weights_from_hdf5_group_by_name(
f, self.layers, skip_mismatch=skip_mismatch)
else:
hdf5_format.load_weights_from_hdf5_group(f, self.layers)
def _updated_config(self):
"""Util shared between different serialization methods.
Returns:
Model config with Keras version information added.
"""
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
config = self.get_config()
model_config = {
'class_name': self.__class__.__name__,
'config': config,
'keras_version': keras_version,
'backend': backend.backend()
}
return model_config
def get_config(self):
raise NotImplementedError
@classmethod
def from_config(cls, config, custom_objects=None):
# Since only FunctionalModel produces config, the model can only
# be constructed for FunctionalModel
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
return functional.Functional.from_config(
config, custom_objects=custom_objects)
def to_json(self, **kwargs):
"""Returns a JSON string containing the network configuration.
To load a network from a JSON save file, use
`keras.models.model_from_json(json_string, custom_objects={})`.
Arguments:
**kwargs: Additional keyword arguments
to be passed to `json.dumps()`.
Returns:
A JSON string.
"""
model_config = self._updated_config()
return json.dumps(
model_config, default=serialization.get_json_type, **kwargs)
def to_yaml(self, **kwargs):
"""Returns a yaml string containing the network configuration.
To load a network from a yaml save file, use
`keras.models.model_from_yaml(yaml_string, custom_objects={})`.
`custom_objects` should be a dictionary mapping
the names of custom losses / layers / etc to the corresponding
functions / classes.
Arguments:
**kwargs: Additional keyword arguments
to be passed to `yaml.dump()`.
Returns:
A YAML string.
Raises:
ImportError: if yaml module is not found.
"""
if yaml is None:
raise ImportError(
'Requires yaml module installed (`pip install pyyaml`).')
return yaml.dump(self._updated_config(), **kwargs)
def reset_states(self):
for layer in self.layers:
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
layer.reset_states()
@property
@deprecation.deprecated(
date=None,
instructions='This property should not be used in TensorFlow 2.0, '
'as updates are applied automatically.')
@doc_controls.do_not_generate_docs
def state_updates(self):
"""Deprecated, do NOT use!
Returns the `updates` from all layers that are stateful.
This is useful for separating training updates and
state updates, e.g. when we need to update a layer's internal state
during prediction.
Returns:
A list of update ops.
"""
state_updates = []
for layer in self.layers:
if getattr(layer, 'stateful', False):
if hasattr(layer, 'updates'):
state_updates += layer.updates
return state_updates
@property
def weights(self):
"""Returns the list of all layer variables/weights.
Returns:
A list of variables.
"""
return self._dedup_weights(self._undeduplicated_weights)
@property
def _undeduplicated_weights(self):
"""Returns the undeduplicated list of all layer variables/weights."""
self._assert_weights_created()
weights = []
for layer in self._layers:
weights += layer.weights
weights += (self._trainable_weights + self._non_trainable_weights)
return weights
def summary(self, line_length=None, positions=None, print_fn=None):
"""Prints a string summary of the network.
Arguments:
line_length: Total length of printed lines
(e.g. set this to adapt the display to different
terminal window sizes).
positions: Relative or absolute positions of log elements
in each line. If not provided,
defaults to `[.33, .55, .67, 1.]`.
print_fn: Print function to use. Defaults to `print`.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
Raises:
ValueError: if `summary()` is called before the model is built.
"""
if not self.built:
raise ValueError('This model has not yet been built. '
'Build the model first by calling `build()` or calling '
'`fit()` with some data, or specify '
'an `input_shape` argument in the first layer(s) for '
'automatic build.')
layer_utils.print_summary(self,
line_length=line_length,
positions=positions,
print_fn=print_fn)
@property
def layers(self):
return self._unique_sublayers()
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
If `name` and `index` are both provided, `index` will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Arguments:
name: String, name of layer.
index: Integer, index of layer.
Returns:
A layer instance.
Raises:
ValueError: In case of invalid layer name or index.
"""
# TODO(fchollet): We could build a dictionary based on layer names
# since they are constant, but we have not done that yet.
if index is not None and name is not None:
raise ValueError('Provide only a layer name or a layer index.')
if index is not None:
if len(self.layers) <= index:
raise ValueError('Was asked to retrieve layer at index ' + str(index) +
' but model only has ' + str(len(self.layers)) +
' layers.')
else:
return self.layers[index]
if name is not None:
for layer in self.layers:
if layer.name == name:
return layer
raise ValueError('No such layer: ' + name + '.')
raise ValueError('Provide either a layer name or layer index.')
@trackable.no_automatic_dependency_tracking
def _set_save_spec(self, inputs):
if self._saved_model_inputs_spec is not None:
return # Already set.
input_names = self.input_names
if not input_names:
input_names = compile_utils.create_pseudo_input_names(inputs)
flat_inputs = nest.flatten(inputs)
specs = []
for name, tensor in zip(input_names, flat_inputs):
specs.append(
tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
specs = nest.pack_sequence_as(inputs, specs)
self._saved_model_inputs_spec = specs
def _get_save_spec(self, dynamic_batch=True):
if self._saved_model_inputs_spec is None:
return None
return nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_inputs_spec)
def _assert_weights_created(self):
"""Asserts that all the weights for the model have been created.
For a non-dynamic model, the weights must already be created after the
layer has been called. For a dynamic model, the exact list of weights can
never be known for certain since it may change at any time during execution.
We run this check right before accessing weights or getting the Numpy value
for the current weights. Otherwise, if the layer has never been called,
the user would just get an empty list, which is misleading.
Raises:
ValueError: if the weights of the network has not yet been created.
"""
if self.dynamic:
return
if ('build' in self.__class__.__dict__ and
self.__class__ != Model and
not self.built):
# For any model that has customized build() method but hasn't
# been invoked yet, this will cover both sequential and subclass model.
# Also make sure to exclude Model class itself which has build() defined.
raise ValueError('Weights for model %s have not yet been created. '
'Weights are created when the Model is first called on '
'inputs or `build()` is called with an `input_shape`.' %
self.name)
def _check_call_args(self, method_name):
"""Check that `call` has only one positional arg."""
# Always allow first arg, regardless of arg name.
@ -1990,3 +2738,8 @@ def _disallow_inside_tf_function(method_name):
'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
).format(method_name=method_name)
raise RuntimeError(error_msg)
def _is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
filepath.endswith('.hdf5'))

View File

@ -43,7 +43,7 @@ from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.engine import network
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_distributed
@ -181,8 +181,8 @@ class Model(training_lib.Model):
self._compile_time_distribution_strategy)
if strategy:
with strategy.scope():
return network.Network.get_weights(self)
return network.Network.get_weights(self)
return base_layer.Layer.get_weights(self)
return base_layer.Layer.get_weights(self)
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
@ -232,7 +232,7 @@ class Model(training_lib.Model):
"""
if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and
(not network._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
(not training_lib._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.')
return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
@ -491,6 +491,11 @@ class Model(training_lib.Model):
"""Returns the model's metrics added using `compile`, `add_metric` APIs."""
metrics = []
if self._is_compiled:
if not hasattr(self, '_v1_compile_was_called'):
# See b/155687393 for more details, the model is created as a v2
# instance but converted to v1. Fallback to use base Model to retrieve
# the metrics.
return super(Model, self).metrics
metrics += self._compile_metric_functions
metrics.extend(self._metrics)
metrics.extend(_get_metrics_from_layers(self._layers))
@ -504,6 +509,12 @@ class Model(training_lib.Model):
# losses for backward compatibility.
metrics_names = ['loss']
if self._is_compiled:
if not hasattr(self, '_v1_compile_was_called'):
# See b/155687393 for more details, the model is created as a v2
# instance but converted to v1. Fallback to use base Model to retrieve
# the metrics name
return super(Model, self).metrics_names
# Add output loss metric names to the metric names list.
if len(self._training_endpoints) > 1:
metrics_names.extend([

View File

@ -114,7 +114,7 @@ def populate_deserializable_objects():
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
LOCAL.ALL_OBJECTS['Network'] = models.Network
LOCAL.ALL_OBJECTS['Functional'] = models.Functional
LOCAL.ALL_OBJECTS['Model'] = models.Model
LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential

View File

@ -377,7 +377,8 @@ class TimeDistributedTest(keras_parameterized.TestCase):
input_layer.compute_output_shape([None, 2, 4]).as_list(),
[None, 2, 8])
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
# TODO(scottzhu): check why v1 session failed.
def test_TimeDistributed_with_mask_first_implementation(self):
np.random.seed(100)
rnn_layer = keras.layers.LSTM(4, return_sequences=True, stateful=True)

View File

@ -23,7 +23,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import network
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training_v1
@ -31,7 +31,6 @@ from tensorflow.python.keras.engine.base_layer import AddMetric
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.saving import model_config
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.utils import generic_utils
@ -45,6 +44,7 @@ from tensorflow.python.util.tf_export import keras_export
# API entries importable from `keras.models`:
Model = training.Model # pylint: disable=invalid-name
Sequential = sequential.Sequential # pylint: disable=invalid-name
Functional = functional.Functional # pylint: disable=invalid-name
save_model = save.save_model
load_model = save.load_model
model_from_config = model_config.model_from_config
@ -193,12 +193,12 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
model_config, created_layers = _clone_layers_and_model_config(
model_configs, created_layers = _clone_layers_and_model_config(
model, new_input_layers, layer_fn)
# Reconstruct model from the config, using the cloned layers.
input_tensors, output_tensors, created_layers = (
network.reconstruct_from_config(model_config,
created_layers=created_layers))
functional.reconstruct_from_config(model_configs,
created_layers=created_layers))
metrics_names = model.metrics_names
model = Model(input_tensors, output_tensors, name=model.name)
# Layers not directly tied to outputs of the Model, such as loss layers
@ -209,8 +209,8 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
if ancillary_layers:
new_nodes = nest.flatten([
layer.inbound_nodes[1:]
if network._should_skip_first_node(layer) else layer.inbound_nodes
for layer in created_layers.values()
if functional._should_skip_first_node(layer)
else layer.inbound_nodes for layer in created_layers.values()
])
_insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
return model
@ -244,7 +244,8 @@ def _clone_layers_and_model_config(model, input_layers, layer_fn):
created_layers[layer.name] = layer_fn(layer)
return {}
config = network.get_network_config(model, serialize_layer_fn=_copy_layer)
config = functional.get_network_config(
model, serialize_layer_fn=_copy_layer)
return config, created_layers
@ -495,7 +496,7 @@ def _in_place_subclassed_model_reset(model):
# This will not work for nested subclassed models used as layers.
# This would be theoretically possible to support, but would add complexity.
# Only do it if users complain.
if isinstance(layer, Network) and not layer._is_graph_network:
if isinstance(layer, training.Model) and not layer._is_graph_network:
raise ValueError('We do not support the use of nested subclassed models '
'in `model_to_estimator` at this time. Found nested '
'model: %s' % layer)

View File

@ -1210,7 +1210,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
def test_incompatible_checkpoint(self):
save_path = trackable.Checkpoint().save(
os.path.join(self.get_temp_dir(), 'ckpt'))
m = keras.Model()
m = DummySubclassModel()
with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
m.load_weights(save_path)
m.dense = keras.layers.Dense(2)
@ -1222,7 +1222,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_directory_passed(self):
with self.cached_session():
m = keras.Model()
m = DummySubclassModel()
v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(),
@ -1235,7 +1235,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_relative_path(self):
with self.cached_session():
m = keras.Model()
m = DummySubclassModel()
v = m.add_weight(name='v', shape=[])
os.chdir(self.get_temp_dir())
@ -1266,7 +1266,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_nonexistent_prefix_directory(self):
with self.cached_session():
m = keras.Model()
m = DummySubclassModel()
v = m.add_weight(name='v', shape=[])
self.evaluate(v.assign(42.))
prefix = os.path.join(self.get_temp_dir(),
@ -1276,5 +1276,10 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
m.load_weights(prefix)
self.assertEqual(42., self.evaluate(v))
class DummySubclassModel(training.Model):
pass
if __name__ == '__main__':
test.main()

View File

@ -62,9 +62,9 @@ layers_module = LazyLoader(
input_layer = LazyLoader(
"input_layer", globals(),
"tensorflow.python.keras.engine.input_layer")
network_lib = LazyLoader(
"network_lib", globals(),
"tensorflow.python.keras.engine.network")
functional_lib = LazyLoader(
"functional_lib", globals(),
"tensorflow.python.keras.engine.functional")
training_lib = LazyLoader(
"training_lib", globals(),
"tensorflow.python.keras.engine.training")
@ -142,7 +142,7 @@ def _is_graph_network(layer):
# pylint: disable=protected-access
if isinstance(layer, RevivedNetwork):
return False
elif isinstance(layer, network_lib.Network):
elif isinstance(layer, functional_lib.Functional):
return (layer._is_graph_network or
isinstance(layer, models_lib.Sequential))
return False
@ -371,7 +371,8 @@ class KerasObjectLoader(tf_load.Loader):
# functional or Sequential model.
model_is_functional_or_sequential = (
metadata.get('is_graph_network', False) or
metadata['class_name'] == 'Sequential')
metadata['class_name'] == 'Sequential' or
metadata['class_name'] == 'Functional')
if not (generic_utils.validate_config(config) and
model_is_functional_or_sequential):
return None # Revive as custom model.
@ -383,7 +384,8 @@ class KerasObjectLoader(tf_load.Loader):
if class_name == 'Sequential':
model = models_lib.Sequential(name=config['name'])
else:
model = models_lib.Model(name=config['name'])
model = models_lib.Functional(
inputs=[], outputs=[], name=config['name'])
# Record this model and its layers. This will later be used to reconstruct
# the model.
@ -561,10 +563,11 @@ class KerasObjectLoader(tf_load.Loader):
if not model.built and not isinstance(input_specs, dict):
model.build(input_shapes)
else:
(inputs, outputs, created_layers) = network_lib.reconstruct_from_config(
config, created_layers={layer.name: layer for layer in layers})
(inputs, outputs,
created_layers) = functional_lib.reconstruct_from_config(
config, created_layers={layer.name: layer for layer in layers})
model.__init__(inputs, outputs, name=config['name'])
network_lib.connect_ancillary_layers(model, created_layers)
functional_lib.connect_ancillary_layers(model, created_layers)
# Set model dtype and trainable status.
_set_network_attributes_from_metadata(model)
@ -764,7 +767,7 @@ def revive_custom_object(identifier, metadata):
revived_classes = {
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
'_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
'_tf_keras_network': (RevivedNetwork, functional_lib.Functional),
'_tf_keras_model': (RevivedNetwork, model_class),
'_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential),
}
@ -852,7 +855,7 @@ def _revive_setter(layer, name, value):
layer._track_trackable(value, name=name)
layer._serialized_attributes[name] = value
# pylint: enable=protected-access
elif (isinstance(layer, network_lib.Network) and
elif (isinstance(layer, functional_lib.Functional) and
re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
# Edges named "layer-n" or "layer_with_weights-n", which are tracked in
# network._track_layers, should not be added as an attribute.

View File

@ -20,11 +20,11 @@ from __future__ import print_function
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import network_serialization
from tensorflow.python.keras.saving.saved_model import layer_serialization
from tensorflow.python.keras.saving.saved_model import save_impl
class ModelSavedModelSaver(network_serialization.NetworkSavedModelSaver):
class ModelSavedModelSaver(layer_serialization.LayerSavedModelSaver):
"""Model SavedModel serialization."""
@property
@ -33,6 +33,10 @@ class ModelSavedModelSaver(network_serialization.NetworkSavedModelSaver):
def _python_properties_internal(self):
metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
# Network stateful property is dependent on the child layers.
metadata.pop('stateful')
metadata['is_graph_network'] = self.obj._is_graph_network # pylint: disable=protected-access
metadata.update(
saving_utils.model_metadata(
self.obj, include_optimizer=True, require_config=False))

View File

@ -18,22 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.saving.saved_model import layer_serialization
from tensorflow.python.keras.saving.saved_model import model_serialization
# Network serialization is pretty much the same as layer serialization.
class NetworkSavedModelSaver(layer_serialization.LayerSavedModelSaver):
# FunctionalModel serialization is pretty much the same as Model serialization.
class NetworkSavedModelSaver(model_serialization.ModelSavedModelSaver):
"""Network serialization."""
@property
def object_identifier(self):
return '_tf_keras_network'
def _python_properties_internal(self):
metadata = super(NetworkSavedModelSaver, self)._python_properties_internal()
# Network stateful property is dependent on the child layers.
metadata.pop('stateful')
metadata['is_graph_network'] = self.obj._is_graph_network # pylint: disable=protected-access
return metadata

View File

@ -53,12 +53,12 @@ class SplitUtilsTest(keras_parameterized.TestCase):
inputs = keras.Input(10)
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
self._check_model_class(model.__class__)
self._check_model_class(model.__class__.__bases__[0])
self._check_layer_class(model)
def test_sequential_model(self):
model = keras.Sequential([keras.layers.Dense(1)])
model_class = model.__class__.__bases__[0]
model_class = model.__class__.__bases__[0].__bases__[0]
self._check_model_class(model_class)
self._check_layer_class(model)

View File

@ -55,10 +55,10 @@ def check_pydot():
def is_wrapped_model(layer):
from tensorflow.python.keras.engine import network
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.layers import wrappers
return (isinstance(layer, wrappers.Wrapper) and
isinstance(layer.layer, network.Network))
isinstance(layer.layer, functional.Functional))
def add_edge(dot, src, dst):
@ -98,7 +98,7 @@ def model_to_dot(model,
"""
from tensorflow.python.keras.layers import wrappers
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import network
from tensorflow.python.keras.engine import functional
if not check_pydot():
message = (
@ -147,7 +147,8 @@ def model_to_dot(model,
class_name = layer.__class__.__name__
if isinstance(layer, wrappers.Wrapper):
if expand_nested and isinstance(layer.layer, network.Network):
if expand_nested and isinstance(layer.layer,
functional.Functional):
submodel_wrapper = model_to_dot(layer.layer, show_shapes,
show_layer_names, rankdir,
expand_nested,
@ -162,7 +163,7 @@ def model_to_dot(model,
child_class_name = layer.layer.__class__.__name__
class_name = '{}({})'.format(class_name, child_class_name)
if expand_nested and isinstance(layer, network.Network):
if expand_nested and isinstance(layer, functional.Functional):
submodel_not_wrapper = model_to_dot(layer, show_shapes,
show_layer_names, rankdir,
expand_nested,
@ -200,7 +201,8 @@ def model_to_dot(model,
inputlabels,
outputlabels)
if not expand_nested or not isinstance(layer, network.Network):
if not expand_nested or not isinstance(
layer, functional.Functional):
node = pydot.Node(layer_id, label=label)
dot.add_node(node)
@ -218,16 +220,17 @@ def model_to_dot(model,
add_edge(dot, inbound_layer_id, layer_id)
else:
# if inbound_layer is not Model or wrapped Model
if (not isinstance(inbound_layer, network.Network) and
if (not isinstance(inbound_layer,
functional.Functional) and
not is_wrapped_model(inbound_layer)):
# if current layer is not Model or wrapped Model
if (not isinstance(layer, network.Network) and
if (not isinstance(layer, functional.Functional) and
not is_wrapped_model(layer)):
assert dot.get_node(inbound_layer_id)
assert dot.get_node(layer_id)
add_edge(dot, inbound_layer_id, layer_id)
# if current layer is Model
elif isinstance(layer, network.Network):
elif isinstance(layer, functional.Functional):
add_edge(dot, inbound_layer_id,
sub_n_first_node[layer.name].get_name())
# if current layer is wrapped Model
@ -236,9 +239,9 @@ def model_to_dot(model,
name = sub_w_first_node[layer.layer.name].get_name()
add_edge(dot, layer_id, name)
# if inbound_layer is Model
elif isinstance(inbound_layer, network.Network):
elif isinstance(inbound_layer, functional.Functional):
name = sub_n_last_node[inbound_layer.name].get_name()
if isinstance(layer, network.Network):
if isinstance(layer, functional.Functional):
output_name = sub_n_first_node[layer.name].get_name()
add_edge(dot, name, output_name)
else:

View File

@ -1,7 +1,6 @@
path: "tensorflow.keras.Model"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -175,7 +174,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -1,8 +1,8 @@
path: "tensorflow.keras.Sequential"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"

View File

@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.LinearModel"
tf_class {
is_instance: "<class \'tensorflow.python.keras.premade.linear.LinearModel\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -176,7 +175,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.WideDeepModel"
tf_class {
is_instance: "<class \'tensorflow.python.keras.premade.wide_deep.WideDeepModel\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -176,7 +175,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -1,7 +1,6 @@
path: "tensorflow.keras.models.Model"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -175,7 +174,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -1,8 +1,8 @@
path: "tensorflow.keras.models.Sequential"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"

View File

@ -1,7 +1,6 @@
path: "tensorflow.keras.Model"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -175,7 +174,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -1,8 +1,8 @@
path: "tensorflow.keras.Sequential"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"

View File

@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.LinearModel"
tf_class {
is_instance: "<class \'tensorflow.python.keras.premade.linear.LinearModel\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -176,7 +175,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.WideDeepModel"
tf_class {
is_instance: "<class \'tensorflow.python.keras.premade.wide_deep.WideDeepModel\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -176,7 +175,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -1,7 +1,6 @@
path: "tensorflow.keras.models.Model"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
@ -175,7 +174,7 @@ tf_class {
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"

View File

@ -1,8 +1,8 @@
path: "tensorflow.keras.models.Sequential"
tf_class {
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"