Separate Keras and Estimator ModeKeys into different classes to avoid headache with updating Estimator ModeKeys.
Detailed list of changes: * Revert Estimator Modekeys back V1 ModeKeys. * Added helper functions in SavedModel model utils to deal with different Estimator and Keras Mode Keys. * Estimator and Keras ModeKeys now reside in saved_model/model_utils/mode_keys. PiperOrigin-RevId: 233484078
This commit is contained in:
		
							parent
							
								
									f5463b160c
								
							
						
					
					
						commit
						c1043a02f9
					
				| @ -31,10 +31,10 @@ from tensorflow.python.framework import random_seed | ||||
| from tensorflow.python.keras import testing_utils | ||||
| from tensorflow.python.keras.engine import distributed_training_utils | ||||
| from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.ops.parsing_ops import gen_parsing_ops | ||||
| from tensorflow.python.training import gradient_descent | ||||
| from tensorflow.python.training import rmsprop | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| 
 | ||||
| _RANDOM_SEED = 1337 | ||||
| _TRAIN_SIZE = 200 | ||||
|  | ||||
| @ -143,7 +143,6 @@ py_library( | ||||
|         ":map_fn", | ||||
|         ":math_ops", | ||||
|         ":metrics", | ||||
|         ":mode_keys", | ||||
|         ":nccl_ops", | ||||
|         ":nn", | ||||
|         ":ops", | ||||
| @ -6076,29 +6075,6 @@ py_binary( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "mode_keys", | ||||
|     srcs = [ | ||||
|         "training/mode_keys.py", | ||||
|     ], | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         ":util", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_py_test( | ||||
|     name = "mode_keys_test", | ||||
|     size = "small", | ||||
|     srcs = [ | ||||
|         "training/mode_keys_test.py", | ||||
|     ], | ||||
|     additional_deps = [ | ||||
|         ":client_testlib", | ||||
|         ":mode_keys", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| pyx_library( | ||||
|     name = "framework_fast_tensor_util", | ||||
|     srcs = ["framework/fast_tensor_util.pyx"], | ||||
|  | ||||
| @ -161,6 +161,7 @@ py_library( | ||||
|         ":engine_utils", | ||||
|         ":initializers", | ||||
|         ":losses", | ||||
|         ":mode_keys", | ||||
|         ":optimizers", | ||||
|         ":regularizers", | ||||
|         ":saving", | ||||
| @ -188,9 +189,9 @@ py_library( | ||||
|     deps = [ | ||||
|         ":backend", | ||||
|         ":engine_utils", | ||||
|         ":mode_keys", | ||||
|         ":optimizers", | ||||
|         "//tensorflow/python:lib", | ||||
|         "//tensorflow/python:mode_keys", | ||||
|         "//tensorflow/python:saver", | ||||
|         "//tensorflow/python/saved_model", | ||||
|         "//tensorflow/python/saved_model/model_utils", | ||||
| @ -218,6 +219,7 @@ py_library( | ||||
|     deps = [ | ||||
|         ":backend", | ||||
|         ":engine_utils", | ||||
|         ":mode_keys", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| @ -369,6 +371,17 @@ py_library( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "mode_keys", | ||||
|     srcs = [ | ||||
|         "utils/mode_keys.py", | ||||
|     ], | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         "//tensorflow/python/saved_model/model_utils:mode_keys", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_py_test( | ||||
|     name = "integration_test", | ||||
|     size = "medium", | ||||
|  | ||||
| @ -36,10 +36,10 @@ from tensorflow.python.framework import ops | ||||
| from tensorflow.python.keras import backend as K | ||||
| from tensorflow.python.keras.utils.data_utils import Sequence | ||||
| from tensorflow.python.keras.utils.generic_utils import Progbar | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import summary_ops_v2 | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| from tensorflow.python.util.tf_export import keras_export | ||||
| 
 | ||||
| try: | ||||
|  | ||||
| @ -33,11 +33,11 @@ from tensorflow.python.keras import callbacks | ||||
| from tensorflow.python.keras import metrics as metrics_module | ||||
| from tensorflow.python.keras import optimizers | ||||
| from tensorflow.python.keras.optimizer_v2 import optimizer_v2 | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.ops import variables | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| from tensorflow.python.util import nest | ||||
| from tensorflow.python.util import tf_contextlib | ||||
| 
 | ||||
|  | ||||
| @ -46,11 +46,11 @@ from tensorflow.python.keras.saving import saving_utils | ||||
| from tensorflow.python.keras.utils import data_utils | ||||
| from tensorflow.python.keras.utils import losses_utils | ||||
| from tensorflow.python.keras.utils.generic_utils import slice_arrays | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training import optimizer as tf_optimizer_module | ||||
| from tensorflow.python.training.checkpointable import base as checkpointable | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| from tensorflow.python.util import nest | ||||
| from tensorflow.python.util.tf_export import keras_export | ||||
| 
 | ||||
|  | ||||
| @ -33,8 +33,8 @@ from tensorflow.python.keras.engine import distributed_training_utils | ||||
| from tensorflow.python.keras.engine import training_utils | ||||
| from tensorflow.python.keras.utils.generic_utils import make_batches | ||||
| from tensorflow.python.keras.utils.generic_utils import slice_arrays | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| 
 | ||||
| try: | ||||
|   from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top | ||||
|  | ||||
| @ -34,9 +34,9 @@ from tensorflow.python.keras.engine import partial_batch_padding_handler as padd | ||||
| from tensorflow.python.keras.engine import training_arrays | ||||
| from tensorflow.python.keras.engine import training_utils | ||||
| from tensorflow.python.keras.utils.generic_utils import Progbar | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| from tensorflow.python.util import nest | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -33,8 +33,8 @@ from tensorflow.python.keras import callbacks as cbks | ||||
| from tensorflow.python.keras.engine import training_utils | ||||
| from tensorflow.python.keras.utils import data_utils | ||||
| from tensorflow.python.keras.utils import generic_utils | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| from tensorflow.python.training.mode_keys import ModeKeys | ||||
| from tensorflow.python.util import nest | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -27,6 +27,7 @@ from tensorflow.python.keras import backend as K | ||||
| from tensorflow.python.keras import optimizers | ||||
| from tensorflow.python.keras.saving import model_from_json | ||||
| from tensorflow.python.keras.saving import saving_utils | ||||
| from tensorflow.python.keras.utils import mode_keys | ||||
| from tensorflow.python.lib.io import file_io | ||||
| from tensorflow.python.ops import variables | ||||
| from tensorflow.python.platform import tf_logging as logging | ||||
| @ -35,7 +36,6 @@ from tensorflow.python.saved_model import constants | ||||
| from tensorflow.python.saved_model import model_utils | ||||
| from tensorflow.python.saved_model import save as save_lib | ||||
| from tensorflow.python.saved_model import utils_impl as saved_model_utils | ||||
| from tensorflow.python.training import mode_keys | ||||
| from tensorflow.python.training import saver as saver_lib | ||||
| from tensorflow.python.training.checkpointable import graph_view | ||||
| from tensorflow.python.util import compat | ||||
|  | ||||
| @ -33,12 +33,12 @@ from tensorflow.python.framework import tensor_spec | ||||
| from tensorflow.python.framework import test_util | ||||
| from tensorflow.python.keras.engine import training | ||||
| from tensorflow.python.keras.saving import saved_model as keras_saved_model | ||||
| from tensorflow.python.keras.utils import mode_keys | ||||
| from tensorflow.python.keras.utils import tf_utils | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.platform import test | ||||
| from tensorflow.python.saved_model import loader_impl | ||||
| from tensorflow.python.saved_model import model_utils | ||||
| from tensorflow.python.training import mode_keys | ||||
| from tensorflow.python.training import training as training_module | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -12,22 +12,11 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Model modeKeys for TensorFlow and Estimator.""" | ||||
| """Keras model mode constants.""" | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| 
 | ||||
| class ModeKeys(object): | ||||
|   """Standard names for model modes. | ||||
| 
 | ||||
|   The following standard keys are defined: | ||||
| 
 | ||||
|   * `TRAIN`: training/fitting mode. | ||||
|   * `TEST`: testing/evaluation mode. | ||||
|   * `PREDICT`: prediction/inference mode. | ||||
|   """ | ||||
| 
 | ||||
|   TRAIN = 'train' | ||||
|   TEST = 'test' | ||||
|   PREDICT = 'predict' | ||||
| # pylint: disable=unused-import | ||||
| from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys | ||||
| # pylint: enable=unused-import | ||||
| @ -30,6 +30,7 @@ py_library( | ||||
|     deps = [ | ||||
|         ":export_output", | ||||
|         ":export_utils", | ||||
|         ":mode_keys", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| @ -70,7 +71,7 @@ py_library( | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         ":export_output", | ||||
|         "//tensorflow/python:mode_keys", | ||||
|         ":mode_keys", | ||||
|         "//tensorflow/python:platform", | ||||
|         "//tensorflow/python:util", | ||||
|         "//tensorflow/python/saved_model:signature_constants", | ||||
| @ -98,3 +99,19 @@ py_test( | ||||
|         "//tensorflow/python/saved_model:signature_def_utils", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "mode_keys", | ||||
|     srcs = ["mode_keys.py"], | ||||
|     srcs_version = "PY2AND3", | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "mode_keys_test", | ||||
|     srcs = ["mode_keys_test.py"], | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         ":mode_keys", | ||||
|         "//tensorflow/python:client_testlib", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| @ -30,29 +30,26 @@ from tensorflow.python.saved_model import signature_constants | ||||
| from tensorflow.python.saved_model import signature_def_utils | ||||
| from tensorflow.python.saved_model import tag_constants | ||||
| from tensorflow.python.saved_model.model_utils import export_output as export_output_lib | ||||
| from tensorflow.python.training import mode_keys | ||||
| from tensorflow.python.saved_model.model_utils import mode_keys | ||||
| from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys | ||||
| from tensorflow.python.util import compat | ||||
| 
 | ||||
| 
 | ||||
| # Mapping of the modes to appropriate MetaGraph tags in the SavedModel. | ||||
| EXPORT_TAG_MAP = { | ||||
|     mode_keys.ModeKeys.PREDICT: [tag_constants.SERVING], | ||||
|     mode_keys.ModeKeys.TRAIN: [tag_constants.TRAINING], | ||||
|     mode_keys.ModeKeys.TEST: [tag_constants.EVAL], | ||||
| } | ||||
| EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{ | ||||
|     ModeKeys.PREDICT: [tag_constants.SERVING], | ||||
|     ModeKeys.TRAIN: [tag_constants.TRAINING], | ||||
|     ModeKeys.TEST: [tag_constants.EVAL]}) | ||||
| 
 | ||||
| # For every exported mode, a SignatureDef map should be created using the | ||||
| # functions `export_outputs_for_mode` and `build_all_signature_defs`. By | ||||
| # default, this map will contain a single Signature that defines the input | ||||
| # tensors and output predictions, losses, and/or metrics (depending on the mode) | ||||
| # The default keys used in the SignatureDef map are defined below. | ||||
| SIGNATURE_KEY_MAP = { | ||||
|     mode_keys.ModeKeys.PREDICT: | ||||
|         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, | ||||
|     mode_keys.ModeKeys.TRAIN: | ||||
|         signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, | ||||
|     mode_keys.ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY | ||||
| } | ||||
| SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{ | ||||
|     ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, | ||||
|     ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, | ||||
|     ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY}) | ||||
| 
 | ||||
| _SINGLE_FEATURE_DEFAULT_NAME = 'feature' | ||||
| _SINGLE_RECEIVER_DEFAULT_NAME = 'input' | ||||
| @ -281,9 +278,9 @@ def export_outputs_for_mode( | ||||
|         'this function. Please ensure that you are using the new ModeKeys.' | ||||
|         .format(mode, SIGNATURE_KEY_MAP.keys())) | ||||
|   signature_key = SIGNATURE_KEY_MAP[mode] | ||||
|   if mode == mode_keys.ModeKeys.PREDICT: | ||||
|   if mode_keys.is_predict(mode): | ||||
|     return get_export_outputs(serving_export_outputs, predictions) | ||||
|   elif mode == mode_keys.ModeKeys.TRAIN: | ||||
|   elif mode_keys.is_eval(mode): | ||||
|     return {signature_key: export_output_lib.TrainOutput( | ||||
|         loss=loss, predictions=predictions, metrics=metrics)} | ||||
|   else: | ||||
|  | ||||
							
								
								
									
										124
									
								
								tensorflow/python/saved_model/model_utils/mode_keys.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								tensorflow/python/saved_model/model_utils/mode_keys.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,124 @@ | ||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Utils for managing different mode strings used by Keras and Estimator models. | ||||
| """ | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import collections | ||||
| 
 | ||||
| 
 | ||||
| class KerasModeKeys(object): | ||||
|   """Standard names for model modes. | ||||
| 
 | ||||
|   The following standard keys are defined: | ||||
| 
 | ||||
|   * `TRAIN`: training/fitting mode. | ||||
|   * `TEST`: testing/evaluation mode. | ||||
|   * `PREDICT`: prediction/inference mode. | ||||
|   """ | ||||
| 
 | ||||
|   TRAIN = 'train' | ||||
|   TEST = 'test' | ||||
|   PREDICT = 'predict' | ||||
| 
 | ||||
| 
 | ||||
| class ModeKeys(object): | ||||
|   """Standard names for model modes. | ||||
| 
 | ||||
|   The following standard keys are defined: | ||||
| 
 | ||||
|   * `TRAIN`: training/fitting mode. | ||||
|   * `TEST`: testing/evaluation mode. | ||||
|   * `PREDICT`: prediction/inference mode. | ||||
|   """ | ||||
| 
 | ||||
|   TRAIN = 'train' | ||||
|   TEST = 'test' | ||||
|   PREDICT = 'predict' | ||||
| 
 | ||||
| 
 | ||||
| # TODO(kathywu): Remove copy in Estimator after nightlies | ||||
| class EstimatorModeKeys(object): | ||||
|   """Standard names for Estimator model modes. | ||||
| 
 | ||||
|   The following standard keys are defined: | ||||
| 
 | ||||
|   * `TRAIN`: training/fitting mode. | ||||
|   * `EVAL`: testing/evaluation mode. | ||||
|   * `PREDICT`: predication/inference mode. | ||||
|   """ | ||||
| 
 | ||||
|   TRAIN = 'train' | ||||
|   EVAL = 'eval' | ||||
|   PREDICT = 'infer' | ||||
| 
 | ||||
| 
 | ||||
| def is_predict(mode): | ||||
|   return mode in [KerasModeKeys.PREDICT, EstimatorModeKeys.PREDICT] | ||||
| 
 | ||||
| 
 | ||||
| def is_eval(mode): | ||||
|   return mode in [KerasModeKeys.TEST, EstimatorModeKeys.EVAL] | ||||
| 
 | ||||
| 
 | ||||
| def is_train(mode): | ||||
|   return mode in [KerasModeKeys.TRAIN, EstimatorModeKeys.TRAIN] | ||||
| 
 | ||||
| 
 | ||||
| class ModeKeyMap(collections.Mapping): | ||||
|   """Map using ModeKeys as keys. | ||||
| 
 | ||||
|   This class creates an immutable mapping from modes to values. For example, | ||||
|   SavedModel export of Keras and Estimator models use this to map modes to their | ||||
|   corresponding MetaGraph tags/SignatureDef keys. | ||||
| 
 | ||||
|   Since this class uses modes, rather than strings, as keys, both "predict" | ||||
|   (Keras's PREDICT ModeKey) and "infer" (Estimator's PREDICT ModeKey) map to the | ||||
|   same value. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, **kwargs): | ||||
|     self._internal_dict = {} | ||||
|     self._keys = [] | ||||
|     for key in kwargs: | ||||
|       self._keys.append(key) | ||||
|       dict_key = self._get_internal_key(key) | ||||
|       if dict_key in self._internal_dict: | ||||
|         raise ValueError( | ||||
|             'Error creating ModeKeyMap. Multiple keys/values found for {} mode.' | ||||
|             .format(dict_key)) | ||||
|       self._internal_dict[dict_key] = kwargs[key] | ||||
| 
 | ||||
|   def _get_internal_key(self, key): | ||||
|     """Return keys used for the internal dictionary.""" | ||||
|     if is_train(key): | ||||
|       return KerasModeKeys.TRAIN | ||||
|     if is_eval(key): | ||||
|       return KerasModeKeys.TEST | ||||
|     if is_predict(key): | ||||
|       return KerasModeKeys.PREDICT | ||||
|     raise ValueError('Invalid mode key: {}.'.format(key)) | ||||
| 
 | ||||
|   def __getitem__(self, key): | ||||
|     return self._internal_dict[self._get_internal_key(key)] | ||||
| 
 | ||||
|   def __iter__(self): | ||||
|     return iter(self._keys) | ||||
| 
 | ||||
|   def __len__(self): | ||||
|     return len(self._keys) | ||||
							
								
								
									
										65
									
								
								tensorflow/python/saved_model/model_utils/mode_keys_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								tensorflow/python/saved_model/model_utils/mode_keys_test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,65 @@ | ||||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """ModeKey Tests.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.platform import test | ||||
| from tensorflow.python.saved_model.model_utils import mode_keys | ||||
| 
 | ||||
| 
 | ||||
| class ModeKeyMapTest(test.TestCase): | ||||
| 
 | ||||
|   def test_map(self): | ||||
|     mode_map = mode_keys.ModeKeyMap(**{ | ||||
|         mode_keys.KerasModeKeys.PREDICT: 3, | ||||
|         mode_keys.KerasModeKeys.TEST: 1 | ||||
|     }) | ||||
| 
 | ||||
|     # Test dictionary __getitem__ | ||||
|     self.assertEqual(3, mode_map[mode_keys.KerasModeKeys.PREDICT]) | ||||
|     self.assertEqual(3, mode_map[mode_keys.EstimatorModeKeys.PREDICT]) | ||||
|     self.assertEqual(1, mode_map[mode_keys.KerasModeKeys.TEST]) | ||||
|     self.assertEqual(1, mode_map[mode_keys.EstimatorModeKeys.EVAL]) | ||||
|     with self.assertRaises(KeyError): | ||||
|       _ = mode_map[mode_keys.KerasModeKeys.TRAIN] | ||||
|     with self.assertRaises(KeyError): | ||||
|       _ = mode_map[mode_keys.EstimatorModeKeys.TRAIN] | ||||
|     with self.assertRaisesRegexp(ValueError, 'Invalid mode'): | ||||
|       _ = mode_map['serve'] | ||||
| 
 | ||||
|     # Test common dictionary methods | ||||
|     self.assertLen(mode_map, 2) | ||||
|     self.assertEqual({1, 3}, set(mode_map.values())) | ||||
|     self.assertEqual( | ||||
|         {mode_keys.KerasModeKeys.TEST, mode_keys.KerasModeKeys.PREDICT}, | ||||
|         set(mode_map.keys())) | ||||
| 
 | ||||
|     # Map is immutable | ||||
|     with self.assertRaises(TypeError): | ||||
|       mode_map[mode_keys.KerasModeKeys.TEST] = 1 | ||||
| 
 | ||||
|   def test_invalid_init(self): | ||||
|     with self.assertRaisesRegexp(ValueError, 'Multiple keys/values found'): | ||||
|       _ = mode_keys.ModeKeyMap(**{ | ||||
|           mode_keys.KerasModeKeys.PREDICT: 3, | ||||
|           mode_keys.EstimatorModeKeys.PREDICT: 1 | ||||
|       }) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   test.main() | ||||
| @ -1,29 +0,0 @@ | ||||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Tests for `tf.train.ModeKeys.""" | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.platform import test | ||||
| from tensorflow.python.training import mode_keys | ||||
| 
 | ||||
| 
 | ||||
| class ModeKeysTest(test.TestCase): | ||||
| 
 | ||||
|   def testKeyEquality(self): | ||||
|     self.assertEqual(mode_keys.ModeKeys.PREDICT, 'predict') | ||||
|     self.assertEqual(mode_keys.ModeKeys.TRAIN, 'train') | ||||
|     self.assertEqual(mode_keys.ModeKeys.TEST, 'test') | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user