From 47e08b44498f6ef6f726693e2cd935c8f8561571 Mon Sep 17 00:00:00 2001 From: Revan Sopher Date: Tue, 2 Jun 2020 21:16:11 -0700 Subject: [PATCH] Pull out tpu_test_wrapper Starlark logic for reuse. PiperOrigin-RevId: 314462066 Change-Id: I07b36148c4ce223f99fb74735e83527b7d85d627 --- tensorflow/python/tpu/tpu.bzl | 77 +++------------ tensorflow/python/tpu/tpu_test_wrapper.bzl | 103 +++++++++++++++++++++ 2 files changed, 114 insertions(+), 66 deletions(-) create mode 100644 tensorflow/python/tpu/tpu_test_wrapper.bzl diff --git a/tensorflow/python/tpu/tpu.bzl b/tensorflow/python/tpu/tpu.bzl index ba58e57e90d..5453702d64d 100644 --- a/tensorflow/python/tpu/tpu.bzl +++ b/tensorflow/python/tpu/tpu.bzl @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - """Provides python test rules for Cloud TPU.""" +load( + "//tensorflow/python/tpu:tpu_test_wrapper.bzl", + _get_kwargs_for_wrapping = "get_kwargs_for_wrapping", +) + def tpu_py_test( name, tags = None, @@ -36,71 +40,12 @@ def tpu_py_test( args: Arguments to apply to tests. **kwargs: Additional named arguments to apply to tests. """ - tags = tags or [] - - tags = [ - "tpu", - "no_pip", - "no_gpu", - "nomac", - "local", - ] + tags - - test_main = kwargs.get("srcs") - if not test_main or len(test_main) > 1: - fail('"srcs" should be a list of exactly one python file.') - test_main = test_main[0] - - wrapper_src = _copy_test_source( - "//tensorflow/python/tpu:tpu_test_wrapper.py", - ) - - kwargs["python_version"] = kwargs.get("python_version", "PY3") - kwargs["srcs"].append(wrapper_src) - kwargs["deps"].append("//tensorflow/python:client_testlib") - kwargs["main"] = wrapper_src - - args = [ - "--wrapped_tpu_test_module_relative=.%s" % test_main.rsplit(".", 1)[0], - ] + args native.py_test( - name = name, - tags = tags, - args = args, - **kwargs - ) - -def _copy_test_source(src): - """Creates a genrule copying src into the current directory. - - This silences a Bazel warning, and is necessary for relative import of the - user test to work. - - This genrule checks existing rules to avoid duplicating the source if - another call has already produced the file. Note that this will fail - weirdly if two source files have the same filename, as whichever one is - copied in first will win and other tests will unexpectedly run the wrong - file. We don't expect to see this case, since we're only copying the one - test wrapper around. - - Args: - src: The source file we would like to use. - - Returns: - The path of a copy of this source file, inside the current package. - """ - name = src.rpartition(":")[-1].rpartition("/")[-1] # Get basename. - - new_main = "%s/%s" % (native.package_name(), name) - new_name = "_gen_" + name - - if not native.existing_rule(new_name): - native.genrule( - name = new_name, - srcs = [src], - outs = [new_main], - cmd = "cp $< $@", + **_get_kwargs_for_wrapping( + name, + tags, + args, + **kwargs ) - - return new_main + ) diff --git a/tensorflow/python/tpu/tpu_test_wrapper.bzl b/tensorflow/python/tpu/tpu_test_wrapper.bzl new file mode 100644 index 00000000000..c8ea62c0e5a --- /dev/null +++ b/tensorflow/python/tpu/tpu_test_wrapper.bzl @@ -0,0 +1,103 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Provides Starlark helpers for Cloud TPU.""" + +def get_kwargs_for_wrapping( + name, + tags = None, + args = [], + **kwargs): + """Generates the kwargs for constructing a wrapped TPU test. + + Args: + name: Name of test. Will be prefixed by accelerator versions. + tags: BUILD tags to apply to tests. + args: Arguments to apply to tests. + **kwargs: Additional named arguments to apply to tests. + + Returns: + A dict to be splatted into a py_binary or py_test. + """ + tags = tags or [] + + tags = [ + "tpu", + "no_pip", + "no_gpu", + "nomac", + "local", + ] + tags + + test_main = kwargs.get("srcs") + if not test_main or len(test_main) > 1: + fail('"srcs" should be a list of exactly one python file.') + test_main = test_main[0] + + wrapper_src = _copy_test_source( + "//tensorflow/python/tpu:tpu_test_wrapper.py", + ) + + deps = depset(kwargs["deps"]) + kwargs["python_version"] = kwargs.get("python_version", "PY3") + kwargs["srcs"] = [wrapper_src] + kwargs["srcs"] + kwargs["deps"] = depset( + ["//tensorflow/python:client_testlib"], + transitive = [deps], + ) + kwargs["main"] = wrapper_src + + args = [ + "--wrapped_tpu_test_module_relative=.%s" % test_main.rsplit(".", 1)[0], + ] + args + + kwargs["name"] = name + kwargs["tags"] = tags + kwargs["args"] = args + + return kwargs + +def _copy_test_source(src): + """Creates a genrule copying src into the current directory. + + This silences a Bazel warning, and is necessary for relative import of the + user test to work. + + This genrule checks existing rules to avoid duplicating the source if + another call has already produced the file. Note that this will fail + weirdly if two source files have the same filename, as whichever one is + copied in first will win and other tests will unexpectedly run the wrong + file. We don't expect to see this case, since we're only copying the one + test wrapper around. + + Args: + src: The source file we would like to use. + + Returns: + The path of a copy of this source file, inside the current package. + """ + name = src.rpartition(":")[-1].rpartition("/")[-1] # Get basename. + + new_main = "%s/%s" % (native.package_name(), name) + new_name = "_gen_" + name + + if not native.existing_rule(new_name): + native.genrule( + name = new_name, + srcs = [src], + outs = [new_main], + cmd = "cp $< $@", + ) + + return new_main