Refactor keras dependency to a common utility.

PiperOrigin-RevId: 338155812
Change-Id: Ibbd933514dbb76032a7c982a9233e183d311ab37
This commit is contained in:
Meghna Natraj 2020-10-20 16:01:20 -07:00 committed by TensorFlower Gardener
parent db14325178
commit e3b2f635b9
3 changed files with 50 additions and 13 deletions

View File

@ -33,10 +33,10 @@ from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.training.tracking import base as tracking
from tensorflow.python.util import keras_deps
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@ -417,7 +417,7 @@ def call_context():
return call_ctx
control_flow_util_v2._register_keras_layer_context_function(call_context) # pylint: disable=protected-access
keras_deps.register_call_context_function(call_context)
class CallContext(object):

View File

@ -28,11 +28,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.util import keras_deps
from tensorflow.python.util import tf_contextlib
_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
_KERAS_LAYER_CONTEXT_FUNCTION = None
_DISABLE_LOWER_USING_SWITCH_MERGE = False
@ -242,18 +242,11 @@ def _is_tpu_strategy(strategy):
strategy.__class__.__name__.startswith("TPUStrategy"))
def _register_keras_layer_context_function(func):
global _KERAS_LAYER_CONTEXT_FUNCTION
# TODO(scottzhu): Disable duplicated inject once keras is moved to
# third_party/py/keras.
_KERAS_LAYER_CONTEXT_FUNCTION = func
def _is_building_keras_layer():
# TODO(srbs): Remove this function when we no long support session with Keras.
global _KERAS_LAYER_CONTEXT_FUNCTION
if _KERAS_LAYER_CONTEXT_FUNCTION is not None:
return _KERAS_LAYER_CONTEXT_FUNCTION().layer is not None
keras_call_context_function = keras_deps.get_call_context_function()
if keras_call_context_function:
return keras_call_context_function().layer is not None
else:
return False

View File

@ -0,0 +1,44 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Interface that provides access to Keras dependencies.
This library is a common interface that contains Keras functions needed by
TensorFlow and TensorFlow Lite and is required as per the dependency inversion
principle (https://en.wikipedia.org/wiki/Dependency_inversion_principle). As per
this principle, high-level modules (eg: TensorFlow and TensorFlow Lite) should
not depend on low-level modules (eg: Keras) and instead both should depend on a
common interface such as this file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
_KERAS_CALL_CONTEXT_FUNCTION = None
def register_call_context_function(func):
global _KERAS_CALL_CONTEXT_FUNCTION
# TODO(scottzhu): Disable duplicated inject once keras is moved to
# third_party/py/keras.
_KERAS_CALL_CONTEXT_FUNCTION = func
def get_call_context_function():
global _KERAS_CALL_CONTEXT_FUNCTION
return _KERAS_CALL_CONTEXT_FUNCTION