Creates a tfrt_enabled flag for TF python tests which if set will run the test with TFRT enabled. It creates a new BUILD target for tests when this flag is set which ensures that the EagerContext for that test has tfrt_enabled.
PiperOrigin-RevId: 304278801 Change-Id: I03c61c997d2be126399040cfabb600c62b5cf4cc
This commit is contained in:
parent
805e47cea9
commit
811cd600ee
@ -1967,6 +1967,13 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
# Including this as a dependency will result in tests to use TFRT.
|
||||
py_library(
|
||||
name = "is_tfrt_test_true",
|
||||
srcs = ["framework/is_tfrt_test_true.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distributed_framework_test_lib",
|
||||
srcs_version = "PY2AND3",
|
||||
|
@ -336,6 +336,19 @@ class _TensorCacheDeleter(object):
|
||||
del _tensor_caches_map[self._context_id]
|
||||
|
||||
|
||||
# If the below import is made available through the BUILD rule, then this
|
||||
# function is overridden and will instead return True and cause Tensorflow
|
||||
# graphs to run with TFRT.
|
||||
def is_tfrt_enabled():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
|
||||
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
||||
# TODO(agarwal): consider keeping the corresponding Graph here.
|
||||
class Context(object):
|
||||
@ -411,7 +424,7 @@ class Context(object):
|
||||
execution_mode = SYNC
|
||||
self._default_is_async = execution_mode == ASYNC
|
||||
self._lazy_remote_inputs_copy = None
|
||||
self._use_tfrt = None
|
||||
self._use_tfrt = is_tfrt_enabled()
|
||||
self._server_def = server_def
|
||||
self._collective_ops_server_def = None
|
||||
self._collective_leader = None
|
||||
|
30
tensorflow/python/framework/is_tfrt_test_true.py
Normal file
30
tensorflow/python/framework/is_tfrt_test_true.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Including this as a dependency will result in Tensorflow tests using TFRT.
|
||||
|
||||
This function is defined by default in eager/context.py to False. The context
|
||||
then attempts to import this module. If this file is made available through the
|
||||
BUILD rule, then this function is overridden and will instead cause
|
||||
Tensorflow eager execution to run with TFRT.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def is_tfrt_enabled():
|
||||
"""Returns true to state TFRT should be enabled for Tensorflow tests."""
|
||||
return True
|
@ -2207,6 +2207,7 @@ def tf_py_test(
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enabled = False,
|
||||
grpc_enabled = False,
|
||||
tfrt_enabled = False,
|
||||
**kwargs):
|
||||
"""Create one or more python tests with extra tensorflow dependencies."""
|
||||
xla_test_true_list = []
|
||||
@ -2223,6 +2224,8 @@ def tf_py_test(
|
||||
deps = deps + tf_additional_xla_deps_py()
|
||||
if grpc_enabled:
|
||||
deps = deps + tf_additional_grpc_deps_py()
|
||||
if tfrt_enabled:
|
||||
deps = deps + ["//tensorflow/python:is_tfrt_test_true"]
|
||||
|
||||
# NOTE(ebrevdo): This is a workaround for depset() not being able to tell
|
||||
# the difference between 'dep' and 'clean_dep(dep)'.
|
||||
@ -2381,6 +2384,7 @@ def py_tests(
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enabled = False,
|
||||
grpc_enabled = False,
|
||||
tfrt_enabled = False,
|
||||
**kwargs):
|
||||
if "additional_deps" in kwargs:
|
||||
fail("Use `deps` to specify dependencies. `additional_deps` has been replaced with the standard pattern of `deps`.")
|
||||
@ -2400,6 +2404,7 @@ def py_tests(
|
||||
tags = tags,
|
||||
xla_enabled = xla_enabled,
|
||||
xla_enable_strict_auto_jit = xla_enable_strict_auto_jit,
|
||||
tfrt_enabled = tfrt_enabled,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user