Move tfrt_enabled target duplication logic to tf_py_test

The tfrt_enabled target duplication logic was initially added to cuda_py_test,
but there are tests that directly use tf_py_test, so add to tf_py_test instead.

PiperOrigin-RevId: 307719942
Change-Id: I90bb1009c5cadf1a5a0a3d036a3a174d31b84eac
This commit is contained in:
Kibeom Kim 2020-04-21 18:10:38 -07:00 committed by TensorFlower Gardener
parent bd0feba7b4
commit 86fd918743

View File

@ -2223,8 +2223,6 @@ def tf_py_test(
deps = deps + tf_additional_xla_deps_py()
if grpc_enabled:
deps = deps + tf_additional_grpc_deps_py()
if tfrt_enabled:
deps = deps + ["//tensorflow/python:is_tfrt_test_true"]
# NOTE(ebrevdo): This is a workaround for depset() not being able to tell
# the difference between 'dep' and 'clean_dep(dep)'.
@ -2253,6 +2251,23 @@ def tf_py_test(
deps = depset(deps + xla_test_true_list),
**kwargs
)
if tfrt_enabled:
py_test(
name = name + "_tfrt",
size = size,
srcs = srcs,
args = args,
data = data,
flaky = flaky,
kernels = kernels,
main = main,
shard_count = shard_count,
tags = tags,
visibility = [clean_dep("//tensorflow:internal")] +
additional_visibility,
deps = depset(deps + xla_test_true_list + ["//tensorflow/python:is_tfrt_test_true"]),
**kwargs
)
register_extension_info(
extension_name = "tf_py_test",