Separate Keras and Estimator ModeKeys into different classes to avoid headache with updating Estimator ModeKeys.
Detailed list of changes: * Revert Estimator Modekeys back V1 ModeKeys. * Added helper functions in SavedModel model utils to deal with different Estimator and Keras Mode Keys. * Estimator and Keras ModeKeys now reside in saved_model/model_utils/mode_keys. PiperOrigin-RevId: 233484078
This commit is contained in:
parent
f5463b160c
commit
c1043a02f9
@ -31,10 +31,10 @@ from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import distributed_training_utils
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import rmsprop
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
|
||||
_RANDOM_SEED = 1337
|
||||
_TRAIN_SIZE = 200
|
||||
|
@ -143,7 +143,6 @@ py_library(
|
||||
":map_fn",
|
||||
":math_ops",
|
||||
":metrics",
|
||||
":mode_keys",
|
||||
":nccl_ops",
|
||||
":nn",
|
||||
":ops",
|
||||
@ -6076,29 +6075,6 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mode_keys",
|
||||
srcs = [
|
||||
"training/mode_keys.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "mode_keys_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"training/mode_keys_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":client_testlib",
|
||||
":mode_keys",
|
||||
],
|
||||
)
|
||||
|
||||
pyx_library(
|
||||
name = "framework_fast_tensor_util",
|
||||
srcs = ["framework/fast_tensor_util.pyx"],
|
||||
|
@ -161,6 +161,7 @@ py_library(
|
||||
":engine_utils",
|
||||
":initializers",
|
||||
":losses",
|
||||
":mode_keys",
|
||||
":optimizers",
|
||||
":regularizers",
|
||||
":saving",
|
||||
@ -188,9 +189,9 @@ py_library(
|
||||
deps = [
|
||||
":backend",
|
||||
":engine_utils",
|
||||
":mode_keys",
|
||||
":optimizers",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:mode_keys",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model/model_utils",
|
||||
@ -218,6 +219,7 @@ py_library(
|
||||
deps = [
|
||||
":backend",
|
||||
":engine_utils",
|
||||
":mode_keys",
|
||||
],
|
||||
)
|
||||
|
||||
@ -369,6 +371,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mode_keys",
|
||||
srcs = [
|
||||
"utils/mode_keys.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/saved_model/model_utils:mode_keys",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "integration_test",
|
||||
size = "medium",
|
||||
|
@ -36,10 +36,10 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.data_utils import Sequence
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
try:
|
||||
|
@ -33,11 +33,11 @@ from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
@ -46,11 +46,11 @@ from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import optimizer as tf_optimizer_module
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
@ -33,8 +33,8 @@ from tensorflow.python.keras.engine import distributed_training_utils
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import make_batches
|
||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
|
||||
try:
|
||||
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
|
||||
|
@ -34,9 +34,9 @@ from tensorflow.python.keras.engine import partial_batch_padding_handler as padd
|
||||
from tensorflow.python.keras.engine import training_arrays
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
|
@ -33,8 +33,8 @@ from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.saving import model_from_json
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -35,7 +36,6 @@ from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training import mode_keys
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.checkpointable import graph_view
|
||||
from tensorflow.python.util import compat
|
||||
|
@ -33,12 +33,12 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.training import mode_keys
|
||||
from tensorflow.python.training import training as training_module
|
||||
|
||||
|
||||
|
@ -12,22 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Model modeKeys for TensorFlow and Estimator."""
|
||||
"""Keras model mode constants."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class ModeKeys(object):
|
||||
"""Standard names for model modes.
|
||||
|
||||
The following standard keys are defined:
|
||||
|
||||
* `TRAIN`: training/fitting mode.
|
||||
* `TEST`: testing/evaluation mode.
|
||||
* `PREDICT`: prediction/inference mode.
|
||||
"""
|
||||
|
||||
TRAIN = 'train'
|
||||
TEST = 'test'
|
||||
PREDICT = 'predict'
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
|
||||
# pylint: enable=unused-import
|
@ -30,6 +30,7 @@ py_library(
|
||||
deps = [
|
||||
":export_output",
|
||||
":export_utils",
|
||||
":mode_keys",
|
||||
],
|
||||
)
|
||||
|
||||
@ -70,7 +71,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":export_output",
|
||||
"//tensorflow/python:mode_keys",
|
||||
":mode_keys",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
@ -98,3 +99,19 @@ py_test(
|
||||
"//tensorflow/python/saved_model:signature_def_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mode_keys",
|
||||
srcs = ["mode_keys.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "mode_keys_test",
|
||||
srcs = ["mode_keys_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":mode_keys",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
@ -30,29 +30,26 @@ from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
|
||||
from tensorflow.python.training import mode_keys
|
||||
from tensorflow.python.saved_model.model_utils import mode_keys
|
||||
from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
# Mapping of the modes to appropriate MetaGraph tags in the SavedModel.
|
||||
EXPORT_TAG_MAP = {
|
||||
mode_keys.ModeKeys.PREDICT: [tag_constants.SERVING],
|
||||
mode_keys.ModeKeys.TRAIN: [tag_constants.TRAINING],
|
||||
mode_keys.ModeKeys.TEST: [tag_constants.EVAL],
|
||||
}
|
||||
EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{
|
||||
ModeKeys.PREDICT: [tag_constants.SERVING],
|
||||
ModeKeys.TRAIN: [tag_constants.TRAINING],
|
||||
ModeKeys.TEST: [tag_constants.EVAL]})
|
||||
|
||||
# For every exported mode, a SignatureDef map should be created using the
|
||||
# functions `export_outputs_for_mode` and `build_all_signature_defs`. By
|
||||
# default, this map will contain a single Signature that defines the input
|
||||
# tensors and output predictions, losses, and/or metrics (depending on the mode)
|
||||
# The default keys used in the SignatureDef map are defined below.
|
||||
SIGNATURE_KEY_MAP = {
|
||||
mode_keys.ModeKeys.PREDICT:
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
||||
mode_keys.ModeKeys.TRAIN:
|
||||
signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
|
||||
mode_keys.ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY
|
||||
}
|
||||
SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{
|
||||
ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
||||
ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
|
||||
ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY})
|
||||
|
||||
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
|
||||
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
|
||||
@ -281,9 +278,9 @@ def export_outputs_for_mode(
|
||||
'this function. Please ensure that you are using the new ModeKeys.'
|
||||
.format(mode, SIGNATURE_KEY_MAP.keys()))
|
||||
signature_key = SIGNATURE_KEY_MAP[mode]
|
||||
if mode == mode_keys.ModeKeys.PREDICT:
|
||||
if mode_keys.is_predict(mode):
|
||||
return get_export_outputs(serving_export_outputs, predictions)
|
||||
elif mode == mode_keys.ModeKeys.TRAIN:
|
||||
elif mode_keys.is_eval(mode):
|
||||
return {signature_key: export_output_lib.TrainOutput(
|
||||
loss=loss, predictions=predictions, metrics=metrics)}
|
||||
else:
|
||||
|
124
tensorflow/python/saved_model/model_utils/mode_keys.py
Normal file
124
tensorflow/python/saved_model/model_utils/mode_keys.py
Normal file
@ -0,0 +1,124 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils for managing different mode strings used by Keras and Estimator models.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
class KerasModeKeys(object):
|
||||
"""Standard names for model modes.
|
||||
|
||||
The following standard keys are defined:
|
||||
|
||||
* `TRAIN`: training/fitting mode.
|
||||
* `TEST`: testing/evaluation mode.
|
||||
* `PREDICT`: prediction/inference mode.
|
||||
"""
|
||||
|
||||
TRAIN = 'train'
|
||||
TEST = 'test'
|
||||
PREDICT = 'predict'
|
||||
|
||||
|
||||
class ModeKeys(object):
|
||||
"""Standard names for model modes.
|
||||
|
||||
The following standard keys are defined:
|
||||
|
||||
* `TRAIN`: training/fitting mode.
|
||||
* `TEST`: testing/evaluation mode.
|
||||
* `PREDICT`: prediction/inference mode.
|
||||
"""
|
||||
|
||||
TRAIN = 'train'
|
||||
TEST = 'test'
|
||||
PREDICT = 'predict'
|
||||
|
||||
|
||||
# TODO(kathywu): Remove copy in Estimator after nightlies
|
||||
class EstimatorModeKeys(object):
|
||||
"""Standard names for Estimator model modes.
|
||||
|
||||
The following standard keys are defined:
|
||||
|
||||
* `TRAIN`: training/fitting mode.
|
||||
* `EVAL`: testing/evaluation mode.
|
||||
* `PREDICT`: predication/inference mode.
|
||||
"""
|
||||
|
||||
TRAIN = 'train'
|
||||
EVAL = 'eval'
|
||||
PREDICT = 'infer'
|
||||
|
||||
|
||||
def is_predict(mode):
|
||||
return mode in [KerasModeKeys.PREDICT, EstimatorModeKeys.PREDICT]
|
||||
|
||||
|
||||
def is_eval(mode):
|
||||
return mode in [KerasModeKeys.TEST, EstimatorModeKeys.EVAL]
|
||||
|
||||
|
||||
def is_train(mode):
|
||||
return mode in [KerasModeKeys.TRAIN, EstimatorModeKeys.TRAIN]
|
||||
|
||||
|
||||
class ModeKeyMap(collections.Mapping):
|
||||
"""Map using ModeKeys as keys.
|
||||
|
||||
This class creates an immutable mapping from modes to values. For example,
|
||||
SavedModel export of Keras and Estimator models use this to map modes to their
|
||||
corresponding MetaGraph tags/SignatureDef keys.
|
||||
|
||||
Since this class uses modes, rather than strings, as keys, both "predict"
|
||||
(Keras's PREDICT ModeKey) and "infer" (Estimator's PREDICT ModeKey) map to the
|
||||
same value.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._internal_dict = {}
|
||||
self._keys = []
|
||||
for key in kwargs:
|
||||
self._keys.append(key)
|
||||
dict_key = self._get_internal_key(key)
|
||||
if dict_key in self._internal_dict:
|
||||
raise ValueError(
|
||||
'Error creating ModeKeyMap. Multiple keys/values found for {} mode.'
|
||||
.format(dict_key))
|
||||
self._internal_dict[dict_key] = kwargs[key]
|
||||
|
||||
def _get_internal_key(self, key):
|
||||
"""Return keys used for the internal dictionary."""
|
||||
if is_train(key):
|
||||
return KerasModeKeys.TRAIN
|
||||
if is_eval(key):
|
||||
return KerasModeKeys.TEST
|
||||
if is_predict(key):
|
||||
return KerasModeKeys.PREDICT
|
||||
raise ValueError('Invalid mode key: {}.'.format(key))
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._internal_dict[self._get_internal_key(key)]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._keys)
|
65
tensorflow/python/saved_model/model_utils/mode_keys_test.py
Normal file
65
tensorflow/python/saved_model/model_utils/mode_keys_test.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ModeKey Tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model.model_utils import mode_keys
|
||||
|
||||
|
||||
class ModeKeyMapTest(test.TestCase):
|
||||
|
||||
def test_map(self):
|
||||
mode_map = mode_keys.ModeKeyMap(**{
|
||||
mode_keys.KerasModeKeys.PREDICT: 3,
|
||||
mode_keys.KerasModeKeys.TEST: 1
|
||||
})
|
||||
|
||||
# Test dictionary __getitem__
|
||||
self.assertEqual(3, mode_map[mode_keys.KerasModeKeys.PREDICT])
|
||||
self.assertEqual(3, mode_map[mode_keys.EstimatorModeKeys.PREDICT])
|
||||
self.assertEqual(1, mode_map[mode_keys.KerasModeKeys.TEST])
|
||||
self.assertEqual(1, mode_map[mode_keys.EstimatorModeKeys.EVAL])
|
||||
with self.assertRaises(KeyError):
|
||||
_ = mode_map[mode_keys.KerasModeKeys.TRAIN]
|
||||
with self.assertRaises(KeyError):
|
||||
_ = mode_map[mode_keys.EstimatorModeKeys.TRAIN]
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid mode'):
|
||||
_ = mode_map['serve']
|
||||
|
||||
# Test common dictionary methods
|
||||
self.assertLen(mode_map, 2)
|
||||
self.assertEqual({1, 3}, set(mode_map.values()))
|
||||
self.assertEqual(
|
||||
{mode_keys.KerasModeKeys.TEST, mode_keys.KerasModeKeys.PREDICT},
|
||||
set(mode_map.keys()))
|
||||
|
||||
# Map is immutable
|
||||
with self.assertRaises(TypeError):
|
||||
mode_map[mode_keys.KerasModeKeys.TEST] = 1
|
||||
|
||||
def test_invalid_init(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Multiple keys/values found'):
|
||||
_ = mode_keys.ModeKeyMap(**{
|
||||
mode_keys.KerasModeKeys.PREDICT: 3,
|
||||
mode_keys.EstimatorModeKeys.PREDICT: 1
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,29 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.train.ModeKeys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import mode_keys
|
||||
|
||||
|
||||
class ModeKeysTest(test.TestCase):
|
||||
|
||||
def testKeyEquality(self):
|
||||
self.assertEqual(mode_keys.ModeKeys.PREDICT, 'predict')
|
||||
self.assertEqual(mode_keys.ModeKeys.TRAIN, 'train')
|
||||
self.assertEqual(mode_keys.ModeKeys.TEST, 'test')
|
Loading…
x
Reference in New Issue
Block a user