Add a TFRT utility class.
PiperOrigin-RevId: 335926412 Change-Id: I4e0ad69c6a977c7fa131fa6cc3a0a838901ed72e
This commit is contained in:
parent
fb973b1c7d
commit
727a8e8798
@ -8,6 +8,7 @@ tensorflow/go/op/wrappers.go
|
||||
tensorflow/lite/micro/build_def.bzl
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/python/eager/benchmarks_test_base.py
|
||||
tensorflow/python/framework/tfrt_utils.py
|
||||
tensorflow/python/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/python/tpu/profiler/pip_package/README
|
||||
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
|
||||
|
@ -2200,12 +2200,19 @@ py_library(
|
||||
|
||||
# Including this as a dependency will result in tests to use TFRT.
|
||||
# TODO(b/170139514): Add a curated list of TFRT native ops.
|
||||
# TODO(kkb): Deprecated by `tfrt_utils`, remove.
|
||||
py_library(
|
||||
name = "is_tfrt_test_true",
|
||||
srcs = ["framework/is_tfrt_test_true.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tfrt_utils",
|
||||
srcs = ["framework/tfrt_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distributed_framework_test_lib",
|
||||
srcs_version = "PY2AND3",
|
||||
|
@ -212,6 +212,7 @@ py_library(
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:tfrt_utils",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.eager import executor
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import tfrt_utils
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import is_in_graph_mode
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
@ -335,19 +336,6 @@ class _TensorCacheDeleter(object):
|
||||
del _tensor_caches_map[self._context_id]
|
||||
|
||||
|
||||
# If the below import is made available through the BUILD rule, then this
|
||||
# function is overridden and will instead return True and cause Tensorflow
|
||||
# graphs to run with TFRT.
|
||||
def is_tfrt_enabled():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
|
||||
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
||||
# TODO(agarwal): consider keeping the corresponding Graph here.
|
||||
class Context(object):
|
||||
@ -423,10 +411,10 @@ class Context(object):
|
||||
raise ValueError(
|
||||
"execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
|
||||
if execution_mode is None:
|
||||
execution_mode = ASYNC if is_tfrt_enabled() else SYNC
|
||||
execution_mode = ASYNC if tfrt_utils.enabled() else SYNC
|
||||
self._default_is_async = execution_mode == ASYNC
|
||||
self._lazy_remote_inputs_copy = None
|
||||
self._use_tfrt = is_tfrt_enabled()
|
||||
self._use_tfrt = tfrt_utils.enabled()
|
||||
self._server_def = server_def
|
||||
self._collective_ops_server_def = None
|
||||
self._collective_leader = None
|
||||
|
@ -66,6 +66,7 @@ from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import tfrt_utils
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
@ -119,17 +120,6 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# Uses the same mechanism as above to selectively enable TFRT.
|
||||
def is_tfrt_enabled():
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top, unused-import
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
|
||||
def _get_object_count_by_type():
|
||||
return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])
|
||||
|
||||
@ -1831,7 +1821,7 @@ def disable_tfrt(unused_description):
|
||||
"""Execute the test only if tfrt is not enabled."""
|
||||
|
||||
if tf_inspect.isclass(cls_or_func):
|
||||
if is_tfrt_enabled():
|
||||
if tfrt_utils.enabled():
|
||||
return None
|
||||
else:
|
||||
return cls_or_func
|
||||
@ -1839,7 +1829,7 @@ def disable_tfrt(unused_description):
|
||||
def decorator(func):
|
||||
|
||||
def decorated(self, *args, **kwargs):
|
||||
if is_tfrt_enabled():
|
||||
if tfrt_utils.enabled():
|
||||
return
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
24
tensorflow/python/framework/tfrt_utils.py
Normal file
24
tensorflow/python/framework/tfrt_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for TFRT migration."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def enabled():
|
||||
"""Returns true if TFRT should be enabled."""
|
||||
return False
|
Loading…
Reference in New Issue
Block a user