Move tfrt_enabled target duplication logic to tf_py_test
The tfrt_enabled target duplication logic was initially added to cuda_py_test, but there are tests that directly use tf_py_test, so add to tf_py_test instead. PiperOrigin-RevId: 307719942 Change-Id: I90bb1009c5cadf1a5a0a3d036a3a174d31b84eac
This commit is contained in:
parent
bd0feba7b4
commit
86fd918743
@ -2223,8 +2223,6 @@ def tf_py_test(
|
||||
deps = deps + tf_additional_xla_deps_py()
|
||||
if grpc_enabled:
|
||||
deps = deps + tf_additional_grpc_deps_py()
|
||||
if tfrt_enabled:
|
||||
deps = deps + ["//tensorflow/python:is_tfrt_test_true"]
|
||||
|
||||
# NOTE(ebrevdo): This is a workaround for depset() not being able to tell
|
||||
# the difference between 'dep' and 'clean_dep(dep)'.
|
||||
@ -2253,6 +2251,23 @@ def tf_py_test(
|
||||
deps = depset(deps + xla_test_true_list),
|
||||
**kwargs
|
||||
)
|
||||
if tfrt_enabled:
|
||||
py_test(
|
||||
name = name + "_tfrt",
|
||||
size = size,
|
||||
srcs = srcs,
|
||||
args = args,
|
||||
data = data,
|
||||
flaky = flaky,
|
||||
kernels = kernels,
|
||||
main = main,
|
||||
shard_count = shard_count,
|
||||
tags = tags,
|
||||
visibility = [clean_dep("//tensorflow:internal")] +
|
||||
additional_visibility,
|
||||
deps = depset(deps + xla_test_true_list + ["//tensorflow/python:is_tfrt_test_true"]),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
register_extension_info(
|
||||
extension_name = "tf_py_test",
|
||||
|
Loading…
x
Reference in New Issue
Block a user