Pull out tpu_test_wrapper Starlark logic for reuse.

PiperOrigin-RevId: 314462066
Change-Id: I07b36148c4ce223f99fb74735e83527b7d85d627
This commit is contained in:
Revan Sopher 2020-06-02 21:16:11 -07:00 committed by TensorFlower Gardener
parent c13f5253e0
commit 47e08b4449
2 changed files with 114 additions and 66 deletions

View File

@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Provides python test rules for Cloud TPU."""
load(
"//tensorflow/python/tpu:tpu_test_wrapper.bzl",
_get_kwargs_for_wrapping = "get_kwargs_for_wrapping",
)
def tpu_py_test(
name,
tags = None,
@ -36,71 +40,12 @@ def tpu_py_test(
args: Arguments to apply to tests.
**kwargs: Additional named arguments to apply to tests.
"""
tags = tags or []
tags = [
"tpu",
"no_pip",
"no_gpu",
"nomac",
"local",
] + tags
test_main = kwargs.get("srcs")
if not test_main or len(test_main) > 1:
fail('"srcs" should be a list of exactly one python file.')
test_main = test_main[0]
wrapper_src = _copy_test_source(
"//tensorflow/python/tpu:tpu_test_wrapper.py",
)
kwargs["python_version"] = kwargs.get("python_version", "PY3")
kwargs["srcs"].append(wrapper_src)
kwargs["deps"].append("//tensorflow/python:client_testlib")
kwargs["main"] = wrapper_src
args = [
"--wrapped_tpu_test_module_relative=.%s" % test_main.rsplit(".", 1)[0],
] + args
native.py_test(
name = name,
tags = tags,
args = args,
**kwargs
)
def _copy_test_source(src):
"""Creates a genrule copying src into the current directory.
This silences a Bazel warning, and is necessary for relative import of the
user test to work.
This genrule checks existing rules to avoid duplicating the source if
another call has already produced the file. Note that this will fail
weirdly if two source files have the same filename, as whichever one is
copied in first will win and other tests will unexpectedly run the wrong
file. We don't expect to see this case, since we're only copying the one
test wrapper around.
Args:
src: The source file we would like to use.
Returns:
The path of a copy of this source file, inside the current package.
"""
name = src.rpartition(":")[-1].rpartition("/")[-1] # Get basename.
new_main = "%s/%s" % (native.package_name(), name)
new_name = "_gen_" + name
if not native.existing_rule(new_name):
native.genrule(
name = new_name,
srcs = [src],
outs = [new_main],
cmd = "cp $< $@",
**_get_kwargs_for_wrapping(
name,
tags,
args,
**kwargs
)
return new_main
)

View File

@ -0,0 +1,103 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Provides Starlark helpers for Cloud TPU."""
def get_kwargs_for_wrapping(
name,
tags = None,
args = [],
**kwargs):
"""Generates the kwargs for constructing a wrapped TPU test.
Args:
name: Name of test. Will be prefixed by accelerator versions.
tags: BUILD tags to apply to tests.
args: Arguments to apply to tests.
**kwargs: Additional named arguments to apply to tests.
Returns:
A dict to be splatted into a py_binary or py_test.
"""
tags = tags or []
tags = [
"tpu",
"no_pip",
"no_gpu",
"nomac",
"local",
] + tags
test_main = kwargs.get("srcs")
if not test_main or len(test_main) > 1:
fail('"srcs" should be a list of exactly one python file.')
test_main = test_main[0]
wrapper_src = _copy_test_source(
"//tensorflow/python/tpu:tpu_test_wrapper.py",
)
deps = depset(kwargs["deps"])
kwargs["python_version"] = kwargs.get("python_version", "PY3")
kwargs["srcs"] = [wrapper_src] + kwargs["srcs"]
kwargs["deps"] = depset(
["//tensorflow/python:client_testlib"],
transitive = [deps],
)
kwargs["main"] = wrapper_src
args = [
"--wrapped_tpu_test_module_relative=.%s" % test_main.rsplit(".", 1)[0],
] + args
kwargs["name"] = name
kwargs["tags"] = tags
kwargs["args"] = args
return kwargs
def _copy_test_source(src):
"""Creates a genrule copying src into the current directory.
This silences a Bazel warning, and is necessary for relative import of the
user test to work.
This genrule checks existing rules to avoid duplicating the source if
another call has already produced the file. Note that this will fail
weirdly if two source files have the same filename, as whichever one is
copied in first will win and other tests will unexpectedly run the wrong
file. We don't expect to see this case, since we're only copying the one
test wrapper around.
Args:
src: The source file we would like to use.
Returns:
The path of a copy of this source file, inside the current package.
"""
name = src.rpartition(":")[-1].rpartition("/")[-1] # Get basename.
new_main = "%s/%s" % (native.package_name(), name)
new_name = "_gen_" + name
if not native.existing_rule(new_name):
native.genrule(
name = new_name,
srcs = [src],
outs = [new_main],
cmd = "cp $< $@",
)
return new_main