200 lines
7.0 KiB
Python
200 lines
7.0 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
"""Wrapper for Python TPU tests.
|
|
|
|
The py_tpu_test macro will actually use this file as its main, building and
|
|
executing the user-provided test file as a py_binary instead. This lets us do
|
|
important work behind the scenes, without complicating the tests themselves.
|
|
|
|
The main responsibilities of this file are:
|
|
- Define standard set of model flags if test did not. This allows us to
|
|
safely set flags at the Bazel invocation level using --test_arg.
|
|
- Pick a random directory on GCS to use for each test case, and set it as the
|
|
default value of --model_dir. This is similar to how Bazel provides each
|
|
test with a fresh local directory in $TEST_TMPDIR.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ast
|
|
import importlib
|
|
import os
|
|
import sys
|
|
import uuid
|
|
|
|
from tensorflow.python.platform import flags
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_string(
|
|
'wrapped_tpu_test_module_relative', None,
|
|
'The Python-style relative path to the user-given test. If test is in same '
|
|
'directory as BUILD file as is common, then "test.py" would be ".test".')
|
|
flags.DEFINE_string('test_dir_base',
|
|
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
|
|
'GCS path to root directory for temporary test files.')
|
|
flags.DEFINE_string(
|
|
'bazel_repo_root', 'tensorflow/python',
|
|
'Substring of a bazel filepath beginning the python absolute import path.')
|
|
|
|
# List of flags which all TPU tests should accept.
|
|
REQUIRED_FLAGS = ['tpu', 'zone', 'project', 'model_dir']
|
|
|
|
|
|
def maybe_define_flags():
|
|
"""Defines any required flags that are missing."""
|
|
for f in REQUIRED_FLAGS:
|
|
try:
|
|
flags.DEFINE_string(f, None, 'flag defined by test lib')
|
|
except flags.DuplicateFlagError:
|
|
pass
|
|
|
|
|
|
def set_random_test_dir():
|
|
"""Pick a random GCS directory under --test_dir_base, set as --model_dir."""
|
|
path = os.path.join(FLAGS.test_dir_base, uuid.uuid4().hex)
|
|
FLAGS.set_default('model_dir', path)
|
|
|
|
|
|
def calculate_parent_python_path(test_filepath):
|
|
"""Returns the absolute import path for the containing directory.
|
|
|
|
Args:
|
|
test_filepath: The filepath which Bazel invoked
|
|
(ex: /filesystem/path/tensorflow/tensorflow/python/tpu/tpu_test)
|
|
|
|
Returns:
|
|
Absolute import path of parent (ex: tensorflow.python.tpu).
|
|
|
|
Raises:
|
|
ValueError: if bazel_repo_root does not appear within test_filepath.
|
|
"""
|
|
# We find the last occurrence of bazel_repo_root, and drop everything before.
|
|
split_path = test_filepath.rsplit(FLAGS.bazel_repo_root, 1)
|
|
if len(split_path) < 2:
|
|
raise ValueError('Filepath "%s" does not contain repo root "%s"' %
|
|
(test_filepath, FLAGS.bazel_repo_root))
|
|
path = FLAGS.bazel_repo_root + split_path[1]
|
|
|
|
# We drop the last portion of the path, which is the name of the test wrapper.
|
|
path = path.rsplit('/', 1)[0]
|
|
|
|
# We convert the directory separators into dots.
|
|
return path.replace('/', '.')
|
|
|
|
|
|
def import_user_module():
|
|
"""Imports the flag-specified user test code.
|
|
|
|
This runs all top-level statements in the user module, specifically flag
|
|
definitions.
|
|
|
|
Returns:
|
|
The user test module.
|
|
"""
|
|
return importlib.import_module(FLAGS.wrapped_tpu_test_module_relative,
|
|
calculate_parent_python_path(sys.argv[0]))
|
|
|
|
|
|
def _is_test_class(obj):
|
|
"""Check if arbitrary object is a test class (not a test object!).
|
|
|
|
Args:
|
|
obj: An arbitrary object from within a module.
|
|
|
|
Returns:
|
|
True iff obj is a test class inheriting at some point from a module
|
|
named "TestCase". This is because we write tests using different underlying
|
|
test libraries.
|
|
"""
|
|
return (tf_inspect.isclass(obj)
|
|
and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj)))
|
|
|
|
|
|
module_variables = vars()
|
|
|
|
|
|
def move_test_classes_into_scope(wrapped_test_module):
|
|
"""Add all test classes defined in wrapped module to our module.
|
|
|
|
The test runner works by inspecting the main module for TestCase classes, so
|
|
by adding a module-level reference to the TestCase we cause it to execute the
|
|
wrapped TestCase.
|
|
|
|
Args:
|
|
wrapped_test_module: The user-provided test code to run.
|
|
"""
|
|
for name, obj in wrapped_test_module.__dict__.items():
|
|
if _is_test_class(obj):
|
|
module_variables['tpu_test_imported_%s' % name] = obj
|
|
|
|
|
|
def run_user_main(wrapped_test_module):
|
|
"""Runs the "if __name__ == '__main__'" at the bottom of a module.
|
|
|
|
TensorFlow practice is to have a main if at the bottom of the module which
|
|
might call an API compat function before calling test.main().
|
|
|
|
Since this is a statement, not a function, we can't cleanly reference it, but
|
|
we can inspect it from the user module and run it in the context of that
|
|
module so all imports and variables are available to it.
|
|
|
|
Args:
|
|
wrapped_test_module: The user-provided test code to run.
|
|
|
|
Raises:
|
|
NotImplementedError: If main block was not found in module. This should not
|
|
be caught, as it is likely an error on the user's part -- absltest is all
|
|
too happy to report a successful status (and zero tests executed) if a
|
|
user forgets to end a class with "test.main()".
|
|
"""
|
|
tree = ast.parse(tf_inspect.getsource(wrapped_test_module))
|
|
|
|
# Get string representation of just the condition `__name == "__main__"`.
|
|
target = ast.dump(ast.parse('if __name__ == "__main__": pass').body[0].test)
|
|
|
|
# `tree.body` is a list of top-level statements in the module, like imports
|
|
# and class definitions. We search for our main block, starting from the end.
|
|
for expr in reversed(tree.body):
|
|
if isinstance(expr, ast.If) and ast.dump(expr.test) == target:
|
|
break
|
|
else:
|
|
raise NotImplementedError(
|
|
'Could not find `if __name__ == "main":` block in %s.' %
|
|
wrapped_test_module.__name__)
|
|
|
|
# expr is defined because we would have raised an error otherwise.
|
|
new_ast = ast.Module(body=expr.body, type_ignores=[]) # pylint:disable=undefined-loop-variable
|
|
exec( # pylint:disable=exec-used
|
|
compile(new_ast, '<ast>', 'exec'),
|
|
globals(),
|
|
wrapped_test_module.__dict__,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Partially parse flags, since module to import is specified by flag.
|
|
unparsed = FLAGS(sys.argv, known_only=True)
|
|
user_module = import_user_module()
|
|
maybe_define_flags()
|
|
# Parse remaining flags.
|
|
FLAGS(unparsed)
|
|
set_random_test_dir()
|
|
|
|
move_test_classes_into_scope(user_module)
|
|
run_user_main(user_module)
|