Creates a tfrt_enabled flag for TF python tests which if set will run the test with TFRT enabled. It creates a new BUILD target for tests when this flag is set which ensures that the EagerContext for that test has tfrt_enabled.
PiperOrigin-RevId: 304278801 Change-Id: I03c61c997d2be126399040cfabb600c62b5cf4cc
This commit is contained in:
parent
805e47cea9
commit
811cd600ee
@ -1967,6 +1967,13 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Including this as a dependency will result in tests to use TFRT.
|
||||||
|
py_library(
|
||||||
|
name = "is_tfrt_test_true",
|
||||||
|
srcs = ["framework/is_tfrt_test_true.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "distributed_framework_test_lib",
|
name = "distributed_framework_test_lib",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
@ -336,6 +336,19 @@ class _TensorCacheDeleter(object):
|
|||||||
del _tensor_caches_map[self._context_id]
|
del _tensor_caches_map[self._context_id]
|
||||||
|
|
||||||
|
|
||||||
|
# If the below import is made available through the BUILD rule, then this
|
||||||
|
# function is overridden and will instead return True and cause Tensorflow
|
||||||
|
# graphs to run with TFRT.
|
||||||
|
def is_tfrt_enabled():
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top
|
||||||
|
except: # pylint: disable=bare-except
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
||||||
# TODO(agarwal): consider keeping the corresponding Graph here.
|
# TODO(agarwal): consider keeping the corresponding Graph here.
|
||||||
class Context(object):
|
class Context(object):
|
||||||
@ -411,7 +424,7 @@ class Context(object):
|
|||||||
execution_mode = SYNC
|
execution_mode = SYNC
|
||||||
self._default_is_async = execution_mode == ASYNC
|
self._default_is_async = execution_mode == ASYNC
|
||||||
self._lazy_remote_inputs_copy = None
|
self._lazy_remote_inputs_copy = None
|
||||||
self._use_tfrt = None
|
self._use_tfrt = is_tfrt_enabled()
|
||||||
self._server_def = server_def
|
self._server_def = server_def
|
||||||
self._collective_ops_server_def = None
|
self._collective_ops_server_def = None
|
||||||
self._collective_leader = None
|
self._collective_leader = None
|
||||||
|
30
tensorflow/python/framework/is_tfrt_test_true.py
Normal file
30
tensorflow/python/framework/is_tfrt_test_true.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Including this as a dependency will result in Tensorflow tests using TFRT.
|
||||||
|
|
||||||
|
This function is defined by default in eager/context.py to False. The context
|
||||||
|
then attempts to import this module. If this file is made available through the
|
||||||
|
BUILD rule, then this function is overridden and will instead cause
|
||||||
|
Tensorflow eager execution to run with TFRT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
def is_tfrt_enabled():
|
||||||
|
"""Returns true to state TFRT should be enabled for Tensorflow tests."""
|
||||||
|
return True
|
@ -2207,6 +2207,7 @@ def tf_py_test(
|
|||||||
xla_enable_strict_auto_jit = False,
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False,
|
grpc_enabled = False,
|
||||||
|
tfrt_enabled = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Create one or more python tests with extra tensorflow dependencies."""
|
"""Create one or more python tests with extra tensorflow dependencies."""
|
||||||
xla_test_true_list = []
|
xla_test_true_list = []
|
||||||
@ -2223,6 +2224,8 @@ def tf_py_test(
|
|||||||
deps = deps + tf_additional_xla_deps_py()
|
deps = deps + tf_additional_xla_deps_py()
|
||||||
if grpc_enabled:
|
if grpc_enabled:
|
||||||
deps = deps + tf_additional_grpc_deps_py()
|
deps = deps + tf_additional_grpc_deps_py()
|
||||||
|
if tfrt_enabled:
|
||||||
|
deps = deps + ["//tensorflow/python:is_tfrt_test_true"]
|
||||||
|
|
||||||
# NOTE(ebrevdo): This is a workaround for depset() not being able to tell
|
# NOTE(ebrevdo): This is a workaround for depset() not being able to tell
|
||||||
# the difference between 'dep' and 'clean_dep(dep)'.
|
# the difference between 'dep' and 'clean_dep(dep)'.
|
||||||
@ -2381,6 +2384,7 @@ def py_tests(
|
|||||||
xla_enable_strict_auto_jit = False,
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False,
|
grpc_enabled = False,
|
||||||
|
tfrt_enabled = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if "additional_deps" in kwargs:
|
if "additional_deps" in kwargs:
|
||||||
fail("Use `deps` to specify dependencies. `additional_deps` has been replaced with the standard pattern of `deps`.")
|
fail("Use `deps` to specify dependencies. `additional_deps` has been replaced with the standard pattern of `deps`.")
|
||||||
@ -2400,6 +2404,7 @@ def py_tests(
|
|||||||
tags = tags,
|
tags = tags,
|
||||||
xla_enabled = xla_enabled,
|
xla_enabled = xla_enabled,
|
||||||
xla_enable_strict_auto_jit = xla_enable_strict_auto_jit,
|
xla_enable_strict_auto_jit = xla_enable_strict_auto_jit,
|
||||||
|
tfrt_enabled = tfrt_enabled,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user