Cleanup the no-at-top imports in Keras.
Give some explanation for certain case about why it is not avoidable. PiperOrigin-RevId: 277118837 Change-Id: I9b16ee1ce89b26cf0b07ae4e3eb2af337b11dd9c
This commit is contained in:
parent
a3540d16d8
commit
ad6be7820b
@ -73,6 +73,7 @@ from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import moving_averages
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
@ -1601,11 +1602,6 @@ def moving_average_update(x, value, momentum):
|
||||
Returns:
|
||||
An Operation to update the variable.
|
||||
"""
|
||||
# `training` is higher-up than the Keras backend in the abstraction hierarchy.
|
||||
# In particular, `training` depends on layers, and thus on Keras.
|
||||
# moving_averages, being low-level ops, should not be part of the training
|
||||
# module.
|
||||
from tensorflow.python.training import moving_averages # pylint: disable=g-import-not-at-top
|
||||
zero_debias = not tf2.enabled()
|
||||
return moving_averages.assign_moving_average(
|
||||
x, value, momentum, zero_debias=zero_debias)
|
||||
|
@ -41,7 +41,8 @@ _call_context = threading.local()
|
||||
|
||||
|
||||
def create_mean_metric(value, name=None):
|
||||
# TODO(psv): Remove this import when b/110718070 is fixed.
|
||||
# import keras will import base_layer and then this module, and metric relies
|
||||
# on base_layer, which result into a cyclic dependency.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
metric_obj = metrics_module.Mean(name=name)
|
||||
return metric_obj, metric_obj(value)
|
||||
|
@ -1,23 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Mixed precision API."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer import LossScaleOptimizer
|
||||
from tensorflow.python.keras.mixed_precision.experimental.policy import global_policy
|
||||
from tensorflow.python.keras.mixed_precision.experimental.policy import Policy
|
||||
from tensorflow.python.keras.mixed_precision.experimental.policy import set_policy
|
@ -794,6 +794,8 @@ def deserialize(config, custom_objects=None):
|
||||
Returns:
|
||||
A Keras Optimizer instance.
|
||||
"""
|
||||
# loss_scale_optimizer has a direct dependency of optimizer, import here
|
||||
# rather than top to avoid the cyclic dependency.
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer # pylint: disable=g-import-not-at-top
|
||||
all_classes = {
|
||||
'adadelta': adadelta_v2.Adadelta,
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine.training import Model
|
||||
from tensorflow.python.keras.layers.core import Lambda
|
||||
from tensorflow.python.keras.layers.merge import concatenate
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -152,10 +154,6 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
|
||||
Raises:
|
||||
ValueError: if the `gpus` argument does not match available devices.
|
||||
"""
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.layers.core import Lambda
|
||||
from tensorflow.python.keras.layers.merge import concatenate
|
||||
|
||||
if isinstance(gpus, (list, tuple)):
|
||||
if len(gpus) <= 1:
|
||||
raise ValueError('For multi-gpu usage to be effective, '
|
||||
|
Loading…
Reference in New Issue
Block a user