STT-tensorflow/tensorflow/tools/test/run_and_gather_logs.py
Hye Soo Yang c396546ca3 PY3 Migration - //tensorflow/tools [2]
PiperOrigin-RevId: 275334989
Change-Id: Ia0b660e7ce7cde8e97f8d7cb3a39afe7fec63a7f
2019-10-17 14:18:30 -07:00

152 lines
4.9 KiB
Python

# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test runner for TensorFlow tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import shlex
from string import maketrans
import sys
import time
import six
from google.protobuf import json_format
from google.protobuf import text_format
from tensorflow.core.util import test_log_pb2
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.tools.test import run_and_gather_logs_lib
# pylint: disable=g-import-not-at-top
# pylint: disable=g-bad-import-order
# pylint: disable=unused-import
# Note: cpuinfo and psutil are not installed for you in the TensorFlow
# OSS tree. They are installable via pip.
try:
import cpuinfo
import psutil
except ImportError as e:
tf_logging.error("\n\n\nERROR: Unable to import necessary library: {}. "
"Issuing a soft exit.\n\n\n".format(e))
sys.exit(0)
# pylint: enable=g-bad-import-order
# pylint: enable=unused-import
FLAGS = None
def gather_build_configuration():
build_config = test_log_pb2.BuildConfiguration()
build_config.mode = FLAGS.compilation_mode
# Include all flags except includes
cc_flags = [
flag for flag in shlex.split(FLAGS.cc_flags) if not flag.startswith("-i")
]
build_config.cc_flags.extend(cc_flags)
return build_config
def main(unused_args):
name = FLAGS.name
test_name = FLAGS.test_name
test_args = FLAGS.test_args
benchmark_type = FLAGS.benchmark_type
test_results, _ = run_and_gather_logs_lib.run_and_gather_logs(
name, test_name=test_name, test_args=test_args,
benchmark_type=benchmark_type)
# Additional bits we receive from bazel
test_results.build_configuration.CopyFrom(gather_build_configuration())
if not FLAGS.test_log_output_dir:
print(text_format.MessageToString(test_results))
return
if FLAGS.test_log_output_filename:
file_name = FLAGS.test_log_output_filename
else:
file_name = (
six.ensure_str(name).strip("/").translate(maketrans("/:", "__")) +
time.strftime("%Y%m%d%H%M%S", time.gmtime()))
if FLAGS.test_log_output_use_tmpdir:
tmpdir = test.get_temp_dir()
output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name)
else:
output_path = os.path.join(
os.path.abspath(FLAGS.test_log_output_dir), file_name)
json_test_results = json_format.MessageToJson(test_results)
gfile.GFile(six.ensure_str(output_path) + ".json",
"w").write(json_test_results)
tf_logging.info("Test results written to: %s" % output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register(
"type", "bool", lambda v: v.lower() in ("true", "t", "y", "yes"))
parser.add_argument(
"--name", type=str, default="", help="Benchmark target identifier.")
parser.add_argument(
"--test_name", type=str, default="", help="Test target to run.")
parser.add_argument(
"--benchmark_type",
type=str,
default="",
help="BenchmarkType enum string (benchmark type).")
parser.add_argument(
"--test_args",
type=str,
default="",
help="Test arguments, space separated.")
parser.add_argument(
"--test_log_output_use_tmpdir",
type="bool",
nargs="?",
const=True,
default=False,
help="Store the log output into tmpdir?")
parser.add_argument(
"--compilation_mode",
type=str,
default="",
help="Mode used during this build (e.g. opt, dbg).")
parser.add_argument(
"--cc_flags",
type=str,
default="",
help="CC flags used during this build.")
parser.add_argument(
"--test_log_output_dir",
type=str,
default="",
help="Directory to write benchmark results to.")
parser.add_argument(
"--test_log_output_filename",
type=str,
default="",
help="Filename to output benchmark results to. If the filename is not "
"specified, it will be automatically created based on --name "
"and current time.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)