diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 9ba5109611d..0f9ab84a270 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -95,6 +95,7 @@ py_library( "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/keras/engine:keras_tensor", "//tensorflow/python/keras/utils:control_flow_util", + "//tensorflow/python/keras/utils:tf_contextlib", "//tensorflow/python/keras/utils:tf_inspect", ], ) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index ee28302e6c7..15eed32fe4b 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -55,6 +55,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend_config from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.utils import control_flow_util +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -83,7 +84,6 @@ from tensorflow.python.training.tracking import util as tracking_util from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util import object_identity -from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls diff --git a/tensorflow/python/keras/distribute/distributed_training_utils.py b/tensorflow/python/keras/distribute/distributed_training_utils.py index 07bbf3f2b1c..1876b69539a 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils.py @@ -40,6 +40,7 @@ from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -49,7 +50,6 @@ from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -from tensorflow.python.util import tf_contextlib def set_weights(distribution_strategy, dist_model, weights): diff --git a/tensorflow/python/keras/legacy_tf_layers/base.py b/tensorflow/python/keras/legacy_tf_layers/base.py index f35aaf18b67..8052651efa7 100644 --- a/tensorflow/python/keras/legacy_tf_layers/base.py +++ b/tensorflow/python/keras/legacy_tf_layers/base.py @@ -27,12 +27,12 @@ from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import function_utils from tensorflow.python.util import nest -from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export # Avoid breaking users who directly import this symbol from this file. diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 316bcc622f8..9615fef54b9 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -53,6 +53,7 @@ from tensorflow.python.keras.saving.saved_model import load as keras_load from tensorflow.python.keras.saving.saved_model import save_impl as keras_save from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -63,7 +64,6 @@ from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import save as tf_save -from tensorflow.python.util import tf_contextlib class LayerWithLearningPhase(keras.engine.base_layer.Layer): diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 9bc9a59b3d2..6d212e0cda3 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -46,8 +46,8 @@ from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect -from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index 77e432ea631..79c99cc1870 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -197,6 +197,15 @@ py_library( ], ) +py_library( + name = "tf_contextlib", + srcs = ["tf_contextlib.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + ], +) + py_library( name = "tf_inspect", srcs = ["tf_inspect.py"], diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 7d53b008a7f..6b79ebf7581 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -29,9 +29,9 @@ import types as python_types import numpy as np import six +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.util import nest -from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import keras_export diff --git a/tensorflow/python/keras/utils/tf_contextlib.py b/tensorflow/python/keras/utils/tf_contextlib.py new file mode 100644 index 00000000000..3830014d4ac --- /dev/null +++ b/tensorflow/python/keras/utils/tf_contextlib.py @@ -0,0 +1,36 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFDecorator-aware replacements for the contextlib module.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib as _contextlib + +from tensorflow.python.util import tf_decorator + + +def contextmanager(target): + """A tf_decorator-aware wrapper for `contextlib.contextmanager`. + + Usage is identical to `contextlib.contextmanager`. + + Args: + target: A callable to be wrapped in a contextmanager. + Returns: + A callable that can be used inside of a `with` statement. + """ + context_manager = _contextlib.contextmanager(target) + return tf_decorator.make_decorator(target, context_manager, 'contextmanager') diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 51cb1acc899..3e75da4ec13 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -30,13 +30,13 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend as K +from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.util import nest from tensorflow.python.util import object_identity -from tensorflow.python.util import tf_contextlib def is_tensor_or_tensor_list(v):