Add ability to run saved_model integration tests as cuda_py_test.
This required changing the integration scripts to be triggered via the same binary as the test instead of via a py_binary added as a data dependency of the test. PiperOrigin-RevId: 247029977
This commit is contained in:
parent
6135ce95b2
commit
9e51c039a1
@ -2,18 +2,16 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
# This target bundles many scripts into a single py_binary so they can be
|
||||
# executed by saved_model_test without exploding the data dependencies.
|
||||
py_binary(
|
||||
name = "run_script",
|
||||
py_library(
|
||||
name = "integration_scripts",
|
||||
srcs = [
|
||||
"export_mnist_cnn.py",
|
||||
"export_rnn_cell.py",
|
||||
"export_simple_text_embedding.py",
|
||||
"export_text_rnn_model.py",
|
||||
"run_script.py",
|
||||
"integration_scripts.py",
|
||||
"use_mnist_cnn.py",
|
||||
"use_model_in_sequential_keras.py",
|
||||
"use_rnn_cell.py",
|
||||
@ -46,25 +44,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "saved_model_test",
|
||||
srcs = [
|
||||
"saved_model_test.py",
|
||||
],
|
||||
data = [
|
||||
":run_script",
|
||||
additional_deps = [
|
||||
":integration_scripts",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
# NOTE: Split SavedModelTest due to Forge input size limit.
|
||||
"no_cuda_on_cpu_tap", # forge input size exceeded
|
||||
"noasan", # forge input size exceeded
|
||||
"nomsan", # forge input size exceeded
|
||||
"notsan", # forge input size exceeded
|
||||
"no_pip", # b/131697937
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility to write SavedModel integration tests.
|
||||
|
||||
SavedModel testing requires isolation between the process that creates and
|
||||
consumes it. This file helps doing that by relaunching the same binary that
|
||||
calls `assertCommandSucceeded` with an environment flag indicating what source
|
||||
file to execute. That binary must start by calling `MaybeRunScriptInstead`.
|
||||
|
||||
This allows to wire this into existing building systems without having to depend
|
||||
on data dependencies. And as so allow to keep a fixed binary size and allows
|
||||
interop with GPU tests.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
class TestCase(tf.test.TestCase):
|
||||
"""Base class to write SavedModel integration tests."""
|
||||
|
||||
def assertCommandSucceeded(self, script_name, **flags):
|
||||
"""Runs an integration test script with given flags."""
|
||||
run_script = sys.argv[0]
|
||||
if run_script.endswith(".py"):
|
||||
command_parts = [sys.executable, run_script]
|
||||
else:
|
||||
command_parts = [run_script]
|
||||
for flag_key, flag_value in flags.items():
|
||||
command_parts.append("--%s=%s" % (flag_key, flag_value))
|
||||
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
|
||||
logging.info("Running: %s with environment flags %s" % (command_parts, env))
|
||||
subprocess.check_call(command_parts, env=dict(os.environ, **env))
|
||||
|
||||
|
||||
def MaybeRunScriptInstead():
|
||||
if "SCRIPT_NAME" in os.environ:
|
||||
# Append current path to import path and execute `SCRIPT_NAME` main.
|
||||
sys.path.extend([os.path.dirname(__file__)])
|
||||
module_name = os.environ["SCRIPT_NAME"]
|
||||
retval = app.run(importlib.import_module(module_name).main)
|
||||
sys.exit(retval)
|
@ -1,36 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility to create a single py_binary that can call multiple py_binaries.
|
||||
|
||||
This simulates executing a python script by importing a module name by the
|
||||
environment 'SCRIPT_NAME' and executing its main via `app.run`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Append current path to import path and execute `SCRIPT_NAME` main.
|
||||
sys.path.extend([os.path.dirname(__file__)])
|
||||
module_name = os.environ['SCRIPT_NAME']
|
||||
app.run(importlib.import_module(module_name).main)
|
@ -18,26 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.examples.saved_model.integration_tests import integration_scripts
|
||||
|
||||
|
||||
class SavedModelTest(tf.test.TestCase):
|
||||
|
||||
def assertCommandSucceeded(self, script_name, **flags):
|
||||
"""Runs a test script via run_script."""
|
||||
run_script = resource_loader.get_path_to_datafile("run_script")
|
||||
command_parts = [run_script]
|
||||
for flag_key, flag_value in flags.items():
|
||||
command_parts.append("--%s=%s" % (flag_key, flag_value))
|
||||
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
|
||||
logging.info("Running: %s with environment flags %s" % (command_parts, env))
|
||||
subprocess.check_call(command_parts, env=dict(os.environ, **env))
|
||||
class SavedModelTest(integration_scripts.TestCase):
|
||||
|
||||
def test_text_rnn(self):
|
||||
export_dir = self.get_temp_dir()
|
||||
@ -57,6 +43,9 @@ class SavedModelTest(tf.test.TestCase):
|
||||
"use_model_in_sequential_keras", model_dir=export_dir)
|
||||
|
||||
def test_text_embedding_in_dataset(self):
|
||||
if tf.test.is_gpu_available():
|
||||
self.skipTest("b/132156097 - fails if there is a gpu available")
|
||||
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded(
|
||||
"export_simple_text_embedding", export_dir=export_dir)
|
||||
@ -86,4 +75,5 @@ class SavedModelTest(tf.test.TestCase):
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
integration_scripts.MaybeRunScriptInstead()
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user