This is a purely mechanical change. All that is done is: * Deleted python/keras/mixed_precision/experimental/__init__.py * All other files in python/keras/mixed_precision/experimental/ are moved one directly up, out of the experimental/ folder * All Python imports, BUILD dependencies, and other references to the old experimental files are adjusted to refer to the new location This changes the API golden files, but there is no API change. The golden files referred to the full paths of the classes in "is_instance" sections, and the full paths have changed. PiperOrigin-RevId: 338345459 Change-Id: I9eefc2bea49b71f26ef7ec3563364a3f1d54abe6
122 lines
4.7 KiB
Python
122 lines
4.7 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# pylint: disable=invalid-name
|
|
"""Built-in optimizer classes.
|
|
|
|
For more examples see the base class `tf.keras.optimizers.Optimizer`.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import six
|
|
|
|
from tensorflow.python.keras import backend as K
|
|
from tensorflow.python.keras.optimizer_v1 import Optimizer
|
|
from tensorflow.python.keras.optimizer_v1 import TFOptimizer
|
|
from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
|
|
from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
|
|
from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
|
|
from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
|
|
from tensorflow.python.keras.optimizer_v2 import ftrl
|
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
|
from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
|
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
|
from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
|
|
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
|
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
|
from tensorflow.python.training import optimizer as tf_optimizer_module
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
|
|
|
|
@keras_export('keras.optimizers.serialize')
|
|
def serialize(optimizer):
|
|
return serialize_keras_object(optimizer)
|
|
|
|
|
|
@keras_export('keras.optimizers.deserialize')
|
|
def deserialize(config, custom_objects=None):
|
|
"""Inverse of the `serialize` function.
|
|
|
|
Arguments:
|
|
config: Optimizer configuration dictionary.
|
|
custom_objects: Optional dictionary mapping names (strings) to custom
|
|
objects (classes and functions) to be considered during deserialization.
|
|
|
|
Returns:
|
|
A Keras Optimizer instance.
|
|
"""
|
|
# loss_scale_optimizer has a direct dependency of optimizer, import here
|
|
# rather than top to avoid the cyclic dependency.
|
|
from tensorflow.python.keras.mixed_precision import loss_scale_optimizer # pylint: disable=g-import-not-at-top
|
|
all_classes = {
|
|
'adadelta': adadelta_v2.Adadelta,
|
|
'adagrad': adagrad_v2.Adagrad,
|
|
'adam': adam_v2.Adam,
|
|
'adamax': adamax_v2.Adamax,
|
|
'nadam': nadam_v2.Nadam,
|
|
'rmsprop': rmsprop_v2.RMSprop,
|
|
'sgd': gradient_descent_v2.SGD,
|
|
'ftrl': ftrl.Ftrl,
|
|
'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
|
|
# LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as
|
|
# LossScaleOptimizerV1 will be removed soon but deserializing it will
|
|
# still be supported.
|
|
'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer,
|
|
}
|
|
|
|
# Make deserialization case-insensitive for built-in optimizers.
|
|
if config['class_name'].lower() in all_classes:
|
|
config['class_name'] = config['class_name'].lower()
|
|
return deserialize_keras_object(
|
|
config,
|
|
module_objects=all_classes,
|
|
custom_objects=custom_objects,
|
|
printable_module_name='optimizer')
|
|
|
|
|
|
@keras_export('keras.optimizers.get')
|
|
def get(identifier):
|
|
"""Retrieves a Keras Optimizer instance.
|
|
|
|
Arguments:
|
|
identifier: Optimizer identifier, one of
|
|
- String: name of an optimizer
|
|
- Dictionary: configuration dictionary. - Keras Optimizer instance (it
|
|
will be returned unchanged). - TensorFlow Optimizer instance (it
|
|
will be wrapped as a Keras Optimizer).
|
|
|
|
Returns:
|
|
A Keras Optimizer instance.
|
|
|
|
Raises:
|
|
ValueError: If `identifier` cannot be interpreted.
|
|
"""
|
|
if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)):
|
|
return identifier
|
|
# Wrap TF optimizer instances
|
|
elif isinstance(identifier, tf_optimizer_module.Optimizer):
|
|
opt = TFOptimizer(identifier)
|
|
K.track_tf_optimizer(opt)
|
|
return opt
|
|
elif isinstance(identifier, dict):
|
|
return deserialize(identifier)
|
|
elif isinstance(identifier, six.string_types):
|
|
config = {'class_name': str(identifier), 'config': {}}
|
|
return deserialize(config)
|
|
else:
|
|
raise ValueError(
|
|
'Could not interpret optimizer identifier: {}'.format(identifier))
|