STT-tensorflow/tensorflow/python/keras/optimizers.py
Reed Wanderman-Milne 642c3e8498 Move mixed precision files out of experimental/ directory.
This is a purely mechanical change. All that is done is:
* Deleted python/keras/mixed_precision/experimental/__init__.py
* All other files in python/keras/mixed_precision/experimental/ are moved one directly up, out of the experimental/ folder
* All Python imports, BUILD dependencies, and other references to the old experimental files are adjusted to refer to the new location

This changes the API golden files, but there is no API change. The golden files referred to the full paths of the classes in "is_instance" sections, and the full paths have changed.

PiperOrigin-RevId: 338345459
Change-Id: I9eefc2bea49b71f26ef7ec3563364a3f1d54abe6
2020-10-21 15:19:21 -07:00

122 lines
4.7 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Built-in optimizer classes.
For more examples see the base class `tf.keras.optimizers.Optimizer`.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizer_v1 import Optimizer
from tensorflow.python.keras.optimizer_v1 import TFOptimizer
from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
from tensorflow.python.keras.optimizer_v2 import ftrl
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.optimizers.serialize')
def serialize(optimizer):
return serialize_keras_object(optimizer)
@keras_export('keras.optimizers.deserialize')
def deserialize(config, custom_objects=None):
"""Inverse of the `serialize` function.
Arguments:
config: Optimizer configuration dictionary.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during deserialization.
Returns:
A Keras Optimizer instance.
"""
# loss_scale_optimizer has a direct dependency of optimizer, import here
# rather than top to avoid the cyclic dependency.
from tensorflow.python.keras.mixed_precision import loss_scale_optimizer # pylint: disable=g-import-not-at-top
all_classes = {
'adadelta': adadelta_v2.Adadelta,
'adagrad': adagrad_v2.Adagrad,
'adam': adam_v2.Adam,
'adamax': adamax_v2.Adamax,
'nadam': nadam_v2.Nadam,
'rmsprop': rmsprop_v2.RMSprop,
'sgd': gradient_descent_v2.SGD,
'ftrl': ftrl.Ftrl,
'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
# LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as
# LossScaleOptimizerV1 will be removed soon but deserializing it will
# still be supported.
'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer,
}
# Make deserialization case-insensitive for built-in optimizers.
if config['class_name'].lower() in all_classes:
config['class_name'] = config['class_name'].lower()
return deserialize_keras_object(
config,
module_objects=all_classes,
custom_objects=custom_objects,
printable_module_name='optimizer')
@keras_export('keras.optimizers.get')
def get(identifier):
"""Retrieves a Keras Optimizer instance.
Arguments:
identifier: Optimizer identifier, one of
- String: name of an optimizer
- Dictionary: configuration dictionary. - Keras Optimizer instance (it
will be returned unchanged). - TensorFlow Optimizer instance (it
will be wrapped as a Keras Optimizer).
Returns:
A Keras Optimizer instance.
Raises:
ValueError: If `identifier` cannot be interpreted.
"""
if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)):
return identifier
# Wrap TF optimizer instances
elif isinstance(identifier, tf_optimizer_module.Optimizer):
opt = TFOptimizer(identifier)
K.track_tf_optimizer(opt)
return opt
elif isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
else:
raise ValueError(
'Could not interpret optimizer identifier: {}'.format(identifier))