Restructure the Keras class hierarchy for Network, Model and Sequential.
The intention of this change is to reduce the code complexity within Keras class, especially for Network, which currently contains logic for both subclass Model and functional Model. After this change, the subclass model and functional model become individual class and become self contained. 1. Model is now the base class for subclass model. It doesn't contains network structure management, and the topology will be created within __init__ and __call__, which is for user to implement. It also contains compile/fit/eval/predict, which is the basic functionality for model training. 2. Functional is created based on existing Network class. It extends the Model, which allows it leverage compile/fit/eval/predict. In addition, it also take input/output as init parameter and manage the network topology. 3. Sequential model is now a subclass of Functional, since it will use Functional's method to manage it topology (layer stacking). Model(input, output) will create a Functional under the hood, and behave the same way as before. PiperOrigin-RevId: 311232972 Change-Id: I6dd32e089cd294d35d5a1f3684e1a1ae1a0ab320
This commit is contained in:
parent
ce5488f85f
commit
bb15c97379
@ -21,8 +21,8 @@ py_library(
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"compile_utils.py",
|
||||
"functional.py",
|
||||
"input_layer.py",
|
||||
"network.py",
|
||||
"node.py",
|
||||
"partial_batch_padding_handler.py",
|
||||
"saving.py",
|
||||
@ -460,9 +460,9 @@ tf_py_test(
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "network_test",
|
||||
name = "functional_test",
|
||||
size = "medium",
|
||||
srcs = ["network_test.py"],
|
||||
srcs = ["functional_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
|
||||
@ -1006,13 +1006,23 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
"""Whether the layer is dynamic (eager-only); set in the constructor."""
|
||||
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
|
||||
# then this cache logic must be updated.
|
||||
return self._dynamic
|
||||
return self._dynamic or any(layer.dynamic
|
||||
for layer in self._unique_sublayers())
|
||||
|
||||
def _unique_sublayers(self):
|
||||
# Model.layers will use this as implementation, but we can't expose this
|
||||
# one as the public property since it might conflict with subclass layers
|
||||
# which also have user defined layers property.
|
||||
self._maybe_create_attribute('_layers', [])
|
||||
return list(
|
||||
trackable_layer_utils.filter_empty_layer_containers(self._layers))
|
||||
|
||||
@property
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
@trackable_layer_utils.cache_recursive_attribute('stateful')
|
||||
def stateful(self):
|
||||
return self._stateful
|
||||
return self._stateful or any(
|
||||
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
|
||||
|
||||
@stateful.setter
|
||||
@trackable_layer_utils.invalidate_recursive_cache('stateful')
|
||||
|
||||
@ -833,13 +833,15 @@ class Layer(base_layer.Layer):
|
||||
def dynamic(self):
|
||||
# NOTE(taylorrobie): Currently self._dynamic is read-only. If that changes
|
||||
# then this cache logic must be updated.
|
||||
return self._dynamic
|
||||
return self._dynamic or any(layer.dynamic
|
||||
for layer in self._unique_sublayers())
|
||||
|
||||
@property
|
||||
@doc_controls.do_not_generate_docs
|
||||
@trackable_layer_utils.cache_recursive_attribute('stateful')
|
||||
def stateful(self):
|
||||
return self._stateful
|
||||
return self._stateful or any(
|
||||
getattr(layer, 'stateful', False) for layer in self._unique_sublayers())
|
||||
|
||||
@stateful.setter
|
||||
@trackable_layer_utils.invalidate_recursive_cache('stateful')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -33,8 +33,8 @@ from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import functional
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_lib
|
||||
from tensorflow.python.keras.engine import network as network_lib
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
@ -89,7 +89,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
|
||||
self.assertEqual(len(layer.updates), 3)
|
||||
|
||||
network = network_lib.Network(x2, y2)
|
||||
network = functional.Functional(x2, y2)
|
||||
self.assertEqual(len(network.updates), 3)
|
||||
|
||||
x3 = input_layer_lib.Input(shape=(1,))
|
||||
@ -120,7 +120,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
dense_a = layers.Dense(4, name='dense_a')
|
||||
dense_b = layers.Dense(2, name='dense_b')
|
||||
y = dense_b(dense_a(x))
|
||||
network = network_lib.Network(x, y, name='dense_network')
|
||||
network = functional.Functional(x, y, name='dense_network')
|
||||
|
||||
# test various get_layer by index
|
||||
self.assertEqual(network.get_layer(index=1), dense_a)
|
||||
@ -251,7 +251,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
x = input_layer_lib.Input(shape=(32,))
|
||||
dense = layers.Dense(2)
|
||||
y = dense(x)
|
||||
network = network_lib.Network(x, y, name='dense_network')
|
||||
network = functional.Functional(x, y, name='dense_network')
|
||||
|
||||
# test basic attributes
|
||||
self.assertEqual(network.name, 'dense_network')
|
||||
@ -740,7 +740,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
else:
|
||||
x = input_layer_lib.Input(shape=(32,))
|
||||
y = MaskedLayer()(x) # pylint: disable=not-callable
|
||||
network = network_lib.Network(x, y)
|
||||
network = functional.Functional(x, y)
|
||||
|
||||
# test callability on Input
|
||||
x_2 = input_layer_lib.Input(shape=(32,))
|
||||
@ -1102,7 +1102,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_subclassed_error_if_init_not_called(self):
|
||||
|
||||
class MyNetwork(network_lib.Network):
|
||||
class MyNetwork(training_lib.Model):
|
||||
|
||||
def __init__(self):
|
||||
self._foo = [layers.Dense(10), layers.Dense(10)]
|
||||
@ -1124,10 +1124,12 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
inputs = input_layer_lib.Input(shape=(32,))
|
||||
outputs = layers.Dense(4)(inputs)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, 'unexpected argument'):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'got an unexpected keyword argument'):
|
||||
model = training_lib.Model(
|
||||
inputs, outputs, name='m', trainable=False, dtype='int64')
|
||||
with self.assertRaisesRegexp(TypeError, 'unexpected argument'):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'got an unexpected keyword argument'):
|
||||
model = training_lib.Model(
|
||||
inputs, outputs, name='m', trainable=False, dynamic=False)
|
||||
|
||||
@ -1136,8 +1138,10 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
self.assertFalse(model.trainable)
|
||||
self.assertFalse(model.dynamic)
|
||||
|
||||
class SubclassModel(training_lib.Model):
|
||||
pass
|
||||
# Subclassed model
|
||||
model = training_lib.Model(
|
||||
model = SubclassModel(
|
||||
name='subclassed', trainable=True, dtype='int64', dynamic=True)
|
||||
self.assertEqual('subclassed', model.name)
|
||||
self.assertTrue(model.dynamic)
|
||||
@ -1150,9 +1154,9 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
input_tensor2 = input_layer_lib.Input(shape=[10], name='b')
|
||||
output_tensor1 = layers.Dense(units=10)(input_tensor1)
|
||||
|
||||
net = network_lib.Network(
|
||||
net = functional.Functional(
|
||||
inputs=[input_tensor1, input_tensor2], outputs=[output_tensor1])
|
||||
net2 = network_lib.Network.from_config(net.get_config())
|
||||
net2 = functional.Functional.from_config(net.get_config())
|
||||
self.assertLen(net2.inputs, 2)
|
||||
self.assertEqual('a', net2.layers[0].name)
|
||||
self.assertEqual('b', net2.layers[1].name)
|
||||
@ -1180,8 +1184,8 @@ class DeferredModeTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(x.shape.as_list(), [None, 2])
|
||||
|
||||
outputs = layers.Dense(4)(x)
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
self.assertIsInstance(network, network_lib.Network)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertIsInstance(network, functional.Functional)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# It should be possible to call such a network on EagerTensors.
|
||||
@ -1204,7 +1208,7 @@ class DeferredModeTest(keras_parameterized.TestCase):
|
||||
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
|
||||
c = layers.Dense(2)(c)
|
||||
|
||||
network = network_lib.Network([input_a, input_b], [a, c])
|
||||
network = functional.Functional([input_a, input_b], [a, c])
|
||||
if context.executing_eagerly():
|
||||
a_val = constant_op.constant(
|
||||
np.random.random((10, 32)).astype('float32'))
|
||||
@ -1484,9 +1488,9 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
'x2': input_layer_lib.Input(shape=(1,))
|
||||
}
|
||||
outputs = layers.Add()([inputs['x1'], inputs['x2']])
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
network = network_lib.Network.from_config(network.get_config())
|
||||
network = functional.Functional.from_config(network.get_config())
|
||||
|
||||
result_tensor = network({
|
||||
'x': array_ops.ones((1, 1), 'float32'),
|
||||
@ -1509,9 +1513,9 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
'x*x': layers.Multiply()([inputs, inputs])
|
||||
}
|
||||
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
network = network_lib.Network.from_config(network.get_config())
|
||||
network = functional.Functional.from_config(network.get_config())
|
||||
|
||||
result_tensor = network(array_ops.ones((1, 1), 'float32'))
|
||||
result = self.evaluate(result_tensor)
|
||||
@ -1531,7 +1535,8 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
'x1+x2': layers.Add()([inner_inputs['x1'], inner_inputs['x2']]),
|
||||
'x1*x2': layers.Multiply()([inner_inputs['x1'], inner_inputs['x2']])
|
||||
}
|
||||
inner_network = network_lib.Network(inner_inputs, inner_outputs)
|
||||
inner_network = functional.Functional(
|
||||
inner_inputs, inner_outputs)
|
||||
|
||||
inputs = [
|
||||
input_layer_lib.Input(shape=(1,)),
|
||||
@ -1539,9 +1544,9 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
]
|
||||
middle = inner_network({'x1': inputs[0], 'x2': inputs[1]})
|
||||
outputs = layers.Add()([middle['x1+x2'], middle['x1*x2']])
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
network = network_lib.Network.from_config(network.get_config())
|
||||
network = functional.Functional.from_config(network.get_config())
|
||||
|
||||
# Computes: `(x1+x2) + (x1*x2)`
|
||||
result_tensor = network(
|
||||
@ -1735,13 +1740,13 @@ class DTypeTest(keras_parameterized.TestCase):
|
||||
def test_graph_network_dtype(self):
|
||||
inputs = input_layer_lib.Input((10,))
|
||||
outputs = layers.Dense(10)(inputs)
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertEqual(network.dtype, 'float32')
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_subclassed_network_dtype(self):
|
||||
|
||||
class IdentityNetwork(network_lib.Network):
|
||||
class IdentityNetwork(training_lib.Model):
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
@ -1785,11 +1790,11 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
|
||||
def layer_and_network_test(self):
|
||||
# Top level layer
|
||||
network = network_lib.Network()
|
||||
network = functional.Functional()
|
||||
|
||||
layer_0 = AttrTrackingLayer()
|
||||
|
||||
sub_network = network_lib.Network()
|
||||
sub_network = functional.Functional()
|
||||
layer_1 = AttrTrackingLayer(dynamic=True)
|
||||
layer_2 = AttrTrackingLayer()
|
||||
sub_network.sub_layers = [layer_1, layer_2]
|
||||
@ -1887,7 +1892,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
x = input_layer_lib.Input(shape=(None, 32))
|
||||
dense = layers.Dense(2)
|
||||
y = dense(x)
|
||||
network = network_lib.Network(x, y, name='dense_network')
|
||||
network = functional.Functional(x, y, name='dense_network')
|
||||
|
||||
for i in range(999, 1024):
|
||||
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
|
||||
@ -1895,7 +1900,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
def test_2d_inputs_squeezed_to_1d(self):
|
||||
input_1d = input_layer_lib.Input(shape=())
|
||||
outputs = input_1d * 2.
|
||||
net = network_lib.Network(input_1d, outputs)
|
||||
net = functional.Functional(input_1d, outputs)
|
||||
|
||||
x = np.ones((10, 1))
|
||||
y = net(x)
|
||||
@ -1904,7 +1909,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
def test_1d_inputs_expanded_to_2d(self):
|
||||
input_1d = input_layer_lib.Input(shape=(1,))
|
||||
outputs = input_1d * 2.
|
||||
net = network_lib.Network(input_1d, outputs)
|
||||
net = functional.Functional(input_1d, outputs)
|
||||
|
||||
x = np.ones((10,))
|
||||
y = net(x)
|
||||
@ -1927,14 +1932,14 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, training=True)
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
# Hard-coded value passed during construction is respected.
|
||||
self.assertAllEqual(network(x, training=False), x)
|
||||
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, training=False)
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
network(x, training=True)
|
||||
# Hard-coded value passed during construction is respected.
|
||||
@ -1942,7 +1947,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, training=None)
|
||||
network = network_lib.Network(inputs, outputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
# `None` value passed during construction is overridden.
|
||||
self.assertAllEqual(network(x, training=True), x)
|
||||
@ -26,8 +26,8 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import layers as layer_module
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import functional
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
@ -35,7 +35,6 @@ from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
@ -48,7 +47,7 @@ SINGLE_LAYER_OUTPUT_ERROR_MSG = ('All layers in a Sequential model should have '
|
||||
|
||||
|
||||
@keras_export('keras.Sequential', 'keras.models.Sequential')
|
||||
class Sequential(training.Model):
|
||||
class Sequential(functional.Functional):
|
||||
"""`Sequential` groups a linear stack of layers into a `tf.keras.Model`.
|
||||
|
||||
`Sequential` provides training and inference features on this model.
|
||||
@ -113,7 +112,9 @@ class Sequential(training.Model):
|
||||
layers: Optional list of layers to add to the model.
|
||||
name: Optional name for the model.
|
||||
"""
|
||||
super(Sequential, self).__init__(name=name, autocast=False)
|
||||
# Skip the init in FunctionalModel since model doesn't have input/output yet
|
||||
super(functional.Functional, self).__init__( # pylint: disable=bad-super-call
|
||||
name=name, autocast=False)
|
||||
self.supports_masking = True
|
||||
self._compute_output_and_mask_jointly = True
|
||||
self._auto_track_sub_layers = False
|
||||
@ -152,11 +153,6 @@ class Sequential(training.Model):
|
||||
return layers[1:]
|
||||
return layers[:]
|
||||
|
||||
@property
|
||||
@trackable_layer_utils.cache_recursive_attribute('dynamic')
|
||||
def dynamic(self):
|
||||
return any(layer.dynamic for layer in self.layers)
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def add(self, layer):
|
||||
"""Adds a layer instance on top of the layer stack.
|
||||
@ -233,7 +229,7 @@ class Sequential(training.Model):
|
||||
self.built = True
|
||||
|
||||
if set_inputs or self._graph_initialized:
|
||||
self._init_graph_network(self.inputs, self.outputs, name=self.name)
|
||||
self._init_graph_network(self.inputs, self.outputs)
|
||||
self._graph_initialized = True
|
||||
else:
|
||||
self._layers.append(layer)
|
||||
@ -267,7 +263,7 @@ class Sequential(training.Model):
|
||||
elif self._graph_initialized:
|
||||
self.layers[-1]._outbound_nodes = []
|
||||
self.outputs = [self.layers[-1].output]
|
||||
self._init_graph_network(self.inputs, self.outputs, name=self.name)
|
||||
self._init_graph_network(self.inputs, self.outputs)
|
||||
self.built = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@ -341,7 +337,7 @@ class Sequential(training.Model):
|
||||
# case, we fall back to the legacy deferred behavior.
|
||||
# TODO(fchollet): consider raising here, as we should not be
|
||||
# supporting such layers.
|
||||
self._init_graph_network(inputs, outputs, name=self.name)
|
||||
self._init_graph_network(inputs, outputs)
|
||||
self._graph_initialized = True
|
||||
except: # pylint:disable=bare-except
|
||||
self._use_legacy_deferred_behavior = True
|
||||
@ -350,7 +346,7 @@ class Sequential(training.Model):
|
||||
@generic_utils.default
|
||||
def build(self, input_shape=None):
|
||||
if self._graph_initialized:
|
||||
self._init_graph_network(self.inputs, self.outputs, name=self.name)
|
||||
self._init_graph_network(self.inputs, self.outputs)
|
||||
else:
|
||||
if input_shape is None:
|
||||
raise ValueError('You must provide an `input_shape` argument.')
|
||||
@ -380,7 +376,7 @@ class Sequential(training.Model):
|
||||
|
||||
if self._graph_initialized:
|
||||
if not self.built:
|
||||
self._init_graph_network(self.inputs, self.outputs, name=self.name)
|
||||
self._init_graph_network(self.inputs, self.outputs)
|
||||
return super(Sequential, self).call(inputs, training=training, mask=mask)
|
||||
|
||||
outputs = inputs # handle the corner case where self.layers is empty
|
||||
@ -519,6 +515,13 @@ class Sequential(training.Model):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_weights_created(self):
|
||||
if self._graph_initialized:
|
||||
return
|
||||
# When the graph has not been initialized, use the Model's implementation to
|
||||
# to check if the weights has been created.
|
||||
super(functional.Functional, self)._assert_weights_created() # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
def _get_shape_tuple(t):
|
||||
if hasattr(t, 'shape'):
|
||||
|
||||
@ -20,6 +20,9 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
@ -31,19 +34,31 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import callbacks as callbacks_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import compile_utils
|
||||
from tensorflow.python.keras.engine import data_adapter
|
||||
from tensorflow.python.keras.engine import network
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils import version_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -52,12 +67,33 @@ from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import trace
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import h5py
|
||||
except ImportError:
|
||||
h5py = None
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
_keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras',
|
||||
@ -97,8 +133,25 @@ def disable_multi_worker(method):
|
||||
target=method, decorator_func=_method_wrapper)
|
||||
|
||||
|
||||
def inject_functional_model_class(cls):
|
||||
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top
|
||||
if cls == Model or cls == training_v1.Model:
|
||||
return functional.Functional
|
||||
|
||||
cls.__bases__ = tuple(inject_functional_model_class(base)
|
||||
for base in cls.__bases__)
|
||||
return cls
|
||||
|
||||
|
||||
def is_functional_model_init_params(args, kwargs):
|
||||
return (len(args) == 2 or
|
||||
len(args) == 1 and 'outputs' in kwargs or
|
||||
'inputs' in kwargs and 'outputs' in kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.Model', 'keras.models.Model')
|
||||
class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
"""`Model` groups layers into an object with training and inference features.
|
||||
|
||||
Arguments:
|
||||
@ -174,11 +227,61 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
_TF_MODULE_IGNORED_PROPERTIES = frozenset(
|
||||
itertools.chain(('_train_counter', '_test_counter', '_predict_counter',
|
||||
'_steps_per_execution'),
|
||||
network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access
|
||||
base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Signature detection
|
||||
if is_functional_model_init_params(args, kwargs) and cls == Model:
|
||||
# Functional model
|
||||
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
|
||||
return functional.Functional(*args, **kwargs)
|
||||
else:
|
||||
return super(Model, cls).__new__(cls, *args, **kwargs)
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model, self).__init__(*args, **kwargs)
|
||||
_keras_api_gauge.get_cell('model').set(True)
|
||||
# Special case for Subclassed Functional Model, which we couldn't detect
|
||||
# when __new__ is called. We only realize it is a functional model when it
|
||||
# calls super.__init__ with input and output tensor.
|
||||
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
|
||||
if (is_functional_model_init_params(args, kwargs) and
|
||||
not isinstance(self, functional.Functional)):
|
||||
inject_functional_model_class(self.__class__)
|
||||
functional.Functional.__init__(self, *args, **kwargs)
|
||||
return
|
||||
|
||||
# The following are implemented as property functions:
|
||||
# self.trainable_weights
|
||||
# self.non_trainable_weights
|
||||
generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic',
|
||||
'name', 'autocast'})
|
||||
super(Model, self).__init__(**kwargs)
|
||||
# By default, Model is a subclass model, which is not in graph network.
|
||||
self._is_graph_network = False
|
||||
|
||||
self.inputs = None
|
||||
self.outputs = None
|
||||
self.input_names = None
|
||||
self.output_names = None
|
||||
# stop_training is used by callback to stop training when error happens
|
||||
self.stop_training = False
|
||||
self.history = None
|
||||
# These objects are used in the default `Model.compile`. They are not
|
||||
# guaranteed to be set after `Model.compile` is called, as users can
|
||||
# override compile with custom logic.
|
||||
self.compiled_loss = None
|
||||
self.compiled_metrics = None
|
||||
|
||||
# This is True for Sequential networks and Functional networks.
|
||||
self._compute_output_and_mask_jointly = False
|
||||
|
||||
# Don't reset compilation if already done. This may occur if calling
|
||||
# `__init__` (or `_init_graph_network`) on an already-compiled model
|
||||
# such as a Sequential model. Sequential models may need to rebuild
|
||||
# themselves after compilation.
|
||||
self._maybe_create_attribute('_is_compiled', False)
|
||||
self._maybe_create_attribute('optimizer', None)
|
||||
|
||||
# Model must be created under scope of DistStrat it will be trained with.
|
||||
if ds_context.has_strategy():
|
||||
self._distribution_strategy = ds_context.get_strategy()
|
||||
@ -186,23 +289,20 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
self._distribution_strategy = None
|
||||
# Defaults to value of `tf.config.experimental_functions_run_eagerly`.
|
||||
self._run_eagerly = None
|
||||
self.stop_training = False
|
||||
# Initialize cache attrs.
|
||||
self._reset_compile_cache()
|
||||
|
||||
# Fault-tolerance handler. Set in `ModelCheckpoint`.
|
||||
self._training_state = None
|
||||
self.history = None
|
||||
|
||||
# These objects are used in the default `Model.compile`. They are not
|
||||
# guaranteed to be set after `Model.compile` is called, as users can
|
||||
# override compile with custom logic.
|
||||
self.compiled_loss = None
|
||||
self.compiled_metrics = None
|
||||
self._saved_model_inputs_spec = None
|
||||
self._trackable_saver = (
|
||||
trackable_utils.saver_with_op_caching(self))
|
||||
|
||||
self._steps_per_execution = None
|
||||
|
||||
self._init_batch_counters()
|
||||
self._base_model_initialized = True
|
||||
_keras_api_gauge.get_cell('model').set(True)
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _init_batch_counters(self):
|
||||
@ -214,67 +314,153 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
self._predict_counter = variables.Variable(
|
||||
0, dtype='int64', aggregation=agg)
|
||||
|
||||
def get_weights(self):
|
||||
"""Retrieves the weights of the model.
|
||||
def __setattr__(self, name, value):
|
||||
if not getattr(self, '_self_setattr_tracking', True):
|
||||
super(Model, self).__setattr__(name, value)
|
||||
return
|
||||
|
||||
Returns:
|
||||
A flat list of Numpy arrays.
|
||||
"""
|
||||
with self.distribute_strategy.scope():
|
||||
return super(Model, self).get_weights()
|
||||
if all(
|
||||
isinstance(v, (base_layer.Layer,
|
||||
data_structures.TrackableDataStructure)) or
|
||||
trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
|
||||
try:
|
||||
self._base_model_initialized
|
||||
except AttributeError:
|
||||
# six.raise_from supresses the original AttributeError from being raised
|
||||
six.raise_from(
|
||||
RuntimeError('It looks like you are subclassing `Model` and you '
|
||||
'forgot to call `super(YourClass, self).__init__()`.'
|
||||
' Always start with this line.'), None)
|
||||
|
||||
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
|
||||
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
|
||||
super(Model, self).__setattr__(name, value)
|
||||
|
||||
If `by_name` is False weights are loaded based on the network's
|
||||
topology. This means the architecture should be the same as when the weights
|
||||
were saved. Note that layers that don't have weights are not taken into
|
||||
account in the topological ordering, so adding or removing layers is fine as
|
||||
long as they don't have weights.
|
||||
# Keep track of metric instance created in subclassed model/layer.
|
||||
# We do this so that we can maintain the correct order of metrics by adding
|
||||
# the instance to the `metrics` list as soon as it is created.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
if isinstance(value, metrics_module.Metric):
|
||||
self._metrics.append(value)
|
||||
|
||||
If `by_name` is True, weights are loaded into layers only if they share the
|
||||
same name. This is useful for fine-tuning or transfer-learning models where
|
||||
some of the layers have changed.
|
||||
@generic_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Builds the model based on input shapes received.
|
||||
|
||||
Only topological loading (`by_name=False`) is supported when loading weights
|
||||
from the TensorFlow format. Note that topological loading differs slightly
|
||||
between TensorFlow and HDF5 formats for user-defined classes inheriting from
|
||||
`tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
|
||||
TensorFlow format loads based on the object-local names of attributes to
|
||||
which layers are assigned in the `Model`'s constructor.
|
||||
This is to be used for subclassed models, which do not know at instantiation
|
||||
time what their inputs look like.
|
||||
|
||||
Arguments:
|
||||
filepath: String, path to the weights file to load. For weight files in
|
||||
TensorFlow format, this is the file prefix (the same as was passed
|
||||
to `save_weights`).
|
||||
by_name: Boolean, whether to load weights by name or by topological
|
||||
order. Only topological loading is supported for weight files in
|
||||
TensorFlow format.
|
||||
skip_mismatch: Boolean, whether to skip loading of layers where there is
|
||||
a mismatch in the number of weights, or a mismatch in the shape of
|
||||
the weight (only valid when `by_name=True`).
|
||||
This method only exists for users who want to call `model.build()` in a
|
||||
standalone way (as a substitute for calling the model on real data to
|
||||
build it). It will never be called by the framework (and thus it will
|
||||
never throw unexpected errors in an unrelated workflow).
|
||||
|
||||
Returns:
|
||||
When loading a weight file in TensorFlow format, returns the same status
|
||||
object as `tf.train.Checkpoint.restore`. When graph building, restore
|
||||
ops are run automatically as soon as the network is built (on first call
|
||||
for user-defined classes inheriting from `Model`, immediately if it is
|
||||
already built).
|
||||
|
||||
When loading weights in HDF5 format, returns `None`.
|
||||
Args:
|
||||
input_shape: Single tuple, TensorShape, or list of shapes, where shapes
|
||||
are tuples, integers, or TensorShapes.
|
||||
|
||||
Raises:
|
||||
ImportError: If h5py is not available and the weight file is in HDF5
|
||||
format.
|
||||
ValueError: If `skip_mismatch` is set to `True` when `by_name` is
|
||||
`False`.
|
||||
ValueError:
|
||||
1. In case of invalid user-provided data (not of type tuple,
|
||||
list, or TensorShape).
|
||||
2. If the model requires call arguments that are agnostic
|
||||
to the input shapes (positional or kwarg in call signature).
|
||||
3. If not all layers were properly built.
|
||||
4. If float type inputs are not supported within the layers.
|
||||
|
||||
In each of these cases, the user should build their model by calling it
|
||||
on real tensor data.
|
||||
"""
|
||||
if dist_utils.is_tpu_strategy(self._distribution_strategy):
|
||||
if (self._distribution_strategy.extended.steps_per_run > 1 and
|
||||
(not network._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
|
||||
raise ValueError('Load weights is not yet supported with TPUStrategy '
|
||||
'with steps_per_run greater than 1.')
|
||||
return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
|
||||
if self._is_graph_network:
|
||||
super(Model, self).build(input_shape)
|
||||
return
|
||||
|
||||
if input_shape is None:
|
||||
raise ValueError('Input shape must be defined when calling build on a '
|
||||
'model subclass network.')
|
||||
valid_types = (tuple, list, tensor_shape.TensorShape)
|
||||
if not isinstance(input_shape, valid_types):
|
||||
raise ValueError('Specified input shape is not one of the valid types. '
|
||||
'Please specify a batch input shape of type tuple or '
|
||||
'list of input shapes. User provided '
|
||||
'input type: {}'.format(type(input_shape)))
|
||||
|
||||
if input_shape and not self.inputs:
|
||||
# We create placeholders for the `None`s in the shape and build the model
|
||||
# in a Graph. Since tf.Variable is compatible with both eager execution
|
||||
# and graph building, the variables created after building the model in
|
||||
# a Graph are still valid when executing eagerly.
|
||||
if context.executing_eagerly():
|
||||
graph = func_graph.FuncGraph('build_graph')
|
||||
else:
|
||||
graph = backend.get_graph()
|
||||
with graph.as_default():
|
||||
if isinstance(input_shape, list):
|
||||
x = [base_layer_utils.generate_placeholders_from_shape(shape)
|
||||
for shape in input_shape]
|
||||
elif isinstance(input_shape, dict):
|
||||
x = {
|
||||
k: base_layer_utils.generate_placeholders_from_shape(shape)
|
||||
for k, shape in input_shape.items()
|
||||
}
|
||||
else:
|
||||
x = base_layer_utils.generate_placeholders_from_shape(input_shape)
|
||||
|
||||
kwargs = {}
|
||||
call_signature = self._call_full_argspec
|
||||
call_args = call_signature.args
|
||||
# Exclude `self`, `inputs`, and any argument with a default value.
|
||||
if len(call_args) > 2:
|
||||
if call_signature.defaults:
|
||||
call_args = call_args[2:-len(call_signature.defaults)]
|
||||
else:
|
||||
call_args = call_args[2:]
|
||||
for arg in call_args:
|
||||
if arg == 'training':
|
||||
# Case where `training` is a positional arg with no default.
|
||||
kwargs['training'] = False
|
||||
else:
|
||||
# Has invalid call signature with unknown positional arguments.
|
||||
raise ValueError(
|
||||
'Currently, you cannot build your model if it has '
|
||||
'positional or keyword arguments that are not '
|
||||
'inputs to the model, but are required for its '
|
||||
'`call` method. Instead, in order to instantiate '
|
||||
'and build your model, `call` your model on real '
|
||||
'tensor data with all expected call arguments.')
|
||||
elif len(call_args) < 2:
|
||||
# Signature without `inputs`.
|
||||
raise ValueError('You can only call `build` on a model if its `call` '
|
||||
'method accepts an `inputs` argument.')
|
||||
try:
|
||||
self.call(x, **kwargs)
|
||||
except (errors.InvalidArgumentError, TypeError):
|
||||
raise ValueError('You cannot build your model by calling `build` '
|
||||
'if your layers do not support float type inputs. '
|
||||
'Instead, in order to instantiate and build your '
|
||||
'model, `call` your model on real tensor data (of '
|
||||
'the correct dtype).')
|
||||
|
||||
super(Model, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
"""Calls the model on new inputs.
|
||||
|
||||
In this case `call` just reapplies
|
||||
all ops in the graph to the new inputs
|
||||
(e.g. build a new computational graph from the provided inputs).
|
||||
|
||||
Arguments:
|
||||
inputs: A tensor or list of tensors.
|
||||
training: Boolean or boolean scalar tensor, indicating whether to run
|
||||
the `Network` in training mode or inference mode.
|
||||
mask: A mask or list of masks. A mask can be
|
||||
either a tensor or None (no mask).
|
||||
|
||||
Returns:
|
||||
A tensor if there is a single output, or
|
||||
a list of tensors if there are more than one outputs.
|
||||
"""
|
||||
raise NotImplementedError('When subclassing the `Model` class, you should '
|
||||
'implement a `call` method.')
|
||||
|
||||
def compile(self,
|
||||
optimizer='rmsprop',
|
||||
@ -399,6 +585,10 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
dtype='int64',
|
||||
aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
|
||||
|
||||
@property
|
||||
def _should_compute_mask(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""Returns the model's metrics added using `compile`, `add_metric` APIs.
|
||||
@ -1661,6 +1851,564 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
verbose=verbose,
|
||||
callbacks=callbacks)
|
||||
|
||||
######################################################################
|
||||
# Functions below are not training related. They are for model weights
|
||||
# tracking, save/load, serialization, etc.
|
||||
######################################################################
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
self._assert_weights_created()
|
||||
return self._dedup_weights(
|
||||
trackable_layer_utils.gather_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._trainable_weights))
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
self._assert_weights_created()
|
||||
return self._dedup_weights(
|
||||
trackable_layer_utils.gather_non_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._non_trainable_weights +
|
||||
self._trainable_weights))
|
||||
|
||||
def get_weights(self):
|
||||
"""Retrieves the weights of the model.
|
||||
|
||||
Returns:
|
||||
A flat list of Numpy arrays.
|
||||
"""
|
||||
with self.distribute_strategy.scope():
|
||||
return super(Model, self).get_weights()
|
||||
|
||||
def save(self,
|
||||
filepath,
|
||||
overwrite=True,
|
||||
include_optimizer=True,
|
||||
save_format=None,
|
||||
signatures=None,
|
||||
options=None):
|
||||
"""Saves the model to Tensorflow SavedModel or a single HDF5 file.
|
||||
|
||||
The savefile includes:
|
||||
|
||||
- The model architecture, allowing to re-instantiate the model.
|
||||
- The model weights.
|
||||
- The state of the optimizer, allowing to resume training
|
||||
exactly where you left off.
|
||||
|
||||
This allows you to save the entirety of the state of a model
|
||||
in a single file.
|
||||
|
||||
Saved models can be reinstantiated via `keras.models.load_model`.
|
||||
The model returned by `load_model` is a compiled model ready to be used
|
||||
(unless the saved model was never compiled in the first place).
|
||||
|
||||
Models built with the Sequential and Functional API can be saved to both the
|
||||
HDF5 and SavedModel formats. Subclassed models can only be saved with the
|
||||
SavedModel format.
|
||||
|
||||
Note that the model weights may have different scoped names after being
|
||||
loaded. Scoped names include the model/layer names, such as
|
||||
`"dense_1/kernel:0"`. It is recommended that you use the layer properties to
|
||||
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
||||
|
||||
Arguments:
|
||||
filepath: String, PathLike, path to SavedModel or H5 file to save the
|
||||
model.
|
||||
overwrite: Whether to silently overwrite any existing file at the
|
||||
target location, or provide the user with a manual prompt.
|
||||
include_optimizer: If True, save optimizer's state together.
|
||||
save_format: Either `'tf'` or `'h5'`, indicating whether to save the
|
||||
model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
|
||||
and 'h5' in TF 1.X.
|
||||
signatures: Signatures to save with the SavedModel. Applicable to the
|
||||
'tf' format only. Please see the `signatures` argument in
|
||||
`tf.saved_model.save` for details.
|
||||
options: Optional `tf.saved_model.SaveOptions` object that specifies
|
||||
options for saving to SavedModel.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from keras.models import load_model
|
||||
|
||||
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
|
||||
del model # deletes the existing model
|
||||
|
||||
# returns a compiled model
|
||||
# identical to the previous one
|
||||
model = load_model('my_model.h5')
|
||||
```
|
||||
"""
|
||||
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
|
||||
signatures, options)
|
||||
|
||||
def save_weights(self, filepath, overwrite=True, save_format=None):
|
||||
"""Saves all layer weights.
|
||||
|
||||
Either saves in HDF5 or in TensorFlow format based on the `save_format`
|
||||
argument.
|
||||
|
||||
When saving in HDF5 format, the weight file has:
|
||||
- `layer_names` (attribute), a list of strings
|
||||
(ordered names of model layers).
|
||||
- For every layer, a `group` named `layer.name`
|
||||
- For every such layer group, a group attribute `weight_names`,
|
||||
a list of strings
|
||||
(ordered names of weights tensor of the layer).
|
||||
- For every weight in the layer, a dataset
|
||||
storing the weight value, named after the weight tensor.
|
||||
|
||||
When saving in TensorFlow format, all objects referenced by the network are
|
||||
saved in the same format as `tf.train.Checkpoint`, including any `Layer`
|
||||
instances or `Optimizer` instances assigned to object attributes. For
|
||||
networks constructed from inputs and outputs using `tf.keras.Model(inputs,
|
||||
outputs)`, `Layer` instances used by the network are tracked/saved
|
||||
automatically. For user-defined classes which inherit from `tf.keras.Model`,
|
||||
`Layer` instances must be assigned to object attributes, typically in the
|
||||
constructor. See the documentation of `tf.train.Checkpoint` and
|
||||
`tf.keras.Model` for details.
|
||||
|
||||
While the formats are the same, do not mix `save_weights` and
|
||||
`tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be
|
||||
loaded using `Model.load_weights`. Checkpoints saved using
|
||||
`tf.train.Checkpoint.save` should be restored using the corresponding
|
||||
`tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
|
||||
`save_weights` for training checkpoints.
|
||||
|
||||
The TensorFlow format matches objects and variables by starting at a root
|
||||
object, `self` for `save_weights`, and greedily matching attribute
|
||||
names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this
|
||||
is the `Checkpoint` even if the `Checkpoint` has a model attached. This
|
||||
means saving a `tf.keras.Model` using `save_weights` and loading into a
|
||||
`tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match
|
||||
the `Model`'s variables. See the [guide to training
|
||||
checkpoints](https://www.tensorflow.org/guide/checkpoint) for details
|
||||
on the TensorFlow format.
|
||||
|
||||
Arguments:
|
||||
filepath: String or PathLike, path to the file to save the weights to.
|
||||
When saving in TensorFlow format, this is the prefix used for
|
||||
checkpoint files (multiple files are generated). Note that the '.h5'
|
||||
suffix causes weights to be saved in HDF5 format.
|
||||
overwrite: Whether to silently overwrite any existing file at the
|
||||
target location, or provide the user with a manual prompt.
|
||||
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
|
||||
'.keras' will default to HDF5 if `save_format` is `None`. Otherwise
|
||||
`None` defaults to 'tf'.
|
||||
|
||||
Raises:
|
||||
ImportError: If h5py is not available when attempting to save in HDF5
|
||||
format.
|
||||
ValueError: For invalid/unknown format arguments.
|
||||
"""
|
||||
self._assert_weights_created()
|
||||
filepath = path_to_string(filepath)
|
||||
filepath_is_h5 = _is_hdf5_filepath(filepath)
|
||||
if save_format is None:
|
||||
if filepath_is_h5:
|
||||
save_format = 'h5'
|
||||
else:
|
||||
save_format = 'tf'
|
||||
else:
|
||||
user_format = save_format.lower().strip()
|
||||
if user_format in ('tensorflow', 'tf'):
|
||||
save_format = 'tf'
|
||||
elif user_format in ('hdf5', 'h5', 'keras'):
|
||||
save_format = 'h5'
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
|
||||
save_format,))
|
||||
if save_format == 'tf' and filepath_is_h5:
|
||||
raise ValueError(
|
||||
('save_weights got save_format="tf"/"tensorflow", but the '
|
||||
'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
|
||||
'when saving in TensorFlow format.')
|
||||
% filepath)
|
||||
|
||||
if save_format == 'h5' and h5py is None:
|
||||
raise ImportError(
|
||||
'`save_weights` requires h5py when saving in hdf5.')
|
||||
if save_format == 'tf':
|
||||
check_filepath = filepath + '.index'
|
||||
else:
|
||||
check_filepath = filepath
|
||||
# If file exists and should not be overwritten:
|
||||
if not overwrite and os.path.isfile(check_filepath):
|
||||
proceed = ask_to_proceed_with_overwrite(check_filepath)
|
||||
if not proceed:
|
||||
return
|
||||
if save_format == 'h5':
|
||||
with h5py.File(filepath, 'w') as f:
|
||||
hdf5_format.save_weights_to_hdf5_group(f, self.layers)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
session = None
|
||||
else:
|
||||
session = backend.get_session()
|
||||
optimizer = getattr(self, 'optimizer', None)
|
||||
if (optimizer
|
||||
and not isinstance(optimizer, trackable.Trackable)):
|
||||
logging.warning(
|
||||
('This model was compiled with a Keras optimizer (%s) but is being '
|
||||
'saved in TensorFlow format with `save_weights`. The model\'s '
|
||||
'weights will be saved, but unlike with TensorFlow optimizers in '
|
||||
'the TensorFlow format the optimizer\'s state will not be '
|
||||
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
|
||||
% (optimizer,))
|
||||
self._trackable_saver.save(filepath, session=session)
|
||||
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
save_dir=os.path.dirname(filepath),
|
||||
model_checkpoint_path=filepath,
|
||||
save_relative_paths=True,
|
||||
all_model_checkpoint_paths=[filepath])
|
||||
|
||||
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
|
||||
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
|
||||
|
||||
If `by_name` is False weights are loaded based on the network's
|
||||
topology. This means the architecture should be the same as when the weights
|
||||
were saved. Note that layers that don't have weights are not taken into
|
||||
account in the topological ordering, so adding or removing layers is fine as
|
||||
long as they don't have weights.
|
||||
|
||||
If `by_name` is True, weights are loaded into layers only if they share the
|
||||
same name. This is useful for fine-tuning or transfer-learning models where
|
||||
some of the layers have changed.
|
||||
|
||||
Only topological loading (`by_name=False`) is supported when loading weights
|
||||
from the TensorFlow format. Note that topological loading differs slightly
|
||||
between TensorFlow and HDF5 formats for user-defined classes inheriting from
|
||||
`tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
|
||||
TensorFlow format loads based on the object-local names of attributes to
|
||||
which layers are assigned in the `Model`'s constructor.
|
||||
|
||||
Arguments:
|
||||
filepath: String, path to the weights file to load. For weight files in
|
||||
TensorFlow format, this is the file prefix (the same as was passed
|
||||
to `save_weights`).
|
||||
by_name: Boolean, whether to load weights by name or by topological
|
||||
order. Only topological loading is supported for weight files in
|
||||
TensorFlow format.
|
||||
skip_mismatch: Boolean, whether to skip loading of layers where there is
|
||||
a mismatch in the number of weights, or a mismatch in the shape of
|
||||
the weight (only valid when `by_name=True`).
|
||||
|
||||
Returns:
|
||||
When loading a weight file in TensorFlow format, returns the same status
|
||||
object as `tf.train.Checkpoint.restore`. When graph building, restore
|
||||
ops are run automatically as soon as the network is built (on first call
|
||||
for user-defined classes inheriting from `Model`, immediately if it is
|
||||
already built).
|
||||
|
||||
When loading weights in HDF5 format, returns `None`.
|
||||
|
||||
Raises:
|
||||
ImportError: If h5py is not available and the weight file is in HDF5
|
||||
format.
|
||||
ValueError: If `skip_mismatch` is set to `True` when `by_name` is
|
||||
`False`.
|
||||
"""
|
||||
if dist_utils.is_tpu_strategy(self._distribution_strategy):
|
||||
if (self._distribution_strategy.extended.steps_per_run > 1 and
|
||||
(not _is_hdf5_filepath(filepath))):
|
||||
raise ValueError('Load weights is not yet supported with TPUStrategy '
|
||||
'with steps_per_run greater than 1.')
|
||||
if skip_mismatch and not by_name:
|
||||
raise ValueError(
|
||||
'When calling model.load_weights, skip_mismatch can only be set to '
|
||||
'True when by_name is True.')
|
||||
|
||||
filepath = path_to_string(filepath)
|
||||
if _is_hdf5_filepath(filepath):
|
||||
save_format = 'h5'
|
||||
else:
|
||||
try:
|
||||
py_checkpoint_reader.NewCheckpointReader(filepath)
|
||||
save_format = 'tf'
|
||||
except errors_impl.DataLossError:
|
||||
# The checkpoint is not readable in TensorFlow format. Try HDF5.
|
||||
save_format = 'h5'
|
||||
if save_format == 'tf':
|
||||
status = self._trackable_saver.restore(filepath)
|
||||
if by_name:
|
||||
raise NotImplementedError(
|
||||
'Weights may only be loaded based on topology into Models when '
|
||||
'loading TensorFlow-formatted weights (got by_name=True to '
|
||||
'load_weights).')
|
||||
if not context.executing_eagerly():
|
||||
session = backend.get_session()
|
||||
# Restore existing variables (if any) immediately, and set up a
|
||||
# streaming restore for any variables created in the future.
|
||||
trackable_utils.streaming_restore(status=status, session=session)
|
||||
status.assert_nontrivial_match()
|
||||
return status
|
||||
if h5py is None:
|
||||
raise ImportError(
|
||||
'`load_weights` requires h5py when loading weights from HDF5.')
|
||||
if not self._is_graph_network and not self.built:
|
||||
raise ValueError(
|
||||
'Unable to load weights saved in HDF5 format into a subclassed '
|
||||
'Model which has not created its variables yet. Call the Model '
|
||||
'first, then load the weights.')
|
||||
self._assert_weights_created()
|
||||
with h5py.File(filepath, 'r') as f:
|
||||
if 'layer_names' not in f.attrs and 'model_weights' in f:
|
||||
f = f['model_weights']
|
||||
if by_name:
|
||||
hdf5_format.load_weights_from_hdf5_group_by_name(
|
||||
f, self.layers, skip_mismatch=skip_mismatch)
|
||||
else:
|
||||
hdf5_format.load_weights_from_hdf5_group(f, self.layers)
|
||||
|
||||
def _updated_config(self):
|
||||
"""Util shared between different serialization methods.
|
||||
|
||||
Returns:
|
||||
Model config with Keras version information added.
|
||||
"""
|
||||
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
|
||||
|
||||
config = self.get_config()
|
||||
model_config = {
|
||||
'class_name': self.__class__.__name__,
|
||||
'config': config,
|
||||
'keras_version': keras_version,
|
||||
'backend': backend.backend()
|
||||
}
|
||||
return model_config
|
||||
|
||||
def get_config(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
# Since only FunctionalModel produces config, the model can only
|
||||
# be constructed for FunctionalModel
|
||||
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
|
||||
return functional.Functional.from_config(
|
||||
config, custom_objects=custom_objects)
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
"""Returns a JSON string containing the network configuration.
|
||||
|
||||
To load a network from a JSON save file, use
|
||||
`keras.models.model_from_json(json_string, custom_objects={})`.
|
||||
|
||||
Arguments:
|
||||
**kwargs: Additional keyword arguments
|
||||
to be passed to `json.dumps()`.
|
||||
|
||||
Returns:
|
||||
A JSON string.
|
||||
"""
|
||||
model_config = self._updated_config()
|
||||
return json.dumps(
|
||||
model_config, default=serialization.get_json_type, **kwargs)
|
||||
|
||||
def to_yaml(self, **kwargs):
|
||||
"""Returns a yaml string containing the network configuration.
|
||||
|
||||
To load a network from a yaml save file, use
|
||||
`keras.models.model_from_yaml(yaml_string, custom_objects={})`.
|
||||
|
||||
`custom_objects` should be a dictionary mapping
|
||||
the names of custom losses / layers / etc to the corresponding
|
||||
functions / classes.
|
||||
|
||||
Arguments:
|
||||
**kwargs: Additional keyword arguments
|
||||
to be passed to `yaml.dump()`.
|
||||
|
||||
Returns:
|
||||
A YAML string.
|
||||
|
||||
Raises:
|
||||
ImportError: if yaml module is not found.
|
||||
"""
|
||||
if yaml is None:
|
||||
raise ImportError(
|
||||
'Requires yaml module installed (`pip install pyyaml`).')
|
||||
return yaml.dump(self._updated_config(), **kwargs)
|
||||
|
||||
def reset_states(self):
|
||||
for layer in self.layers:
|
||||
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
|
||||
layer.reset_states()
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
date=None,
|
||||
instructions='This property should not be used in TensorFlow 2.0, '
|
||||
'as updates are applied automatically.')
|
||||
@doc_controls.do_not_generate_docs
|
||||
def state_updates(self):
|
||||
"""Deprecated, do NOT use!
|
||||
|
||||
Returns the `updates` from all layers that are stateful.
|
||||
|
||||
This is useful for separating training updates and
|
||||
state updates, e.g. when we need to update a layer's internal state
|
||||
during prediction.
|
||||
|
||||
Returns:
|
||||
A list of update ops.
|
||||
"""
|
||||
state_updates = []
|
||||
for layer in self.layers:
|
||||
if getattr(layer, 'stateful', False):
|
||||
if hasattr(layer, 'updates'):
|
||||
state_updates += layer.updates
|
||||
return state_updates
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
"""Returns the list of all layer variables/weights.
|
||||
|
||||
Returns:
|
||||
A list of variables.
|
||||
"""
|
||||
return self._dedup_weights(self._undeduplicated_weights)
|
||||
|
||||
@property
|
||||
def _undeduplicated_weights(self):
|
||||
"""Returns the undeduplicated list of all layer variables/weights."""
|
||||
self._assert_weights_created()
|
||||
weights = []
|
||||
for layer in self._layers:
|
||||
weights += layer.weights
|
||||
weights += (self._trainable_weights + self._non_trainable_weights)
|
||||
return weights
|
||||
|
||||
def summary(self, line_length=None, positions=None, print_fn=None):
|
||||
"""Prints a string summary of the network.
|
||||
|
||||
Arguments:
|
||||
line_length: Total length of printed lines
|
||||
(e.g. set this to adapt the display to different
|
||||
terminal window sizes).
|
||||
positions: Relative or absolute positions of log elements
|
||||
in each line. If not provided,
|
||||
defaults to `[.33, .55, .67, 1.]`.
|
||||
print_fn: Print function to use. Defaults to `print`.
|
||||
It will be called on each line of the summary.
|
||||
You can set it to a custom function
|
||||
in order to capture the string summary.
|
||||
|
||||
Raises:
|
||||
ValueError: if `summary()` is called before the model is built.
|
||||
"""
|
||||
if not self.built:
|
||||
raise ValueError('This model has not yet been built. '
|
||||
'Build the model first by calling `build()` or calling '
|
||||
'`fit()` with some data, or specify '
|
||||
'an `input_shape` argument in the first layer(s) for '
|
||||
'automatic build.')
|
||||
layer_utils.print_summary(self,
|
||||
line_length=line_length,
|
||||
positions=positions,
|
||||
print_fn=print_fn)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self._unique_sublayers()
|
||||
|
||||
def get_layer(self, name=None, index=None):
|
||||
"""Retrieves a layer based on either its name (unique) or index.
|
||||
|
||||
If `name` and `index` are both provided, `index` will take precedence.
|
||||
Indices are based on order of horizontal graph traversal (bottom-up).
|
||||
|
||||
Arguments:
|
||||
name: String, name of layer.
|
||||
index: Integer, index of layer.
|
||||
|
||||
Returns:
|
||||
A layer instance.
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid layer name or index.
|
||||
"""
|
||||
# TODO(fchollet): We could build a dictionary based on layer names
|
||||
# since they are constant, but we have not done that yet.
|
||||
if index is not None and name is not None:
|
||||
raise ValueError('Provide only a layer name or a layer index.')
|
||||
|
||||
if index is not None:
|
||||
if len(self.layers) <= index:
|
||||
raise ValueError('Was asked to retrieve layer at index ' + str(index) +
|
||||
' but model only has ' + str(len(self.layers)) +
|
||||
' layers.')
|
||||
else:
|
||||
return self.layers[index]
|
||||
|
||||
if name is not None:
|
||||
for layer in self.layers:
|
||||
if layer.name == name:
|
||||
return layer
|
||||
raise ValueError('No such layer: ' + name + '.')
|
||||
raise ValueError('Provide either a layer name or layer index.')
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _set_save_spec(self, inputs):
|
||||
if self._saved_model_inputs_spec is not None:
|
||||
return # Already set.
|
||||
|
||||
input_names = self.input_names
|
||||
if not input_names:
|
||||
input_names = compile_utils.create_pseudo_input_names(inputs)
|
||||
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
specs = []
|
||||
for name, tensor in zip(input_names, flat_inputs):
|
||||
specs.append(
|
||||
tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
|
||||
specs = nest.pack_sequence_as(inputs, specs)
|
||||
|
||||
self._saved_model_inputs_spec = specs
|
||||
|
||||
def _get_save_spec(self, dynamic_batch=True):
|
||||
if self._saved_model_inputs_spec is None:
|
||||
return None
|
||||
|
||||
return nest.map_structure(
|
||||
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
|
||||
self._saved_model_inputs_spec)
|
||||
|
||||
def _assert_weights_created(self):
|
||||
"""Asserts that all the weights for the model have been created.
|
||||
|
||||
For a non-dynamic model, the weights must already be created after the
|
||||
layer has been called. For a dynamic model, the exact list of weights can
|
||||
never be known for certain since it may change at any time during execution.
|
||||
|
||||
We run this check right before accessing weights or getting the Numpy value
|
||||
for the current weights. Otherwise, if the layer has never been called,
|
||||
the user would just get an empty list, which is misleading.
|
||||
|
||||
Raises:
|
||||
ValueError: if the weights of the network has not yet been created.
|
||||
"""
|
||||
if self.dynamic:
|
||||
return
|
||||
|
||||
if ('build' in self.__class__.__dict__ and
|
||||
self.__class__ != Model and
|
||||
not self.built):
|
||||
# For any model that has customized build() method but hasn't
|
||||
# been invoked yet, this will cover both sequential and subclass model.
|
||||
# Also make sure to exclude Model class itself which has build() defined.
|
||||
raise ValueError('Weights for model %s have not yet been created. '
|
||||
'Weights are created when the Model is first called on '
|
||||
'inputs or `build()` is called with an `input_shape`.' %
|
||||
self.name)
|
||||
|
||||
def _check_call_args(self, method_name):
|
||||
"""Check that `call` has only one positional arg."""
|
||||
# Always allow first arg, regardless of arg name.
|
||||
@ -1990,3 +2738,8 @@ def _disallow_inside_tf_function(method_name):
|
||||
'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
|
||||
).format(method_name=method_name)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def _is_hdf5_filepath(filepath):
|
||||
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
|
||||
filepath.endswith('.hdf5'))
|
||||
|
||||
@ -43,7 +43,7 @@ from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.engine import network
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
from tensorflow.python.keras.engine import training_arrays
|
||||
from tensorflow.python.keras.engine import training_distributed
|
||||
@ -181,8 +181,8 @@ class Model(training_lib.Model):
|
||||
self._compile_time_distribution_strategy)
|
||||
if strategy:
|
||||
with strategy.scope():
|
||||
return network.Network.get_weights(self)
|
||||
return network.Network.get_weights(self)
|
||||
return base_layer.Layer.get_weights(self)
|
||||
return base_layer.Layer.get_weights(self)
|
||||
|
||||
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
|
||||
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
|
||||
@ -232,7 +232,7 @@ class Model(training_lib.Model):
|
||||
"""
|
||||
if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
|
||||
if (self._distribution_strategy.extended.steps_per_run > 1 and
|
||||
(not network._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
|
||||
(not training_lib._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
|
||||
raise ValueError('Load weights is not yet supported with TPUStrategy '
|
||||
'with steps_per_run greater than 1.')
|
||||
return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
|
||||
@ -491,6 +491,11 @@ class Model(training_lib.Model):
|
||||
"""Returns the model's metrics added using `compile`, `add_metric` APIs."""
|
||||
metrics = []
|
||||
if self._is_compiled:
|
||||
if not hasattr(self, '_v1_compile_was_called'):
|
||||
# See b/155687393 for more details, the model is created as a v2
|
||||
# instance but converted to v1. Fallback to use base Model to retrieve
|
||||
# the metrics.
|
||||
return super(Model, self).metrics
|
||||
metrics += self._compile_metric_functions
|
||||
metrics.extend(self._metrics)
|
||||
metrics.extend(_get_metrics_from_layers(self._layers))
|
||||
@ -504,6 +509,12 @@ class Model(training_lib.Model):
|
||||
# losses for backward compatibility.
|
||||
metrics_names = ['loss']
|
||||
if self._is_compiled:
|
||||
if not hasattr(self, '_v1_compile_was_called'):
|
||||
# See b/155687393 for more details, the model is created as a v2
|
||||
# instance but converted to v1. Fallback to use base Model to retrieve
|
||||
# the metrics name
|
||||
return super(Model, self).metrics_names
|
||||
|
||||
# Add output loss metric names to the metric names list.
|
||||
if len(self._training_endpoints) > 1:
|
||||
metrics_names.extend([
|
||||
|
||||
@ -114,7 +114,7 @@ def populate_deserializable_objects():
|
||||
|
||||
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
|
||||
LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
|
||||
LOCAL.ALL_OBJECTS['Network'] = models.Network
|
||||
LOCAL.ALL_OBJECTS['Functional'] = models.Functional
|
||||
LOCAL.ALL_OBJECTS['Model'] = models.Model
|
||||
LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures
|
||||
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
|
||||
|
||||
@ -377,7 +377,8 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
||||
input_layer.compute_output_shape([None, 2, 4]).as_list(),
|
||||
[None, 2, 8])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
# TODO(scottzhu): check why v1 session failed.
|
||||
def test_TimeDistributed_with_mask_first_implementation(self):
|
||||
np.random.seed(100)
|
||||
rnn_layer = keras.layers.LSTM(4, return_sequences=True, stateful=True)
|
||||
|
||||
@ -23,7 +23,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.engine import network
|
||||
from tensorflow.python.keras.engine import functional
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine import training_v1
|
||||
@ -31,7 +31,6 @@ from tensorflow.python.keras.engine.base_layer import AddMetric
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_layer import Input
|
||||
from tensorflow.python.keras.engine.input_layer import InputLayer
|
||||
from tensorflow.python.keras.engine.network import Network
|
||||
from tensorflow.python.keras.saving import model_config
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
@ -45,6 +44,7 @@ from tensorflow.python.util.tf_export import keras_export
|
||||
# API entries importable from `keras.models`:
|
||||
Model = training.Model # pylint: disable=invalid-name
|
||||
Sequential = sequential.Sequential # pylint: disable=invalid-name
|
||||
Functional = functional.Functional # pylint: disable=invalid-name
|
||||
save_model = save.save_model
|
||||
load_model = save.load_model
|
||||
model_from_config = model_config.model_from_config
|
||||
@ -193,12 +193,12 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||
if not callable(layer_fn):
|
||||
raise ValueError('Expected `layer_fn` argument to be a callable.')
|
||||
|
||||
model_config, created_layers = _clone_layers_and_model_config(
|
||||
model_configs, created_layers = _clone_layers_and_model_config(
|
||||
model, new_input_layers, layer_fn)
|
||||
# Reconstruct model from the config, using the cloned layers.
|
||||
input_tensors, output_tensors, created_layers = (
|
||||
network.reconstruct_from_config(model_config,
|
||||
created_layers=created_layers))
|
||||
functional.reconstruct_from_config(model_configs,
|
||||
created_layers=created_layers))
|
||||
metrics_names = model.metrics_names
|
||||
model = Model(input_tensors, output_tensors, name=model.name)
|
||||
# Layers not directly tied to outputs of the Model, such as loss layers
|
||||
@ -209,8 +209,8 @@ def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||||
if ancillary_layers:
|
||||
new_nodes = nest.flatten([
|
||||
layer.inbound_nodes[1:]
|
||||
if network._should_skip_first_node(layer) else layer.inbound_nodes
|
||||
for layer in created_layers.values()
|
||||
if functional._should_skip_first_node(layer)
|
||||
else layer.inbound_nodes for layer in created_layers.values()
|
||||
])
|
||||
_insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
|
||||
return model
|
||||
@ -244,7 +244,8 @@ def _clone_layers_and_model_config(model, input_layers, layer_fn):
|
||||
created_layers[layer.name] = layer_fn(layer)
|
||||
return {}
|
||||
|
||||
config = network.get_network_config(model, serialize_layer_fn=_copy_layer)
|
||||
config = functional.get_network_config(
|
||||
model, serialize_layer_fn=_copy_layer)
|
||||
return config, created_layers
|
||||
|
||||
|
||||
@ -495,7 +496,7 @@ def _in_place_subclassed_model_reset(model):
|
||||
# This will not work for nested subclassed models used as layers.
|
||||
# This would be theoretically possible to support, but would add complexity.
|
||||
# Only do it if users complain.
|
||||
if isinstance(layer, Network) and not layer._is_graph_network:
|
||||
if isinstance(layer, training.Model) and not layer._is_graph_network:
|
||||
raise ValueError('We do not support the use of nested subclassed models '
|
||||
'in `model_to_estimator` at this time. Found nested '
|
||||
'model: %s' % layer)
|
||||
|
||||
@ -1210,7 +1210,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
def test_incompatible_checkpoint(self):
|
||||
save_path = trackable.Checkpoint().save(
|
||||
os.path.join(self.get_temp_dir(), 'ckpt'))
|
||||
m = keras.Model()
|
||||
m = DummySubclassModel()
|
||||
with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
|
||||
m.load_weights(save_path)
|
||||
m.dense = keras.layers.Dense(2)
|
||||
@ -1222,7 +1222,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_directory_passed(self):
|
||||
with self.cached_session():
|
||||
m = keras.Model()
|
||||
m = DummySubclassModel()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
self.evaluate(v.assign(42.))
|
||||
prefix = os.path.join(self.get_temp_dir(),
|
||||
@ -1235,7 +1235,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_relative_path(self):
|
||||
with self.cached_session():
|
||||
m = keras.Model()
|
||||
m = DummySubclassModel()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
os.chdir(self.get_temp_dir())
|
||||
|
||||
@ -1266,7 +1266,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_nonexistent_prefix_directory(self):
|
||||
with self.cached_session():
|
||||
m = keras.Model()
|
||||
m = DummySubclassModel()
|
||||
v = m.add_weight(name='v', shape=[])
|
||||
self.evaluate(v.assign(42.))
|
||||
prefix = os.path.join(self.get_temp_dir(),
|
||||
@ -1276,5 +1276,10 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
||||
m.load_weights(prefix)
|
||||
self.assertEqual(42., self.evaluate(v))
|
||||
|
||||
|
||||
class DummySubclassModel(training.Model):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
@ -62,9 +62,9 @@ layers_module = LazyLoader(
|
||||
input_layer = LazyLoader(
|
||||
"input_layer", globals(),
|
||||
"tensorflow.python.keras.engine.input_layer")
|
||||
network_lib = LazyLoader(
|
||||
"network_lib", globals(),
|
||||
"tensorflow.python.keras.engine.network")
|
||||
functional_lib = LazyLoader(
|
||||
"functional_lib", globals(),
|
||||
"tensorflow.python.keras.engine.functional")
|
||||
training_lib = LazyLoader(
|
||||
"training_lib", globals(),
|
||||
"tensorflow.python.keras.engine.training")
|
||||
@ -142,7 +142,7 @@ def _is_graph_network(layer):
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(layer, RevivedNetwork):
|
||||
return False
|
||||
elif isinstance(layer, network_lib.Network):
|
||||
elif isinstance(layer, functional_lib.Functional):
|
||||
return (layer._is_graph_network or
|
||||
isinstance(layer, models_lib.Sequential))
|
||||
return False
|
||||
@ -371,7 +371,8 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
# functional or Sequential model.
|
||||
model_is_functional_or_sequential = (
|
||||
metadata.get('is_graph_network', False) or
|
||||
metadata['class_name'] == 'Sequential')
|
||||
metadata['class_name'] == 'Sequential' or
|
||||
metadata['class_name'] == 'Functional')
|
||||
if not (generic_utils.validate_config(config) and
|
||||
model_is_functional_or_sequential):
|
||||
return None # Revive as custom model.
|
||||
@ -383,7 +384,8 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
if class_name == 'Sequential':
|
||||
model = models_lib.Sequential(name=config['name'])
|
||||
else:
|
||||
model = models_lib.Model(name=config['name'])
|
||||
model = models_lib.Functional(
|
||||
inputs=[], outputs=[], name=config['name'])
|
||||
|
||||
# Record this model and its layers. This will later be used to reconstruct
|
||||
# the model.
|
||||
@ -561,10 +563,11 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
if not model.built and not isinstance(input_specs, dict):
|
||||
model.build(input_shapes)
|
||||
else:
|
||||
(inputs, outputs, created_layers) = network_lib.reconstruct_from_config(
|
||||
config, created_layers={layer.name: layer for layer in layers})
|
||||
(inputs, outputs,
|
||||
created_layers) = functional_lib.reconstruct_from_config(
|
||||
config, created_layers={layer.name: layer for layer in layers})
|
||||
model.__init__(inputs, outputs, name=config['name'])
|
||||
network_lib.connect_ancillary_layers(model, created_layers)
|
||||
functional_lib.connect_ancillary_layers(model, created_layers)
|
||||
|
||||
# Set model dtype and trainable status.
|
||||
_set_network_attributes_from_metadata(model)
|
||||
@ -764,7 +767,7 @@ def revive_custom_object(identifier, metadata):
|
||||
revived_classes = {
|
||||
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
|
||||
'_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
|
||||
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
|
||||
'_tf_keras_network': (RevivedNetwork, functional_lib.Functional),
|
||||
'_tf_keras_model': (RevivedNetwork, model_class),
|
||||
'_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential),
|
||||
}
|
||||
@ -852,7 +855,7 @@ def _revive_setter(layer, name, value):
|
||||
layer._track_trackable(value, name=name)
|
||||
layer._serialized_attributes[name] = value
|
||||
# pylint: enable=protected-access
|
||||
elif (isinstance(layer, network_lib.Network) and
|
||||
elif (isinstance(layer, functional_lib.Functional) and
|
||||
re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
|
||||
# Edges named "layer-n" or "layer_with_weights-n", which are tracked in
|
||||
# network._track_layers, should not be added as an attribute.
|
||||
|
||||
@ -20,11 +20,11 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import network_serialization
|
||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||
from tensorflow.python.keras.saving.saved_model import save_impl
|
||||
|
||||
|
||||
class ModelSavedModelSaver(network_serialization.NetworkSavedModelSaver):
|
||||
class ModelSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
||||
"""Model SavedModel serialization."""
|
||||
|
||||
@property
|
||||
@ -33,6 +33,10 @@ class ModelSavedModelSaver(network_serialization.NetworkSavedModelSaver):
|
||||
|
||||
def _python_properties_internal(self):
|
||||
metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
|
||||
# Network stateful property is dependent on the child layers.
|
||||
metadata.pop('stateful')
|
||||
metadata['is_graph_network'] = self.obj._is_graph_network # pylint: disable=protected-access
|
||||
|
||||
metadata.update(
|
||||
saving_utils.model_metadata(
|
||||
self.obj, include_optimizer=True, require_config=False))
|
||||
|
||||
@ -18,22 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
|
||||
|
||||
# Network serialization is pretty much the same as layer serialization.
|
||||
class NetworkSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
||||
# FunctionalModel serialization is pretty much the same as Model serialization.
|
||||
class NetworkSavedModelSaver(model_serialization.ModelSavedModelSaver):
|
||||
"""Network serialization."""
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_network'
|
||||
|
||||
def _python_properties_internal(self):
|
||||
metadata = super(NetworkSavedModelSaver, self)._python_properties_internal()
|
||||
|
||||
# Network stateful property is dependent on the child layers.
|
||||
metadata.pop('stateful')
|
||||
|
||||
metadata['is_graph_network'] = self.obj._is_graph_network # pylint: disable=protected-access
|
||||
return metadata
|
||||
|
||||
@ -53,12 +53,12 @@ class SplitUtilsTest(keras_parameterized.TestCase):
|
||||
inputs = keras.Input(10)
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
self._check_model_class(model.__class__)
|
||||
self._check_model_class(model.__class__.__bases__[0])
|
||||
self._check_layer_class(model)
|
||||
|
||||
def test_sequential_model(self):
|
||||
model = keras.Sequential([keras.layers.Dense(1)])
|
||||
model_class = model.__class__.__bases__[0]
|
||||
model_class = model.__class__.__bases__[0].__bases__[0]
|
||||
self._check_model_class(model_class)
|
||||
self._check_layer_class(model)
|
||||
|
||||
|
||||
@ -55,10 +55,10 @@ def check_pydot():
|
||||
|
||||
|
||||
def is_wrapped_model(layer):
|
||||
from tensorflow.python.keras.engine import network
|
||||
from tensorflow.python.keras.engine import functional
|
||||
from tensorflow.python.keras.layers import wrappers
|
||||
return (isinstance(layer, wrappers.Wrapper) and
|
||||
isinstance(layer.layer, network.Network))
|
||||
isinstance(layer.layer, functional.Functional))
|
||||
|
||||
|
||||
def add_edge(dot, src, dst):
|
||||
@ -98,7 +98,7 @@ def model_to_dot(model,
|
||||
"""
|
||||
from tensorflow.python.keras.layers import wrappers
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import network
|
||||
from tensorflow.python.keras.engine import functional
|
||||
|
||||
if not check_pydot():
|
||||
message = (
|
||||
@ -147,7 +147,8 @@ def model_to_dot(model,
|
||||
class_name = layer.__class__.__name__
|
||||
|
||||
if isinstance(layer, wrappers.Wrapper):
|
||||
if expand_nested and isinstance(layer.layer, network.Network):
|
||||
if expand_nested and isinstance(layer.layer,
|
||||
functional.Functional):
|
||||
submodel_wrapper = model_to_dot(layer.layer, show_shapes,
|
||||
show_layer_names, rankdir,
|
||||
expand_nested,
|
||||
@ -162,7 +163,7 @@ def model_to_dot(model,
|
||||
child_class_name = layer.layer.__class__.__name__
|
||||
class_name = '{}({})'.format(class_name, child_class_name)
|
||||
|
||||
if expand_nested and isinstance(layer, network.Network):
|
||||
if expand_nested and isinstance(layer, functional.Functional):
|
||||
submodel_not_wrapper = model_to_dot(layer, show_shapes,
|
||||
show_layer_names, rankdir,
|
||||
expand_nested,
|
||||
@ -200,7 +201,8 @@ def model_to_dot(model,
|
||||
inputlabels,
|
||||
outputlabels)
|
||||
|
||||
if not expand_nested or not isinstance(layer, network.Network):
|
||||
if not expand_nested or not isinstance(
|
||||
layer, functional.Functional):
|
||||
node = pydot.Node(layer_id, label=label)
|
||||
dot.add_node(node)
|
||||
|
||||
@ -218,16 +220,17 @@ def model_to_dot(model,
|
||||
add_edge(dot, inbound_layer_id, layer_id)
|
||||
else:
|
||||
# if inbound_layer is not Model or wrapped Model
|
||||
if (not isinstance(inbound_layer, network.Network) and
|
||||
if (not isinstance(inbound_layer,
|
||||
functional.Functional) and
|
||||
not is_wrapped_model(inbound_layer)):
|
||||
# if current layer is not Model or wrapped Model
|
||||
if (not isinstance(layer, network.Network) and
|
||||
if (not isinstance(layer, functional.Functional) and
|
||||
not is_wrapped_model(layer)):
|
||||
assert dot.get_node(inbound_layer_id)
|
||||
assert dot.get_node(layer_id)
|
||||
add_edge(dot, inbound_layer_id, layer_id)
|
||||
# if current layer is Model
|
||||
elif isinstance(layer, network.Network):
|
||||
elif isinstance(layer, functional.Functional):
|
||||
add_edge(dot, inbound_layer_id,
|
||||
sub_n_first_node[layer.name].get_name())
|
||||
# if current layer is wrapped Model
|
||||
@ -236,9 +239,9 @@ def model_to_dot(model,
|
||||
name = sub_w_first_node[layer.layer.name].get_name()
|
||||
add_edge(dot, layer_id, name)
|
||||
# if inbound_layer is Model
|
||||
elif isinstance(inbound_layer, network.Network):
|
||||
elif isinstance(inbound_layer, functional.Functional):
|
||||
name = sub_n_last_node[inbound_layer.name].get_name()
|
||||
if isinstance(layer, network.Network):
|
||||
if isinstance(layer, functional.Functional):
|
||||
output_name = sub_n_first_node[layer.name].get_name()
|
||||
add_edge(dot, name, output_name)
|
||||
else:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.keras.Model"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -175,7 +174,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
path: "tensorflow.keras.Sequential"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
|
||||
@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.LinearModel"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.premade.linear.LinearModel\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -176,7 +175,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.WideDeepModel"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.premade.wide_deep.WideDeepModel\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -176,7 +175,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.keras.models.Model"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -175,7 +174,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
path: "tensorflow.keras.models.Sequential"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.keras.Model"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -175,7 +174,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
path: "tensorflow.keras.Sequential"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
|
||||
@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.LinearModel"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.premade.linear.LinearModel\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -176,7 +175,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -2,7 +2,6 @@ path: "tensorflow.keras.experimental.WideDeepModel"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.premade.wide_deep.WideDeepModel\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -176,7 +175,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.keras.models.Model"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -175,7 +174,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
path: "tensorflow.keras.models.Sequential"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.functional.Functional\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user