Split v1 and v2 Model class.

PiperOrigin-RevId: 281684883
Change-Id: I59920ca1ef5a0e87360c1485ab928d600f2f852d
This commit is contained in:
Thomas O'Malley 2019-11-20 23:35:43 -08:00 committed by TensorFlower Gardener
parent 6acdf87ae7
commit a2435de245
19 changed files with 3451 additions and 20 deletions

View File

@ -176,11 +176,13 @@ py_library(
"engine/training_eager.py",
"engine/training_generator.py",
"engine/training_utils.py",
"engine/training_v1.py",
"engine/training_v2.py",
"engine/training_v2_utils.py",
"metrics.py", # Need base_layer
"models.py",
"utils/metrics_utils.py",
"utils/version_utils.py",
],
srcs_version = "PY2AND3",
deps = [
@ -1251,6 +1253,17 @@ tf_py_test(
],
)
tf_py_test(
name = "version_utils_test",
size = "small",
srcs = ["utils/version_utils_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
],
)
tf_py_test(
name = "tf_utils_test",
size = "small",

View File

@ -55,6 +55,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -78,7 +79,7 @@ _keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras',
@keras_export('keras.models.Model', 'keras.Model')
class Model(network.Network):
class Model(network.Network, version_utils.VersionSelector):
"""`Model` groups layers into an object with training and inference features.
There are two ways to instantiate a `Model`:
@ -760,6 +761,8 @@ class Model(network.Network):
and what the model expects.
"""
_keras_api_gauge.get_cell('fit').set(True)
# Legacy graph support is contained in `training_v1.Model`.
version_utils.disallow_legacy_graph('Model', 'fit')
# Legacy support
if 'nb_epoch' in kwargs:
logging.warning(
@ -880,6 +883,7 @@ class Model(network.Network):
ValueError: in case of invalid arguments.
"""
_keras_api_gauge.get_cell('evaluate').set(True)
version_utils.disallow_legacy_graph('Model', 'evaluate')
self._assert_compile_was_called()
self._check_call_args('evaluate')
@ -959,6 +963,7 @@ class Model(network.Network):
that is not a multiple of the batch size.
"""
_keras_api_gauge.get_cell('predict').set(True)
version_utils.disallow_legacy_graph('Model', 'predict')
self._check_call_args('predict')
func = self._select_training_loop(x)

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import json
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import input_spec
@ -55,6 +56,9 @@ network_lib = LazyLoader(
training_lib = LazyLoader(
"training_lib", globals(),
"tensorflow.python.keras.engine.training")
training_lib_v1 = LazyLoader(
"training_lib", globals(),
"tensorflow.python.keras.engine.training_v1")
# pylint:enable=g-inconsistent-quotes
@ -196,11 +200,16 @@ class KerasObjectLoader(tf_load.Loader):
# pylint: enable=protected-access
def _recreate_base_user_object(self, proto):
if ops.executing_eagerly_outside_functions():
model_class = training_lib.Model
else:
model_class = training_lib_v1.Model
revived_classes = {
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
'_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
'_tf_keras_model': (RevivedNetwork, training_lib.Model),
'_tf_keras_model': (RevivedNetwork, model_class),
'_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential)
}
@ -210,9 +219,7 @@ class KerasObjectLoader(tf_load.Loader):
parent_classes = revived_classes[proto.identifier]
metadata = json.loads(proto.metadata)
revived_cls = type(
compat.as_str(metadata['class_name']),
parent_classes,
{'__setattr__': parent_classes[1].__setattr__})
compat.as_str(metadata['class_name']), parent_classes, {})
return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
@ -377,4 +384,3 @@ def _set_network_attributes_from_metadata(revived_obj):
revived_obj.activity_regularizer = regularizers.deserialize(
metadata['activity_regularizer'])
# pylint:enable=protected-access

View File

@ -92,7 +92,7 @@ class LayerWithUpdate(keras.layers.Layer):
return inputs
@test_util.run_all_in_graph_and_eager_modes
@keras_parameterized.run_all_keras_modes
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
@ -264,6 +264,11 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
def test_compiled_model(self):
# TODO(b/134519980): Issue with model.fit if the model call function uses
# a tf.function (Graph mode only).
if not context.executing_eagerly():
return
input_arr = np.random.random((1, 3))
target_arr = np.random.random((1, 4))
@ -275,21 +280,18 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
# TODO(b/134519980): Issue with model.fit if the model call function uses
# a tf.function (Graph mode only).
with context.eager_mode():
loaded = keras_load.load(saved_model_dir)
actual_predict = loaded.predict(input_arr)
self.assertAllClose(expected_predict, actual_predict)
loaded = keras_load.load(saved_model_dir)
actual_predict = loaded.predict(input_arr)
self.assertAllClose(expected_predict, actual_predict)
loss_before = loaded.evaluate(input_arr, target_arr)
loaded.fit(input_arr, target_arr)
loss_after = loaded.evaluate(input_arr, target_arr)
self.assertLess(loss_after, loss_before)
predict = loaded.predict(input_arr)
loss_before = loaded.evaluate(input_arr, target_arr)
loaded.fit(input_arr, target_arr)
loss_after = loaded.evaluate(input_arr, target_arr)
self.assertLess(loss_after, loss_before)
predict = loaded.predict(input_arr)
ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
loaded.save_weights(ckpt_path)
ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
loaded.save_weights(ckpt_path)
# Ensure that the checkpoint is compatible with the original model.
model.load_weights(ckpt_path)

View File

@ -0,0 +1,70 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities for Keras classes with v1 and v2 versions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.util import lazy_loader
# TODO(b/134426265): Switch back to single-quotes once the issue
# with copybara is fixed.
# pylint: disable=g-inconsistent-quotes
training = lazy_loader.LazyLoader(
"training", globals(),
"tensorflow.python.keras.engine.training")
training_v1 = lazy_loader.LazyLoader(
"training_v1", globals(),
"tensorflow.python.keras.engine.training_v1")
# pylint: enable=g-inconsistent-quotes
# TODO(omalleyt): Extend to Layer class once Layer class is split.
class VersionSelector(object):
"""Chooses between Keras v1 and v2 Model class."""
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
new_cls = swap_class(cls, training.Model, training_v1.Model)
return object.__new__(new_cls)
def swap_class(cls, v2_cls, v1_cls):
"""Swaps in v2_cls or v1_cls depending on graph mode."""
if cls == object:
return cls
if cls in (v2_cls, v1_cls):
if ops.executing_eagerly_outside_functions():
return v2_cls
return v1_cls
# Recursively search superclasses to swap in the right Keras class.
cls.__bases__ = tuple(
swap_class(base, v2_cls, v1_cls) for base in cls.__bases__)
return cls
def disallow_legacy_graph(cls_name, method_name):
if not ops.executing_eagerly_outside_functions():
error_msg = (
"Calling `{cls_name}.{method_name}` in graph mode is not supported "
"when the `{cls_name}` instance was constructed with eager mode "
"enabled. Please construct your `{cls_name}` instance in graph mode or"
" call `{cls_name}.{method_name}` with eager mode enabled.")
error_msg = error_msg.format(cls_name=cls_name, method_name=method_name)
raise ValueError(error_msg)

View File

@ -0,0 +1,133 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras utilities to split v1 and v2 classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import numpy as np
import six
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training_v1
from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes
class SplitUtilsTest(keras_parameterized.TestCase):
def _check_model_class(self, model_class):
if ops.executing_eagerly_outside_functions():
self.assertEqual(model_class, training.Model)
else:
self.assertEqual(model_class, training_v1.Model)
def test_functional_model(self):
inputs = keras.Input(10)
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
self._check_model_class(model.__class__)
def test_sequential_model(self):
model = keras.Sequential([keras.layers.Dense(1)])
model_class = model.__class__.__bases__[0]
self._check_model_class(model_class)
def test_subclass_model(self):
class MyModel(keras.Model):
def call(self, x):
return 2 * x
model = MyModel()
model_class = model.__class__.__bases__[0]
self._check_model_class(model_class)
def test_multiple_subclass_model(self):
class Model1(keras.Model):
pass
class Model2(Model1):
def call(self, x):
return 2 * x
model = Model2()
model_class = model.__class__.__bases__[0].__bases__[0]
self._check_model_class(model_class)
def test_user_provided_metaclass(self):
@six.add_metaclass(abc.ABCMeta)
class AbstractModel(keras.Model):
@abc.abstractmethod
def call(self, inputs):
"""Calls the model."""
class MyModel(AbstractModel):
def call(self, inputs):
return 2 * inputs
with self.assertRaisesRegexp(TypeError, 'instantiate abstract class'):
AbstractModel()
model = MyModel()
model_class = model.__class__.__bases__[0].__bases__[0]
self._check_model_class(model_class)
def test_multiple_inheritance(self):
class Return2(object):
def return_2(self):
return 2
class MyModel(keras.Model, Return2):
def call(self, x):
return self.return_2() * x
model = MyModel()
bases = model.__class__.__bases__
self._check_model_class(bases[0])
self.assertEqual(bases[1], Return2)
self.assertEqual(model.return_2(), 2)
def test_fit_error(self):
if not ops.executing_eagerly_outside_functions():
# Error only appears on the v2 class.
return
model = keras.Sequential([keras.layers.Dense(1)])
model.compile('sgd', 'mse')
x, y = np.ones((10, 10)), np.ones((10, 1))
with context.graph_mode():
with self.assertRaisesRegexp(
ValueError, 'instance was constructed with eager mode enabled'):
model.fit(x, y, batch_size=2)
if __name__ == '__main__':
test.main()

View File

@ -6,6 +6,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -6,6 +6,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -6,6 +6,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -6,6 +6,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"

View File

@ -7,6 +7,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.VersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"