Add ability to run saved_model integration tests as cuda_py_test.
This required changing the integration scripts to be triggered via the same binary as the test instead of via a py_binary added as a data dependency of the test. PiperOrigin-RevId: 247029977
This commit is contained in:
parent
6135ce95b2
commit
9e51c039a1
@ -2,18 +2,16 @@ licenses(["notice"]) # Apache 2.0
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
|
|
||||||
# This target bundles many scripts into a single py_binary so they can be
|
py_library(
|
||||||
# executed by saved_model_test without exploding the data dependencies.
|
name = "integration_scripts",
|
||||||
py_binary(
|
|
||||||
name = "run_script",
|
|
||||||
srcs = [
|
srcs = [
|
||||||
"export_mnist_cnn.py",
|
"export_mnist_cnn.py",
|
||||||
"export_rnn_cell.py",
|
"export_rnn_cell.py",
|
||||||
"export_simple_text_embedding.py",
|
"export_simple_text_embedding.py",
|
||||||
"export_text_rnn_model.py",
|
"export_text_rnn_model.py",
|
||||||
"run_script.py",
|
"integration_scripts.py",
|
||||||
"use_mnist_cnn.py",
|
"use_mnist_cnn.py",
|
||||||
"use_model_in_sequential_keras.py",
|
"use_model_in_sequential_keras.py",
|
||||||
"use_rnn_cell.py",
|
"use_rnn_cell.py",
|
||||||
@ -46,25 +44,19 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
cuda_py_test(
|
||||||
name = "saved_model_test",
|
name = "saved_model_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"saved_model_test.py",
|
"saved_model_test.py",
|
||||||
],
|
],
|
||||||
data = [
|
additional_deps = [
|
||||||
":run_script",
|
":integration_scripts",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
tags = [
|
tags = [
|
||||||
# NOTE: Split SavedModelTest due to Forge input size limit.
|
|
||||||
"no_cuda_on_cpu_tap", # forge input size exceeded
|
|
||||||
"noasan", # forge input size exceeded
|
"noasan", # forge input size exceeded
|
||||||
"nomsan", # forge input size exceeded
|
"nomsan", # forge input size exceeded
|
||||||
"notsan", # forge input size exceeded
|
"notsan", # forge input size exceeded
|
||||||
"no_pip", # b/131697937
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utility to write SavedModel integration tests.
|
||||||
|
|
||||||
|
SavedModel testing requires isolation between the process that creates and
|
||||||
|
consumes it. This file helps doing that by relaunching the same binary that
|
||||||
|
calls `assertCommandSucceeded` with an environment flag indicating what source
|
||||||
|
file to execute. That binary must start by calling `MaybeRunScriptInstead`.
|
||||||
|
|
||||||
|
This allows to wire this into existing building systems without having to depend
|
||||||
|
on data dependencies. And as so allow to keep a fixed binary size and allows
|
||||||
|
interop with GPU tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
import tensorflow.compat.v2 as tf
|
||||||
|
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase(tf.test.TestCase):
|
||||||
|
"""Base class to write SavedModel integration tests."""
|
||||||
|
|
||||||
|
def assertCommandSucceeded(self, script_name, **flags):
|
||||||
|
"""Runs an integration test script with given flags."""
|
||||||
|
run_script = sys.argv[0]
|
||||||
|
if run_script.endswith(".py"):
|
||||||
|
command_parts = [sys.executable, run_script]
|
||||||
|
else:
|
||||||
|
command_parts = [run_script]
|
||||||
|
for flag_key, flag_value in flags.items():
|
||||||
|
command_parts.append("--%s=%s" % (flag_key, flag_value))
|
||||||
|
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
|
||||||
|
logging.info("Running: %s with environment flags %s" % (command_parts, env))
|
||||||
|
subprocess.check_call(command_parts, env=dict(os.environ, **env))
|
||||||
|
|
||||||
|
|
||||||
|
def MaybeRunScriptInstead():
|
||||||
|
if "SCRIPT_NAME" in os.environ:
|
||||||
|
# Append current path to import path and execute `SCRIPT_NAME` main.
|
||||||
|
sys.path.extend([os.path.dirname(__file__)])
|
||||||
|
module_name = os.environ["SCRIPT_NAME"]
|
||||||
|
retval = app.run(importlib.import_module(module_name).main)
|
||||||
|
sys.exit(retval)
|
@ -1,36 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Utility to create a single py_binary that can call multiple py_binaries.
|
|
||||||
|
|
||||||
This simulates executing a python script by importing a module name by the
|
|
||||||
environment 'SCRIPT_NAME' and executing its main via `app.run`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Append current path to import path and execute `SCRIPT_NAME` main.
|
|
||||||
sys.path.extend([os.path.dirname(__file__)])
|
|
||||||
module_name = os.environ['SCRIPT_NAME']
|
|
||||||
app.run(importlib.import_module(module_name).main)
|
|
@ -18,26 +18,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import tensorflow.compat.v2 as tf
|
import tensorflow.compat.v2 as tf
|
||||||
|
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.examples.saved_model.integration_tests import integration_scripts
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
|
||||||
|
|
||||||
|
|
||||||
class SavedModelTest(tf.test.TestCase):
|
class SavedModelTest(integration_scripts.TestCase):
|
||||||
|
|
||||||
def assertCommandSucceeded(self, script_name, **flags):
|
|
||||||
"""Runs a test script via run_script."""
|
|
||||||
run_script = resource_loader.get_path_to_datafile("run_script")
|
|
||||||
command_parts = [run_script]
|
|
||||||
for flag_key, flag_value in flags.items():
|
|
||||||
command_parts.append("--%s=%s" % (flag_key, flag_value))
|
|
||||||
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
|
|
||||||
logging.info("Running: %s with environment flags %s" % (command_parts, env))
|
|
||||||
subprocess.check_call(command_parts, env=dict(os.environ, **env))
|
|
||||||
|
|
||||||
def test_text_rnn(self):
|
def test_text_rnn(self):
|
||||||
export_dir = self.get_temp_dir()
|
export_dir = self.get_temp_dir()
|
||||||
@ -57,6 +43,9 @@ class SavedModelTest(tf.test.TestCase):
|
|||||||
"use_model_in_sequential_keras", model_dir=export_dir)
|
"use_model_in_sequential_keras", model_dir=export_dir)
|
||||||
|
|
||||||
def test_text_embedding_in_dataset(self):
|
def test_text_embedding_in_dataset(self):
|
||||||
|
if tf.test.is_gpu_available():
|
||||||
|
self.skipTest("b/132156097 - fails if there is a gpu available")
|
||||||
|
|
||||||
export_dir = self.get_temp_dir()
|
export_dir = self.get_temp_dir()
|
||||||
self.assertCommandSucceeded(
|
self.assertCommandSucceeded(
|
||||||
"export_simple_text_embedding", export_dir=export_dir)
|
"export_simple_text_embedding", export_dir=export_dir)
|
||||||
@ -86,4 +75,5 @@ class SavedModelTest(tf.test.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
integration_scripts.MaybeRunScriptInstead()
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user