Enable TF Hub's tests with KerasLayer.

PiperOrigin-RevId: 253730171
This commit is contained in:
Shining Sun 2019-06-17 23:09:12 -07:00 committed by TensorFlower Gardener
parent e906cdde1b
commit 99d71381f2
6 changed files with 126 additions and 40 deletions

View File

@ -23,6 +23,7 @@ py_library(
],
visibility = ["//tensorflow:internal"],
deps = [
":distribution_strategy_utils",
":mnist_util",
"//tensorflow:tensorflow_py",
"@absl_py//absl/testing:parameterized",
@ -38,14 +39,26 @@ py_library(
],
)
py_library(
name = "distribution_strategy_utils",
srcs = ["distribution_strategy_utils.py"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python/distribute:strategy_combinations",
],
)
cuda_py_test(
name = "saved_model_test",
srcs = [
"saved_model_test.py",
],
additional_deps = [
":distribution_strategy_utils",
":integration_scripts",
"@absl_py//absl/testing:parameterized",
"//tensorflow:tensorflow_py",
"//tensorflow/python/distribute:combinations",
],
shard_count = 4,
tags = [

View File

@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to tf.distribute.strategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.distribute import strategy_combinations
_strategies = [
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
]
named_strategies = collections.OrderedDict()
named_strategies[None] = None
for strategy in _strategies:
named_strategies[str(strategy)] = strategy

View File

@ -33,7 +33,7 @@ def _load_random_data(num_train_and_test):
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
"""Returns MNIST or Fashion MNIST or fake train and test data."""
load = ((lambda: _load_random_data([10, 50])) if fake_tiny_data else
load = ((lambda: _load_random_data([16, 128])) if fake_tiny_data else
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
tf.keras.datasets.mnist.load_data)
(x_train, y_train), (x_test, y_test) = load()

View File

@ -23,10 +23,12 @@ import os
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
from tensorflow.examples.saved_model.integration_tests import integration_scripts
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
from tensorflow.python.distribute import combinations
class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
def __init__(self, method_name="runTest", has_extra_deps=False):
super(SavedModelTest, self).__init__(method_name)
@ -71,40 +73,48 @@ class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
self.assertCommandSucceeded(
"use_text_embedding_in_dataset", model_dir=export_dir)
NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = (
("", dict()),
("_with_retraining", dict(
retrain=True,
regularization_loss_multiplier=2, # Test impact of b/134528831.
)),
("_with_mirrored_strategy", dict(
retrain=True, # That's the relevant case for distribution.
use_mirrored_strategy=True,
)),
)
@combinations.generate(
combinations.combine(
named_strategy=list(ds_utils.named_strategies.values()),
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[0, 2]),
test_combinations=[combinations.NamedGPUCombination()])
def test_mnist_cnn(self, named_strategy, retrain_flag_value,
regularization_loss_multiplier):
@parameterized.named_parameters(*NAMED_PARAMETERS_FOR_TEST_MNIST_CNN)
def test_mnist_cnn(self, use_kwargs):
self.skipIfMissingExtraDeps()
if use_kwargs.get("use_mirrored_strategy", None):
self.skipTest(
"b/129134185 - saved model and distribution strategy integration")
fast_test_mode = True
temp_dir = self.get_temp_dir()
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
full_model_dir = os.path.join(temp_dir, "full_model")
# TODO(b/135043074): remove this if-else.
if named_strategy is None:
full_model_dir = os.path.join(temp_dir, "full_model")
else:
full_model_dir = None
self.assertCommandSucceeded(
"export_mnist_cnn", fast_test_mode=fast_test_mode,
"export_mnist_cnn",
fast_test_mode=fast_test_mode,
export_dir=feature_extrator_dir)
self.assertCommandSucceeded(
"use_mnist_cnn", fast_test_mode=fast_test_mode,
"use_mnist_cnn",
fast_test_mode=fast_test_mode,
input_saved_model_dir=feature_extrator_dir,
output_saved_model_dir=full_model_dir, **use_kwargs)
self.assertCommandSucceeded(
"deploy_mnist_cnn", fast_test_mode=fast_test_mode,
saved_model_dir=full_model_dir)
output_saved_model_dir=full_model_dir,
strategy=str(named_strategy),
retrain=retrain_flag_value,
regularization_loss_multiplier=regularization_loss_multiplier)
if full_model_dir is not None:
self.assertCommandSucceeded(
"deploy_mnist_cnn",
fast_test_mode=fast_test_mode,
saved_model_dir=full_model_dir)
if __name__ == "__main__":
integration_scripts.MaybeRunScriptInstead()
scripts.MaybeRunScriptInstead()
tf.test.main()

View File

@ -31,6 +31,7 @@ from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
from tensorflow.examples.saved_model.integration_tests import mnist_util
FLAGS = flags.FLAGS
@ -57,12 +58,29 @@ flags.DEFINE_bool(
flags.DEFINE_bool(
'fast_test_mode', False,
'Shortcut training for running in unit tests.')
flags.DEFINE_bool(
'use_mirrored_strategy', False,
'Whether to use mirrored distribution strategy.')
flags.DEFINE_string(
'output_saved_model_dir', None,
'Directory of the SavedModel that was exported for reuse.')
flags.DEFINE_string('strategy', None,
'Name of the distribution strategy to use.')
class MaybeDistributionScope(object):
"""Provides a context allowing no distribution strategy."""
def __init__(self, distribution):
self._distribution = distribution
self._scope = None
def __enter__(self):
if self._distribution:
self._scope = self._distribution.scope()
self._scope.__enter__()
def __exit__(self, exc_type, value, traceback):
if self._distribution:
self._scope.__exit__(exc_type, value, traceback)
self._scope = None
def make_feature_extractor(saved_model_path, trainable,
@ -108,12 +126,11 @@ def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5):
def main(argv):
del argv
if FLAGS.use_mirrored_strategy:
strategy = tf.distribute.MirroredStrategy()
else:
strategy = tf.distribute.get_strategy()
named_strategy = (
ds_utils.named_strategies.get(FLAGS.strategy) if FLAGS.strategy else None)
strategy = named_strategy.strategy if named_strategy else None
with strategy.scope():
with MaybeDistributionScope(strategy):
feature_extractor = make_feature_extractor(
FLAGS.input_saved_model_dir,
FLAGS.retrain,
@ -137,7 +154,7 @@ def main(argv):
verbose=1,
validation_data=(x_test, y_test))
if FLAGS.output_saved_model_dir:
if FLAGS.output_saved_model_dir and FLAGS.output_saved_model_dir != 'None':
tf.saved_model.save(model, FLAGS.output_saved_model_dir)

View File

@ -72,8 +72,9 @@ BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
CacheKey = collections.namedtuple("CacheKey", [
"input_signature", "parent_graph", "device_functions",
"colocation_stack"])
"input_signature", "parent_graph", "device_functions", "colocation_stack",
"in_cross_replica_context"
])
CacheKey.replace = CacheKey._replace # pylint: disable=protected-access
@ -1562,9 +1563,15 @@ class Function(object):
device_functions = tuple(default_graph._device_functions_outer_to_inner)
else:
device_functions = ()
# pylint: enable=protected-access
in_cross_replica_context = False
try:
in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access
except (AttributeError, IndexError):
pass
return CacheKey(input_signature, parent_graph, device_functions,
colocation_stack)
colocation_stack, in_cross_replica_context)
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
@ -1664,6 +1671,7 @@ class Function(object):
if self.input_signature is None or args is not None or kwargs is not None:
args, kwargs = self._function_spec.canonicalize_function_inputs(
*args, **kwargs)
cache_key = self._cache_key(args, kwargs)
try: