Add a TFRT utility class.

PiperOrigin-RevId: 335926412
Change-Id: I4e0ad69c6a977c7fa131fa6cc3a0a838901ed72e
This commit is contained in:
Kibeom Kim 2020-10-07 12:24:51 -07:00 committed by TensorFlower Gardener
parent fb973b1c7d
commit 727a8e8798
6 changed files with 39 additions and 28 deletions

View File

@ -8,6 +8,7 @@ tensorflow/go/op/wrappers.go
tensorflow/lite/micro/build_def.bzl
tensorflow/python/autograph/core/config.py
tensorflow/python/eager/benchmarks_test_base.py
tensorflow/python/framework/tfrt_utils.py
tensorflow/python/tpu/profiler/pip_package/BUILD
tensorflow/python/tpu/profiler/pip_package/README
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh

View File

@ -2200,12 +2200,19 @@ py_library(
# Including this as a dependency will result in tests to use TFRT.
# TODO(b/170139514): Add a curated list of TFRT native ops.
# TODO(kkb): Deprecated by `tfrt_utils`, remove.
py_library(
name = "is_tfrt_test_true",
srcs = ["framework/is_tfrt_test_true.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "tfrt_utils",
srcs = ["framework/tfrt_utils.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "distributed_framework_test_lib",
srcs_version = "PY2AND3",

View File

@ -212,6 +212,7 @@ py_library(
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:tf2",
"//tensorflow/python:tfrt_utils",
"//tensorflow/python:util",
"//third_party/py/numpy",
],

View File

@ -39,6 +39,7 @@ from tensorflow.python.eager import executor
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import tfrt_utils
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
@ -335,19 +336,6 @@ class _TensorCacheDeleter(object):
del _tensor_caches_map[self._context_id]
# If the below import is made available through the BUILD rule, then this
# function is overridden and will instead return True and cause Tensorflow
# graphs to run with TFRT.
def is_tfrt_enabled():
return None
try:
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top
except: # pylint: disable=bare-except
pass
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
@ -423,10 +411,10 @@ class Context(object):
raise ValueError(
"execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
if execution_mode is None:
execution_mode = ASYNC if is_tfrt_enabled() else SYNC
execution_mode = ASYNC if tfrt_utils.enabled() else SYNC
self._default_is_async = execution_mode == ASYNC
self._lazy_remote_inputs_copy = None
self._use_tfrt = is_tfrt_enabled()
self._use_tfrt = tfrt_utils.enabled()
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None

View File

@ -66,6 +66,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import tfrt_utils
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
@ -119,17 +120,6 @@ except ImportError:
pass
# Uses the same mechanism as above to selectively enable TFRT.
def is_tfrt_enabled():
return False
try:
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top, unused-import
except Exception: # pylint: disable=broad-except
pass
def _get_object_count_by_type():
return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])
@ -1831,7 +1821,7 @@ def disable_tfrt(unused_description):
"""Execute the test only if tfrt is not enabled."""
if tf_inspect.isclass(cls_or_func):
if is_tfrt_enabled():
if tfrt_utils.enabled():
return None
else:
return cls_or_func
@ -1839,7 +1829,7 @@ def disable_tfrt(unused_description):
def decorator(func):
def decorated(self, *args, **kwargs):
if is_tfrt_enabled():
if tfrt_utils.enabled():
return
else:
return func(self, *args, **kwargs)

View File

@ -0,0 +1,24 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for TFRT migration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def enabled():
"""Returns true if TFRT should be enabled."""
return False