Enable TF Hub's tests with KerasLayer.
PiperOrigin-RevId: 253730171
This commit is contained in:
parent
e906cdde1b
commit
99d71381f2
@ -23,6 +23,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":distribution_strategy_utils",
|
||||||
":mnist_util",
|
":mnist_util",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
@ -38,14 +39,26 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "distribution_strategy_utils",
|
||||||
|
srcs = ["distribution_strategy_utils.py"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "saved_model_test",
|
name = "saved_model_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"saved_model_test.py",
|
"saved_model_test.py",
|
||||||
],
|
],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
|
":distribution_strategy_utils",
|
||||||
":integration_scripts",
|
":integration_scripts",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python/distribute:combinations",
|
||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
tags = [
|
tags = [
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utils related to tf.distribute.strategy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
|
|
||||||
|
_strategies = [
|
||||||
|
strategy_combinations.one_device_strategy,
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||||
|
]
|
||||||
|
|
||||||
|
named_strategies = collections.OrderedDict()
|
||||||
|
named_strategies[None] = None
|
||||||
|
|
||||||
|
|
||||||
|
for strategy in _strategies:
|
||||||
|
named_strategies[str(strategy)] = strategy
|
@ -33,7 +33,7 @@ def _load_random_data(num_train_and_test):
|
|||||||
|
|
||||||
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
|
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
|
||||||
"""Returns MNIST or Fashion MNIST or fake train and test data."""
|
"""Returns MNIST or Fashion MNIST or fake train and test data."""
|
||||||
load = ((lambda: _load_random_data([10, 50])) if fake_tiny_data else
|
load = ((lambda: _load_random_data([16, 128])) if fake_tiny_data else
|
||||||
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
|
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
|
||||||
tf.keras.datasets.mnist.load_data)
|
tf.keras.datasets.mnist.load_data)
|
||||||
(x_train, y_train), (x_test, y_test) = load()
|
(x_train, y_train), (x_test, y_test) = load()
|
||||||
|
@ -23,10 +23,12 @@ import os
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow.compat.v2 as tf
|
import tensorflow.compat.v2 as tf
|
||||||
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import integration_scripts
|
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
||||||
|
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
|
||||||
|
from tensorflow.python.distribute import combinations
|
||||||
|
|
||||||
|
|
||||||
class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
|
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def __init__(self, method_name="runTest", has_extra_deps=False):
|
def __init__(self, method_name="runTest", has_extra_deps=False):
|
||||||
super(SavedModelTest, self).__init__(method_name)
|
super(SavedModelTest, self).__init__(method_name)
|
||||||
@ -71,40 +73,48 @@ class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase):
|
|||||||
self.assertCommandSucceeded(
|
self.assertCommandSucceeded(
|
||||||
"use_text_embedding_in_dataset", model_dir=export_dir)
|
"use_text_embedding_in_dataset", model_dir=export_dir)
|
||||||
|
|
||||||
NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = (
|
@combinations.generate(
|
||||||
("", dict()),
|
combinations.combine(
|
||||||
("_with_retraining", dict(
|
named_strategy=list(ds_utils.named_strategies.values()),
|
||||||
retrain=True,
|
retrain_flag_value=["true", "false"],
|
||||||
regularization_loss_multiplier=2, # Test impact of b/134528831.
|
regularization_loss_multiplier=[0, 2]),
|
||||||
)),
|
test_combinations=[combinations.NamedGPUCombination()])
|
||||||
("_with_mirrored_strategy", dict(
|
def test_mnist_cnn(self, named_strategy, retrain_flag_value,
|
||||||
retrain=True, # That's the relevant case for distribution.
|
regularization_loss_multiplier):
|
||||||
use_mirrored_strategy=True,
|
|
||||||
)),
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(*NAMED_PARAMETERS_FOR_TEST_MNIST_CNN)
|
|
||||||
def test_mnist_cnn(self, use_kwargs):
|
|
||||||
self.skipIfMissingExtraDeps()
|
self.skipIfMissingExtraDeps()
|
||||||
if use_kwargs.get("use_mirrored_strategy", None):
|
|
||||||
self.skipTest(
|
|
||||||
"b/129134185 - saved model and distribution strategy integration")
|
|
||||||
fast_test_mode = True
|
fast_test_mode = True
|
||||||
temp_dir = self.get_temp_dir()
|
temp_dir = self.get_temp_dir()
|
||||||
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
|
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
|
||||||
full_model_dir = os.path.join(temp_dir, "full_model")
|
|
||||||
|
# TODO(b/135043074): remove this if-else.
|
||||||
|
if named_strategy is None:
|
||||||
|
full_model_dir = os.path.join(temp_dir, "full_model")
|
||||||
|
else:
|
||||||
|
full_model_dir = None
|
||||||
|
|
||||||
self.assertCommandSucceeded(
|
self.assertCommandSucceeded(
|
||||||
"export_mnist_cnn", fast_test_mode=fast_test_mode,
|
"export_mnist_cnn",
|
||||||
|
fast_test_mode=fast_test_mode,
|
||||||
export_dir=feature_extrator_dir)
|
export_dir=feature_extrator_dir)
|
||||||
|
|
||||||
self.assertCommandSucceeded(
|
self.assertCommandSucceeded(
|
||||||
"use_mnist_cnn", fast_test_mode=fast_test_mode,
|
"use_mnist_cnn",
|
||||||
|
fast_test_mode=fast_test_mode,
|
||||||
input_saved_model_dir=feature_extrator_dir,
|
input_saved_model_dir=feature_extrator_dir,
|
||||||
output_saved_model_dir=full_model_dir, **use_kwargs)
|
output_saved_model_dir=full_model_dir,
|
||||||
self.assertCommandSucceeded(
|
strategy=str(named_strategy),
|
||||||
"deploy_mnist_cnn", fast_test_mode=fast_test_mode,
|
retrain=retrain_flag_value,
|
||||||
saved_model_dir=full_model_dir)
|
regularization_loss_multiplier=regularization_loss_multiplier)
|
||||||
|
|
||||||
|
if full_model_dir is not None:
|
||||||
|
self.assertCommandSucceeded(
|
||||||
|
"deploy_mnist_cnn",
|
||||||
|
fast_test_mode=fast_test_mode,
|
||||||
|
saved_model_dir=full_model_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
integration_scripts.MaybeRunScriptInstead()
|
scripts.MaybeRunScriptInstead()
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -31,6 +31,7 @@ from absl import flags
|
|||||||
import tensorflow.compat.v2 as tf
|
import tensorflow.compat.v2 as tf
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
||||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
@ -57,12 +58,29 @@ flags.DEFINE_bool(
|
|||||||
flags.DEFINE_bool(
|
flags.DEFINE_bool(
|
||||||
'fast_test_mode', False,
|
'fast_test_mode', False,
|
||||||
'Shortcut training for running in unit tests.')
|
'Shortcut training for running in unit tests.')
|
||||||
flags.DEFINE_bool(
|
|
||||||
'use_mirrored_strategy', False,
|
|
||||||
'Whether to use mirrored distribution strategy.')
|
|
||||||
flags.DEFINE_string(
|
flags.DEFINE_string(
|
||||||
'output_saved_model_dir', None,
|
'output_saved_model_dir', None,
|
||||||
'Directory of the SavedModel that was exported for reuse.')
|
'Directory of the SavedModel that was exported for reuse.')
|
||||||
|
flags.DEFINE_string('strategy', None,
|
||||||
|
'Name of the distribution strategy to use.')
|
||||||
|
|
||||||
|
|
||||||
|
class MaybeDistributionScope(object):
|
||||||
|
"""Provides a context allowing no distribution strategy."""
|
||||||
|
|
||||||
|
def __init__(self, distribution):
|
||||||
|
self._distribution = distribution
|
||||||
|
self._scope = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self._distribution:
|
||||||
|
self._scope = self._distribution.scope()
|
||||||
|
self._scope.__enter__()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, value, traceback):
|
||||||
|
if self._distribution:
|
||||||
|
self._scope.__exit__(exc_type, value, traceback)
|
||||||
|
self._scope = None
|
||||||
|
|
||||||
|
|
||||||
def make_feature_extractor(saved_model_path, trainable,
|
def make_feature_extractor(saved_model_path, trainable,
|
||||||
@ -108,12 +126,11 @@ def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5):
|
|||||||
def main(argv):
|
def main(argv):
|
||||||
del argv
|
del argv
|
||||||
|
|
||||||
if FLAGS.use_mirrored_strategy:
|
named_strategy = (
|
||||||
strategy = tf.distribute.MirroredStrategy()
|
ds_utils.named_strategies.get(FLAGS.strategy) if FLAGS.strategy else None)
|
||||||
else:
|
strategy = named_strategy.strategy if named_strategy else None
|
||||||
strategy = tf.distribute.get_strategy()
|
|
||||||
|
|
||||||
with strategy.scope():
|
with MaybeDistributionScope(strategy):
|
||||||
feature_extractor = make_feature_extractor(
|
feature_extractor = make_feature_extractor(
|
||||||
FLAGS.input_saved_model_dir,
|
FLAGS.input_saved_model_dir,
|
||||||
FLAGS.retrain,
|
FLAGS.retrain,
|
||||||
@ -137,7 +154,7 @@ def main(argv):
|
|||||||
verbose=1,
|
verbose=1,
|
||||||
validation_data=(x_test, y_test))
|
validation_data=(x_test, y_test))
|
||||||
|
|
||||||
if FLAGS.output_saved_model_dir:
|
if FLAGS.output_saved_model_dir and FLAGS.output_saved_model_dir != 'None':
|
||||||
tf.saved_model.save(model, FLAGS.output_saved_model_dir)
|
tf.saved_model.save(model, FLAGS.output_saved_model_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,8 +72,9 @@ BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
|
|||||||
|
|
||||||
|
|
||||||
CacheKey = collections.namedtuple("CacheKey", [
|
CacheKey = collections.namedtuple("CacheKey", [
|
||||||
"input_signature", "parent_graph", "device_functions",
|
"input_signature", "parent_graph", "device_functions", "colocation_stack",
|
||||||
"colocation_stack"])
|
"in_cross_replica_context"
|
||||||
|
])
|
||||||
|
|
||||||
CacheKey.replace = CacheKey._replace # pylint: disable=protected-access
|
CacheKey.replace = CacheKey._replace # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -1562,9 +1563,15 @@ class Function(object):
|
|||||||
device_functions = tuple(default_graph._device_functions_outer_to_inner)
|
device_functions = tuple(default_graph._device_functions_outer_to_inner)
|
||||||
else:
|
else:
|
||||||
device_functions = ()
|
device_functions = ()
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
in_cross_replica_context = False
|
||||||
|
try:
|
||||||
|
in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access
|
||||||
|
except (AttributeError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
return CacheKey(input_signature, parent_graph, device_functions,
|
return CacheKey(input_signature, parent_graph, device_functions,
|
||||||
colocation_stack)
|
colocation_stack, in_cross_replica_context)
|
||||||
|
|
||||||
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
|
||||||
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
|
||||||
@ -1664,6 +1671,7 @@ class Function(object):
|
|||||||
if self.input_signature is None or args is not None or kwargs is not None:
|
if self.input_signature is None or args is not None or kwargs is not None:
|
||||||
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
cache_key = self._cache_key(args, kwargs)
|
cache_key = self._cache_key(args, kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user