Cleanup the no-at-top imports in Keras.

Give some explanation for certain case about why it is not avoidable.

PiperOrigin-RevId: 277118837
Change-Id: I9b16ee1ce89b26cf0b07ae4e3eb2af337b11dd9c
This commit is contained in:
Scott Zhu 2019-10-28 12:27:42 -07:00 committed by TensorFlower Gardener
parent a3540d16d8
commit ad6be7820b
5 changed files with 7 additions and 33 deletions

View File

@ -73,6 +73,7 @@ from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import moving_averages
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
@ -1601,11 +1602,6 @@ def moving_average_update(x, value, momentum):
Returns: Returns:
An Operation to update the variable. An Operation to update the variable.
""" """
# `training` is higher-up than the Keras backend in the abstraction hierarchy.
# In particular, `training` depends on layers, and thus on Keras.
# moving_averages, being low-level ops, should not be part of the training
# module.
from tensorflow.python.training import moving_averages # pylint: disable=g-import-not-at-top
zero_debias = not tf2.enabled() zero_debias = not tf2.enabled()
return moving_averages.assign_moving_average( return moving_averages.assign_moving_average(
x, value, momentum, zero_debias=zero_debias) x, value, momentum, zero_debias=zero_debias)

View File

@ -41,7 +41,8 @@ _call_context = threading.local()
def create_mean_metric(value, name=None): def create_mean_metric(value, name=None):
# TODO(psv): Remove this import when b/110718070 is fixed. # import keras will import base_layer and then this module, and metric relies
# on base_layer, which result into a cyclic dependency.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
metric_obj = metrics_module.Mean(name=name) metric_obj = metrics_module.Mean(name=name)
return metric_obj, metric_obj(value) return metric_obj, metric_obj(value)

View File

@ -1,23 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mixed precision API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer import LossScaleOptimizer
from tensorflow.python.keras.mixed_precision.experimental.policy import global_policy
from tensorflow.python.keras.mixed_precision.experimental.policy import Policy
from tensorflow.python.keras.mixed_precision.experimental.policy import set_policy

View File

@ -794,6 +794,8 @@ def deserialize(config, custom_objects=None):
Returns: Returns:
A Keras Optimizer instance. A Keras Optimizer instance.
""" """
# loss_scale_optimizer has a direct dependency of optimizer, import here
# rather than top to avoid the cyclic dependency.
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer # pylint: disable=g-import-not-at-top from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer # pylint: disable=g-import-not-at-top
all_classes = { all_classes = {
'adadelta': adadelta_v2.Adadelta, 'adadelta': adadelta_v2.Adadelta,

View File

@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.training import Model from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.layers.core import Lambda
from tensorflow.python.keras.layers.merge import concatenate
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -152,10 +154,6 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
Raises: Raises:
ValueError: if the `gpus` argument does not match available devices. ValueError: if the `gpus` argument does not match available devices.
""" """
# pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.core import Lambda
from tensorflow.python.keras.layers.merge import concatenate
if isinstance(gpus, (list, tuple)): if isinstance(gpus, (list, tuple)):
if len(gpus) <= 1: if len(gpus) <= 1:
raise ValueError('For multi-gpu usage to be effective, ' raise ValueError('For multi-gpu usage to be effective, '