diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 2aad135c23b..5d7658389be 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -8,6 +8,7 @@ tensorflow/go/op/wrappers.go tensorflow/lite/micro/build_def.bzl tensorflow/python/autograph/core/config.py tensorflow/python/eager/benchmarks_test_base.py +tensorflow/python/framework/tfrt_utils.py tensorflow/python/tpu/profiler/pip_package/BUILD tensorflow/python/tpu/profiler/pip_package/README tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 70ea4ebe1e2..01a4dbab751 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2200,12 +2200,19 @@ py_library( # Including this as a dependency will result in tests to use TFRT. # TODO(b/170139514): Add a curated list of TFRT native ops. +# TODO(kkb): Deprecated by `tfrt_utils`, remove. py_library( name = "is_tfrt_test_true", srcs = ["framework/is_tfrt_test_true.py"], srcs_version = "PY2AND3", ) +py_library( + name = "tfrt_utils", + srcs = ["framework/tfrt_utils.py"], + srcs_version = "PY2AND3", +) + py_library( name = "distributed_framework_test_lib", srcs_version = "PY2AND3", diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index b743f354cd2..f9ebb9cb4f2 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -212,6 +212,7 @@ py_library( "//tensorflow/python:pywrap_tf_session", "//tensorflow/python:pywrap_tfe", "//tensorflow/python:tf2", + "//tensorflow/python:tfrt_utils", "//tensorflow/python:util", "//third_party/py/numpy", ], diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index cacbc670189..8d5f4546731 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -39,6 +39,7 @@ from tensorflow.python.eager import executor from tensorflow.python.eager import monitoring from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import tfrt_utils from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib @@ -335,19 +336,6 @@ class _TensorCacheDeleter(object): del _tensor_caches_map[self._context_id] -# If the below import is made available through the BUILD rule, then this -# function is overridden and will instead return True and cause Tensorflow -# graphs to run with TFRT. -def is_tfrt_enabled(): - return None - - -try: - from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top -except: # pylint: disable=bare-except - pass - - # TODO(agarwal): rename to EagerContext / EagerRuntime ? # TODO(agarwal): consider keeping the corresponding Graph here. class Context(object): @@ -423,10 +411,10 @@ class Context(object): raise ValueError( "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode) if execution_mode is None: - execution_mode = ASYNC if is_tfrt_enabled() else SYNC + execution_mode = ASYNC if tfrt_utils.enabled() else SYNC self._default_is_async = execution_mode == ASYNC self._lazy_remote_inputs_copy = None - self._use_tfrt = is_tfrt_enabled() + self._use_tfrt = tfrt_utils.enabled() self._server_def = server_def self._collective_ops_server_def = None self._collective_leader = None diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index ee078dd0455..7f610393180 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -66,6 +66,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import tfrt_utils from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util @@ -119,17 +120,6 @@ except ImportError: pass -# Uses the same mechanism as above to selectively enable TFRT. -def is_tfrt_enabled(): - return False - - -try: - from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top, unused-import -except Exception: # pylint: disable=broad-except - pass - - def _get_object_count_by_type(): return collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) @@ -1831,7 +1821,7 @@ def disable_tfrt(unused_description): """Execute the test only if tfrt is not enabled.""" if tf_inspect.isclass(cls_or_func): - if is_tfrt_enabled(): + if tfrt_utils.enabled(): return None else: return cls_or_func @@ -1839,7 +1829,7 @@ def disable_tfrt(unused_description): def decorator(func): def decorated(self, *args, **kwargs): - if is_tfrt_enabled(): + if tfrt_utils.enabled(): return else: return func(self, *args, **kwargs) diff --git a/tensorflow/python/framework/tfrt_utils.py b/tensorflow/python/framework/tfrt_utils.py new file mode 100644 index 00000000000..0f973b6a5bb --- /dev/null +++ b/tensorflow/python/framework/tfrt_utils.py @@ -0,0 +1,24 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for TFRT migration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def enabled(): + """Returns true if TFRT should be enabled.""" + return False