This is a purely mechanical change. All that is done is: * Deleted python/keras/mixed_precision/experimental/__init__.py * All other files in python/keras/mixed_precision/experimental/ are moved one directly up, out of the experimental/ folder * All Python imports, BUILD dependencies, and other references to the old experimental files are adjusted to refer to the new location This changes the API golden files, but there is no API change. The golden files referred to the full paths of the classes in "is_instance" sections, and the full paths have changed. PiperOrigin-RevId: 338345459 Change-Id: I9eefc2bea49b71f26ef7ec3563364a3f1d54abe6
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains keras-specific LossScale functionality.
|
|
|
|
This functions cannot be in the non-keras loss_scale.py file since they depend
|
|
on keras, and files outside of keras should not depend on files inside keras.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import six
|
|
|
|
from tensorflow.python.keras.utils import generic_utils
|
|
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
|
|
|
|
|
def serialize(loss_scale):
|
|
return generic_utils.serialize_keras_object(loss_scale)
|
|
|
|
|
|
def deserialize(config, custom_objects=None):
|
|
loss_scale_module_objects = {
|
|
'FixedLossScale': loss_scale_module.FixedLossScale,
|
|
'DynamicLossScale': loss_scale_module.DynamicLossScale,
|
|
}
|
|
|
|
return generic_utils.deserialize_keras_object(
|
|
config,
|
|
module_objects=loss_scale_module_objects,
|
|
custom_objects=custom_objects,
|
|
printable_module_name='loss scale'
|
|
)
|
|
|
|
|
|
def get(identifier):
|
|
"""Get a loss scale object."""
|
|
if isinstance(identifier, dict):
|
|
return deserialize(identifier)
|
|
|
|
if isinstance(identifier, six.integer_types + (float,)):
|
|
return loss_scale_module.FixedLossScale(identifier)
|
|
if identifier == 'dynamic':
|
|
return loss_scale_module.DynamicLossScale()
|
|
if isinstance(identifier, loss_scale_module.LossScale):
|
|
return identifier
|
|
elif identifier is None:
|
|
return None
|
|
else:
|
|
raise ValueError('Could not interpret loss scale identifier: %s' %
|
|
identifier)
|