Fork the tf_contextlib to keras and use tf_decorator which is a __internal__ API.
PiperOrigin-RevId: 331790478 Change-Id: Idb3f4fec529b0a1c012d3fd2f67d5f6f440d6806
This commit is contained in:
parent
d0791dcf1a
commit
e63c9f8418
tensorflow/python/keras
@ -95,6 +95,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/keras/engine:keras_tensor",
|
||||
"//tensorflow/python/keras/utils:control_flow_util",
|
||||
"//tensorflow/python/keras/utils:tf_contextlib",
|
||||
"//tensorflow/python/keras/utils:tf_inspect",
|
||||
],
|
||||
)
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend_config
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.utils import control_flow_util
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
@ -83,7 +84,6 @@ from tensorflow.python.training.tracking import util as tracking_util
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -49,7 +50,6 @@ from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
def set_weights(distribution_strategy, dist_model, weights):
|
||||
|
@ -27,12 +27,12 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Avoid breaking users who directly import this symbol from this file.
|
||||
|
@ -53,6 +53,7 @@ from tensorflow.python.keras.saving.saved_model import load as keras_load
|
||||
from tensorflow.python.keras.saving.saved_model import save_impl as keras_save
|
||||
from tensorflow.python.keras.utils import control_flow_util
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -63,7 +64,6 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
class LayerWithLearningPhase(keras.engine.base_layer.Layer):
|
||||
|
@ -46,8 +46,8 @@ from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
|
||||
|
@ -197,6 +197,15 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_contextlib",
|
||||
srcs = ["tf_contextlib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_inspect",
|
||||
srcs = ["tf_inspect.py"],
|
||||
|
@ -29,9 +29,9 @@ import types as python_types
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
36
tensorflow/python/keras/utils/tf_contextlib.py
Normal file
36
tensorflow/python/keras/utils/tf_contextlib.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TFDecorator-aware replacements for the contextlib module."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib as _contextlib
|
||||
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
|
||||
def contextmanager(target):
|
||||
"""A tf_decorator-aware wrapper for `contextlib.contextmanager`.
|
||||
|
||||
Usage is identical to `contextlib.contextmanager`.
|
||||
|
||||
Args:
|
||||
target: A callable to be wrapped in a contextmanager.
|
||||
Returns:
|
||||
A callable that can be used inside of a `with` statement.
|
||||
"""
|
||||
context_manager = _contextlib.contextmanager(target)
|
||||
return tf_decorator.make_decorator(target, context_manager, 'contextmanager')
|
@ -30,13 +30,13 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
def is_tensor_or_tensor_list(v):
|
||||
|
Loading…
Reference in New Issue
Block a user