Fork the tf_contextlib to keras and use tf_decorator which is a __internal__ API.

PiperOrigin-RevId: 331790478
Change-Id: Idb3f4fec529b0a1c012d3fd2f67d5f6f440d6806
This commit is contained in:
Scott Zhu 2020-09-15 09:51:08 -07:00 committed by TensorFlower Gardener
parent d0791dcf1a
commit e63c9f8418
10 changed files with 53 additions and 7 deletions

View File

@ -95,6 +95,7 @@ py_library(
"//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/keras/engine:keras_tensor", "//tensorflow/python/keras/engine:keras_tensor",
"//tensorflow/python/keras/utils:control_flow_util", "//tensorflow/python/keras/utils:control_flow_util",
"//tensorflow/python/keras/utils:tf_contextlib",
"//tensorflow/python/keras/utils:tf_inspect", "//tensorflow/python/keras/utils:tf_inspect",
], ],
) )

View File

@ -55,6 +55,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend_config from tensorflow.python.keras import backend_config
from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
@ -83,7 +84,6 @@ from tensorflow.python.training.tracking import util as tracking_util
from tensorflow.python.util import dispatch from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_controls

View File

@ -40,6 +40,7 @@ from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -49,7 +50,6 @@ from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
def set_weights(distribution_strategy, dist_model, weights): def set_weights(distribution_strategy, dist_model, weights):

View File

@ -27,12 +27,12 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import function_utils from tensorflow.python.util import function_utils
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# Avoid breaking users who directly import this symbol from this file. # Avoid breaking users who directly import this symbol from this file.

View File

@ -53,6 +53,7 @@ from tensorflow.python.keras.saving.saved_model import load as keras_load
from tensorflow.python.keras.saving.saved_model import save_impl as keras_save from tensorflow.python.keras.saving.saved_model import save_impl as keras_save
from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
@ -63,7 +64,6 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import save as tf_save from tensorflow.python.saved_model import save as tf_save
from tensorflow.python.util import tf_contextlib
class LayerWithLearningPhase(keras.engine.base_layer.Layer): class LayerWithLearningPhase(keras.engine.base_layer.Layer):

View File

@ -46,8 +46,8 @@ from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator

View File

@ -197,6 +197,15 @@ py_library(
], ],
) )
py_library(
name = "tf_contextlib",
srcs = ["tf_contextlib.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
],
)
py_library( py_library(
name = "tf_inspect", name = "tf_inspect",
srcs = ["tf_inspect.py"], srcs = ["tf_inspect.py"],

View File

@ -29,9 +29,9 @@ import types as python_types
import numpy as np import numpy as np
import six import six
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export

View File

@ -0,0 +1,36 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFDecorator-aware replacements for the contextlib module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib as _contextlib
from tensorflow.python.util import tf_decorator
def contextmanager(target):
"""A tf_decorator-aware wrapper for `contextlib.contextmanager`.
Usage is identical to `contextlib.contextmanager`.
Args:
target: A callable to be wrapped in a contextmanager.
Returns:
A callable that can be used inside of a `with` statement.
"""
context_manager = _contextlib.contextmanager(target)
return tf_decorator.make_decorator(target, context_manager, 'contextmanager')

View File

@ -30,13 +30,13 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
def is_tensor_or_tensor_list(v): def is_tensor_or_tensor_list(v):