Enable TF Hub's tests with KerasLayer.
PiperOrigin-RevId: 253730171
This commit is contained in:
parent
e906cdde1b
commit
99d71381f2
@ -23,6 +23,7 @@ py_library(
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":distribution_strategy_utils",
|
||||
":mnist_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -38,14 +39,26 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distribution_strategy_utils",
|
||||
srcs = ["distribution_strategy_utils.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "saved_model_test",
|
||||
srcs = [
|
||||
"saved_model_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":distribution_strategy_utils",
|
||||
":integration_scripts",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
|
@ -0,0 +1,38 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils related to tf.distribute.strategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
|
||||
_strategies = [
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
]
|
||||
|
||||
named_strategies = collections.OrderedDict()
|
||||
named_strategies[None] = None
|
||||
|
||||
|
||||
for strategy in _strategies:
|
||||
named_strategies[str(strategy)] = strategy
|
@ -33,7 +33,7 @@ def _load_random_data(num_train_and_test):
|
||||
|
||||
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
|
||||
"""Returns MNIST or Fashion MNIST or fake train and test data."""
|
||||
load = ((lambda: _load_random_data([10, 50])) if fake_tiny_data else
|
||||
load = ((lambda: _load_random_data([16, 128])) if fake_tiny_data else
|
||||
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
|
||||
tf.keras.datasets.mnist.load_data)
|
||||
(x_train, y_train), (x_test, y_test) = load()
|
||||
|
@ -23,10 +23,12 @@ import os
|
||||
from absl.testing import parameterized
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.examples.saved_model.integration_tests import integration_scripts
|
||||
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
||||
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
|
||||
from tensorflow.python.distribute import combinations
|
||||
|
||||
|
||||
class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
|
||||
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
|
||||
def __init__(self, method_name="runTest", has_extra_deps=False):
|
||||
super(SavedModelTest, self).__init__(method_name)
|
||||
@ -71,40 +73,48 @@ class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
|
||||
self.assertCommandSucceeded(
|
||||
"use_text_embedding_in_dataset", model_dir=export_dir)
|
||||
|
||||
NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = (
|
||||
("", dict()),
|
||||
("_with_retraining", dict(
|
||||
retrain=True,
|
||||
regularization_loss_multiplier=2, # Test impact of b/134528831.
|
||||
)),
|
||||
("_with_mirrored_strategy", dict(
|
||||
retrain=True, # That's the relevant case for distribution.
|
||||
use_mirrored_strategy=True,
|
||||
)),
|
||||
)
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
named_strategy=list(ds_utils.named_strategies.values()),
|
||||
retrain_flag_value=["true", "false"],
|
||||
regularization_loss_multiplier=[0, 2]),
|
||||
test_combinations=[combinations.NamedGPUCombination()])
|
||||
def test_mnist_cnn(self, named_strategy, retrain_flag_value,
|
||||
regularization_loss_multiplier):
|
||||
|
||||
@parameterized.named_parameters(*NAMED_PARAMETERS_FOR_TEST_MNIST_CNN)
|
||||
def test_mnist_cnn(self, use_kwargs):
|
||||
self.skipIfMissingExtraDeps()
|
||||
if use_kwargs.get("use_mirrored_strategy", None):
|
||||
self.skipTest(
|
||||
"b/129134185 - saved model and distribution strategy integration")
|
||||
|
||||
fast_test_mode = True
|
||||
temp_dir = self.get_temp_dir()
|
||||
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
|
||||
|
||||
# TODO(b/135043074): remove this if-else.
|
||||
if named_strategy is None:
|
||||
full_model_dir = os.path.join(temp_dir, "full_model")
|
||||
else:
|
||||
full_model_dir = None
|
||||
|
||||
self.assertCommandSucceeded(
|
||||
"export_mnist_cnn", fast_test_mode=fast_test_mode,
|
||||
"export_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
export_dir=feature_extrator_dir)
|
||||
|
||||
self.assertCommandSucceeded(
|
||||
"use_mnist_cnn", fast_test_mode=fast_test_mode,
|
||||
"use_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
input_saved_model_dir=feature_extrator_dir,
|
||||
output_saved_model_dir=full_model_dir, **use_kwargs)
|
||||
output_saved_model_dir=full_model_dir,
|
||||
strategy=str(named_strategy),
|
||||
retrain=retrain_flag_value,
|
||||
regularization_loss_multiplier=regularization_loss_multiplier)
|
||||
|
||||
if full_model_dir is not None:
|
||||
self.assertCommandSucceeded(
|
||||
"deploy_mnist_cnn", fast_test_mode=fast_test_mode,
|
||||
"deploy_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
saved_model_dir=full_model_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
integration_scripts.MaybeRunScriptInstead()
|
||||
scripts.MaybeRunScriptInstead()
|
||||
tf.test.main()
|
||||
|
@ -31,6 +31,7 @@ from absl import flags
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -57,12 +58,29 @@ flags.DEFINE_bool(
|
||||
flags.DEFINE_bool(
|
||||
'fast_test_mode', False,
|
||||
'Shortcut training for running in unit tests.')
|
||||
flags.DEFINE_bool(
|
||||
'use_mirrored_strategy', False,
|
||||
'Whether to use mirrored distribution strategy.')
|
||||
flags.DEFINE_string(
|
||||
'output_saved_model_dir', None,
|
||||
'Directory of the SavedModel that was exported for reuse.')
|
||||
flags.DEFINE_string('strategy', None,
|
||||
'Name of the distribution strategy to use.')
|
||||
|
||||
|
||||
class MaybeDistributionScope(object):
|
||||
"""Provides a context allowing no distribution strategy."""
|
||||
|
||||
def __init__(self, distribution):
|
||||
self._distribution = distribution
|
||||
self._scope = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._distribution:
|
||||
self._scope = self._distribution.scope()
|
||||
self._scope.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
if self._distribution:
|
||||
self._scope.__exit__(exc_type, value, traceback)
|
||||
self._scope = None
|
||||
|
||||
|
||||
def make_feature_extractor(saved_model_path, trainable,
|
||||
@ -108,12 +126,11 @@ def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5):
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
if FLAGS.use_mirrored_strategy:
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
else:
|
||||
strategy = tf.distribute.get_strategy()
|
||||
named_strategy = (
|
||||
ds_utils.named_strategies.get(FLAGS.strategy) if FLAGS.strategy else None)
|
||||
strategy = named_strategy.strategy if named_strategy else None
|
||||
|
||||
with strategy.scope():
|
||||
with MaybeDistributionScope(strategy):
|
||||
feature_extractor = make_feature_extractor(
|
||||
FLAGS.input_saved_model_dir,
|
||||
FLAGS.retrain,
|
||||
@ -137,7 +154,7 @@ def main(argv):
|
||||
verbose=1,
|
||||
validation_data=(x_test, y_test))
|
||||
|
||||
if FLAGS.output_saved_model_dir:
|
||||
if FLAGS.output_saved_model_dir and FLAGS.output_saved_model_dir != 'None':
|
||||
tf.saved_model.save(model, FLAGS.output_saved_model_dir)
|
||||
|
||||
|
||||
|
@ -72,8 +72,9 @@ BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
|
||||
|
||||
|
||||
CacheKey = collections.namedtuple("CacheKey", [
|
||||
"input_signature", "parent_graph", "device_functions",
|
||||
"colocation_stack"])
|
||||
"input_signature", "parent_graph", "device_functions", "colocation_stack",
|
||||
"in_cross_replica_context"
|
||||
])
|
||||
|
||||
CacheKey.replace = CacheKey._replace # pylint: disable=protected-access
|
||||
|
||||
@ -1562,9 +1563,15 @@ class Function(object):
|
||||
device_functions = tuple(default_graph._device_functions_outer_to_inner)
|
||||
else:
|
||||
device_functions = ()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
in_cross_replica_context = False
|
||||
try:
|
||||
in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
|
||||
return CacheKey(input_signature, parent_graph, device_functions,
|
||||
colocation_stack)
|
||||
colocation_stack, in_cross_replica_context)
|
||||
|
||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||
@ -1664,6 +1671,7 @@ class Function(object):
|
||||
if self.input_signature is None or args is not None or kwargs is not None:
|
||||
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
|
||||
cache_key = self._cache_key(args, kwargs)
|
||||
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user