Refactor keras dependency to a common utility.
PiperOrigin-RevId: 338155812 Change-Id: Ibbd933514dbb76032a7c982a9233e183d311ab37
This commit is contained in:
parent
db14325178
commit
e3b2f635b9
@ -33,10 +33,10 @@ from tensorflow.python.keras.utils import control_flow_util
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.training.tracking import base as tracking
|
||||
from tensorflow.python.util import keras_deps
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -417,7 +417,7 @@ def call_context():
|
||||
return call_ctx
|
||||
|
||||
|
||||
control_flow_util_v2._register_keras_layer_context_function(call_context) # pylint: disable=protected-access
|
||||
keras_deps.register_call_context_function(call_context)
|
||||
|
||||
|
||||
class CallContext(object):
|
||||
|
@ -28,11 +28,11 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.func_graph import FuncGraph
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.util import keras_deps
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
|
||||
_KERAS_LAYER_CONTEXT_FUNCTION = None
|
||||
_DISABLE_LOWER_USING_SWITCH_MERGE = False
|
||||
|
||||
|
||||
@ -242,18 +242,11 @@ def _is_tpu_strategy(strategy):
|
||||
strategy.__class__.__name__.startswith("TPUStrategy"))
|
||||
|
||||
|
||||
def _register_keras_layer_context_function(func):
|
||||
global _KERAS_LAYER_CONTEXT_FUNCTION
|
||||
# TODO(scottzhu): Disable duplicated inject once keras is moved to
|
||||
# third_party/py/keras.
|
||||
_KERAS_LAYER_CONTEXT_FUNCTION = func
|
||||
|
||||
|
||||
def _is_building_keras_layer():
|
||||
# TODO(srbs): Remove this function when we no long support session with Keras.
|
||||
global _KERAS_LAYER_CONTEXT_FUNCTION
|
||||
if _KERAS_LAYER_CONTEXT_FUNCTION is not None:
|
||||
return _KERAS_LAYER_CONTEXT_FUNCTION().layer is not None
|
||||
keras_call_context_function = keras_deps.get_call_context_function()
|
||||
if keras_call_context_function:
|
||||
return keras_call_context_function().layer is not None
|
||||
else:
|
||||
return False
|
||||
|
||||
|
44
tensorflow/python/util/keras_deps.py
Normal file
44
tensorflow/python/util/keras_deps.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Interface that provides access to Keras dependencies.
|
||||
|
||||
This library is a common interface that contains Keras functions needed by
|
||||
TensorFlow and TensorFlow Lite and is required as per the dependency inversion
|
||||
principle (https://en.wikipedia.org/wiki/Dependency_inversion_principle). As per
|
||||
this principle, high-level modules (eg: TensorFlow and TensorFlow Lite) should
|
||||
not depend on low-level modules (eg: Keras) and instead both should depend on a
|
||||
common interface such as this file.
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
_KERAS_CALL_CONTEXT_FUNCTION = None
|
||||
|
||||
|
||||
def register_call_context_function(func):
|
||||
global _KERAS_CALL_CONTEXT_FUNCTION
|
||||
# TODO(scottzhu): Disable duplicated inject once keras is moved to
|
||||
# third_party/py/keras.
|
||||
_KERAS_CALL_CONTEXT_FUNCTION = func
|
||||
|
||||
|
||||
def get_call_context_function():
|
||||
global _KERAS_CALL_CONTEXT_FUNCTION
|
||||
return _KERAS_CALL_CONTEXT_FUNCTION
|
Loading…
Reference in New Issue
Block a user