Separate Keras and Estimator ModeKeys into different classes to avoid headache with updating Estimator ModeKeys.

Detailed list of changes:
* Revert Estimator Modekeys back V1 ModeKeys.
* Added helper functions in SavedModel model utils to deal with different Estimator and Keras Mode Keys.
* Estimator and Keras ModeKeys now reside in saved_model/model_utils/mode_keys.

PiperOrigin-RevId: 233484078
This commit is contained in:
Katherine Wu 2019-02-11 14:51:51 -08:00 committed by TensorFlower Gardener
parent f5463b160c
commit c1043a02f9
17 changed files with 246 additions and 94 deletions

View File

@ -31,10 +31,10 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
from tensorflow.python.training.mode_keys import ModeKeys
_RANDOM_SEED = 1337
_TRAIN_SIZE = 200

View File

@ -143,7 +143,6 @@ py_library(
":map_fn",
":math_ops",
":metrics",
":mode_keys",
":nccl_ops",
":nn",
":ops",
@ -6076,29 +6075,6 @@ py_binary(
],
)
py_library(
name = "mode_keys",
srcs = [
"training/mode_keys.py",
],
srcs_version = "PY2AND3",
deps = [
":util",
],
)
tf_py_test(
name = "mode_keys_test",
size = "small",
srcs = [
"training/mode_keys_test.py",
],
additional_deps = [
":client_testlib",
":mode_keys",
],
)
pyx_library(
name = "framework_fast_tensor_util",
srcs = ["framework/fast_tensor_util.pyx"],

View File

@ -161,6 +161,7 @@ py_library(
":engine_utils",
":initializers",
":losses",
":mode_keys",
":optimizers",
":regularizers",
":saving",
@ -188,9 +189,9 @@ py_library(
deps = [
":backend",
":engine_utils",
":mode_keys",
":optimizers",
"//tensorflow/python:lib",
"//tensorflow/python:mode_keys",
"//tensorflow/python:saver",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model/model_utils",
@ -218,6 +219,7 @@ py_library(
deps = [
":backend",
":engine_utils",
":mode_keys",
],
)
@ -369,6 +371,17 @@ py_library(
],
)
py_library(
name = "mode_keys",
srcs = [
"utils/mode_keys.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/saved_model/model_utils:mode_keys",
],
)
tf_py_test(
name = "integration_test",
size = "medium",

View File

@ -36,10 +36,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util.tf_export import keras_export
try:

View File

@ -33,11 +33,11 @@ from tensorflow.python.keras import callbacks
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib

View File

@ -46,11 +46,11 @@ from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export

View File

@ -33,8 +33,8 @@ from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.generic_utils import make_batches
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top

View File

@ -34,9 +34,9 @@ from tensorflow.python.keras.engine import partial_batch_padding_handler as padd
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest

View File

@ -33,8 +33,8 @@ from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest

View File

@ -27,6 +27,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.saving import model_from_json
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@ -35,7 +36,6 @@ from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import model_utils
from tensorflow.python.saved_model import save as save_lib
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import mode_keys
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.util import compat

View File

@ -33,12 +33,12 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.saving import saved_model as keras_saved_model
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import model_utils
from tensorflow.python.training import mode_keys
from tensorflow.python.training import training as training_module

View File

@ -12,22 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model modeKeys for TensorFlow and Estimator."""
"""Keras model mode constants."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class ModeKeys(object):
"""Standard names for model modes.
The following standard keys are defined:
* `TRAIN`: training/fitting mode.
* `TEST`: testing/evaluation mode.
* `PREDICT`: prediction/inference mode.
"""
TRAIN = 'train'
TEST = 'test'
PREDICT = 'predict'
# pylint: disable=unused-import
from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
# pylint: enable=unused-import

View File

@ -30,6 +30,7 @@ py_library(
deps = [
":export_output",
":export_utils",
":mode_keys",
],
)
@ -70,7 +71,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":export_output",
"//tensorflow/python:mode_keys",
":mode_keys",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python/saved_model:signature_constants",
@ -98,3 +99,19 @@ py_test(
"//tensorflow/python/saved_model:signature_def_utils",
],
)
py_library(
name = "mode_keys",
srcs = ["mode_keys.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "mode_keys_test",
srcs = ["mode_keys_test.py"],
srcs_version = "PY2AND3",
deps = [
":mode_keys",
"//tensorflow/python:client_testlib",
],
)

View File

@ -30,29 +30,26 @@ from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
from tensorflow.python.training import mode_keys
from tensorflow.python.saved_model.model_utils import mode_keys
from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
from tensorflow.python.util import compat
# Mapping of the modes to appropriate MetaGraph tags in the SavedModel.
EXPORT_TAG_MAP = {
mode_keys.ModeKeys.PREDICT: [tag_constants.SERVING],
mode_keys.ModeKeys.TRAIN: [tag_constants.TRAINING],
mode_keys.ModeKeys.TEST: [tag_constants.EVAL],
}
EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{
ModeKeys.PREDICT: [tag_constants.SERVING],
ModeKeys.TRAIN: [tag_constants.TRAINING],
ModeKeys.TEST: [tag_constants.EVAL]})
# For every exported mode, a SignatureDef map should be created using the
# functions `export_outputs_for_mode` and `build_all_signature_defs`. By
# default, this map will contain a single Signature that defines the input
# tensors and output predictions, losses, and/or metrics (depending on the mode)
# The default keys used in the SignatureDef map are defined below.
SIGNATURE_KEY_MAP = {
mode_keys.ModeKeys.PREDICT:
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
mode_keys.ModeKeys.TRAIN:
signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
mode_keys.ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY
}
SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{
ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY})
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
@ -281,9 +278,9 @@ def export_outputs_for_mode(
'this function. Please ensure that you are using the new ModeKeys.'
.format(mode, SIGNATURE_KEY_MAP.keys()))
signature_key = SIGNATURE_KEY_MAP[mode]
if mode == mode_keys.ModeKeys.PREDICT:
if mode_keys.is_predict(mode):
return get_export_outputs(serving_export_outputs, predictions)
elif mode == mode_keys.ModeKeys.TRAIN:
elif mode_keys.is_eval(mode):
return {signature_key: export_output_lib.TrainOutput(
loss=loss, predictions=predictions, metrics=metrics)}
else:

View File

@ -0,0 +1,124 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for managing different mode strings used by Keras and Estimator models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
class KerasModeKeys(object):
"""Standard names for model modes.
The following standard keys are defined:
* `TRAIN`: training/fitting mode.
* `TEST`: testing/evaluation mode.
* `PREDICT`: prediction/inference mode.
"""
TRAIN = 'train'
TEST = 'test'
PREDICT = 'predict'
class ModeKeys(object):
"""Standard names for model modes.
The following standard keys are defined:
* `TRAIN`: training/fitting mode.
* `TEST`: testing/evaluation mode.
* `PREDICT`: prediction/inference mode.
"""
TRAIN = 'train'
TEST = 'test'
PREDICT = 'predict'
# TODO(kathywu): Remove copy in Estimator after nightlies
class EstimatorModeKeys(object):
"""Standard names for Estimator model modes.
The following standard keys are defined:
* `TRAIN`: training/fitting mode.
* `EVAL`: testing/evaluation mode.
* `PREDICT`: predication/inference mode.
"""
TRAIN = 'train'
EVAL = 'eval'
PREDICT = 'infer'
def is_predict(mode):
return mode in [KerasModeKeys.PREDICT, EstimatorModeKeys.PREDICT]
def is_eval(mode):
return mode in [KerasModeKeys.TEST, EstimatorModeKeys.EVAL]
def is_train(mode):
return mode in [KerasModeKeys.TRAIN, EstimatorModeKeys.TRAIN]
class ModeKeyMap(collections.Mapping):
"""Map using ModeKeys as keys.
This class creates an immutable mapping from modes to values. For example,
SavedModel export of Keras and Estimator models use this to map modes to their
corresponding MetaGraph tags/SignatureDef keys.
Since this class uses modes, rather than strings, as keys, both "predict"
(Keras's PREDICT ModeKey) and "infer" (Estimator's PREDICT ModeKey) map to the
same value.
"""
def __init__(self, **kwargs):
self._internal_dict = {}
self._keys = []
for key in kwargs:
self._keys.append(key)
dict_key = self._get_internal_key(key)
if dict_key in self._internal_dict:
raise ValueError(
'Error creating ModeKeyMap. Multiple keys/values found for {} mode.'
.format(dict_key))
self._internal_dict[dict_key] = kwargs[key]
def _get_internal_key(self, key):
"""Return keys used for the internal dictionary."""
if is_train(key):
return KerasModeKeys.TRAIN
if is_eval(key):
return KerasModeKeys.TEST
if is_predict(key):
return KerasModeKeys.PREDICT
raise ValueError('Invalid mode key: {}.'.format(key))
def __getitem__(self, key):
return self._internal_dict[self._get_internal_key(key)]
def __iter__(self):
return iter(self._keys)
def __len__(self):
return len(self._keys)

View File

@ -0,0 +1,65 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ModeKey Tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import test
from tensorflow.python.saved_model.model_utils import mode_keys
class ModeKeyMapTest(test.TestCase):
def test_map(self):
mode_map = mode_keys.ModeKeyMap(**{
mode_keys.KerasModeKeys.PREDICT: 3,
mode_keys.KerasModeKeys.TEST: 1
})
# Test dictionary __getitem__
self.assertEqual(3, mode_map[mode_keys.KerasModeKeys.PREDICT])
self.assertEqual(3, mode_map[mode_keys.EstimatorModeKeys.PREDICT])
self.assertEqual(1, mode_map[mode_keys.KerasModeKeys.TEST])
self.assertEqual(1, mode_map[mode_keys.EstimatorModeKeys.EVAL])
with self.assertRaises(KeyError):
_ = mode_map[mode_keys.KerasModeKeys.TRAIN]
with self.assertRaises(KeyError):
_ = mode_map[mode_keys.EstimatorModeKeys.TRAIN]
with self.assertRaisesRegexp(ValueError, 'Invalid mode'):
_ = mode_map['serve']
# Test common dictionary methods
self.assertLen(mode_map, 2)
self.assertEqual({1, 3}, set(mode_map.values()))
self.assertEqual(
{mode_keys.KerasModeKeys.TEST, mode_keys.KerasModeKeys.PREDICT},
set(mode_map.keys()))
# Map is immutable
with self.assertRaises(TypeError):
mode_map[mode_keys.KerasModeKeys.TEST] = 1
def test_invalid_init(self):
with self.assertRaisesRegexp(ValueError, 'Multiple keys/values found'):
_ = mode_keys.ModeKeyMap(**{
mode_keys.KerasModeKeys.PREDICT: 3,
mode_keys.EstimatorModeKeys.PREDICT: 1
})
if __name__ == '__main__':
test.main()

View File

@ -1,29 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.train.ModeKeys."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import test
from tensorflow.python.training import mode_keys
class ModeKeysTest(test.TestCase):
def testKeyEquality(self):
self.assertEqual(mode_keys.ModeKeys.PREDICT, 'predict')
self.assertEqual(mode_keys.ModeKeys.TRAIN, 'train')
self.assertEqual(mode_keys.ModeKeys.TEST, 'test')