Creates a tfrt_enabled flag for TF python tests which if set will run the test with TFRT enabled. It creates a new BUILD target for tests when this flag is set which ensures that the EagerContext for that test has tfrt_enabled.

PiperOrigin-RevId: 304278801
Change-Id: I03c61c997d2be126399040cfabb600c62b5cf4cc
This commit is contained in:
Rohan Jain 2020-04-01 15:51:36 -07:00 committed by TensorFlower Gardener
parent 805e47cea9
commit 811cd600ee
4 changed files with 56 additions and 1 deletions

View File

@ -1967,6 +1967,13 @@ py_library(
srcs_version = "PY2AND3",
)
# Including this as a dependency will result in tests to use TFRT.
py_library(
name = "is_tfrt_test_true",
srcs = ["framework/is_tfrt_test_true.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "distributed_framework_test_lib",
srcs_version = "PY2AND3",

View File

@ -336,6 +336,19 @@ class _TensorCacheDeleter(object):
del _tensor_caches_map[self._context_id]
# If the below import is made available through the BUILD rule, then this
# function is overridden and will instead return True and cause Tensorflow
# graphs to run with TFRT.
def is_tfrt_enabled():
return None
try:
from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top
except: # pylint: disable=bare-except
pass
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
@ -411,7 +424,7 @@ class Context(object):
execution_mode = SYNC
self._default_is_async = execution_mode == ASYNC
self._lazy_remote_inputs_copy = None
self._use_tfrt = None
self._use_tfrt = is_tfrt_enabled()
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None

View File

@ -0,0 +1,30 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Including this as a dependency will result in Tensorflow tests using TFRT.
This function is defined by default in eager/context.py to False. The context
then attempts to import this module. If this file is made available through the
BUILD rule, then this function is overridden and will instead cause
Tensorflow eager execution to run with TFRT.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def is_tfrt_enabled():
"""Returns true to state TFRT should be enabled for Tensorflow tests."""
return True

View File

@ -2207,6 +2207,7 @@ def tf_py_test(
xla_enable_strict_auto_jit = False,
xla_enabled = False,
grpc_enabled = False,
tfrt_enabled = False,
**kwargs):
"""Create one or more python tests with extra tensorflow dependencies."""
xla_test_true_list = []
@ -2223,6 +2224,8 @@ def tf_py_test(
deps = deps + tf_additional_xla_deps_py()
if grpc_enabled:
deps = deps + tf_additional_grpc_deps_py()
if tfrt_enabled:
deps = deps + ["//tensorflow/python:is_tfrt_test_true"]
# NOTE(ebrevdo): This is a workaround for depset() not being able to tell
# the difference between 'dep' and 'clean_dep(dep)'.
@ -2381,6 +2384,7 @@ def py_tests(
xla_enable_strict_auto_jit = False,
xla_enabled = False,
grpc_enabled = False,
tfrt_enabled = False,
**kwargs):
if "additional_deps" in kwargs:
fail("Use `deps` to specify dependencies. `additional_deps` has been replaced with the standard pattern of `deps`.")
@ -2400,6 +2404,7 @@ def py_tests(
tags = tags,
xla_enabled = xla_enabled,
xla_enable_strict_auto_jit = xla_enable_strict_auto_jit,
tfrt_enabled = tfrt_enabled,
**kwargs
)