diff --git a/tensorflow/examples/saved_model/integration_tests/BUILD b/tensorflow/examples/saved_model/integration_tests/BUILD index 93e0e1ca0d2..b81c40132fd 100644 --- a/tensorflow/examples/saved_model/integration_tests/BUILD +++ b/tensorflow/examples/saved_model/integration_tests/BUILD @@ -23,6 +23,7 @@ py_library( ], visibility = ["//tensorflow:internal"], deps = [ + ":distribution_strategy_utils", ":mnist_util", "//tensorflow:tensorflow_py", "@absl_py//absl/testing:parameterized", @@ -38,14 +39,26 @@ py_library( ], ) +py_library( + name = "distribution_strategy_utils", + srcs = ["distribution_strategy_utils.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python/distribute:strategy_combinations", + ], +) + cuda_py_test( name = "saved_model_test", srcs = [ "saved_model_test.py", ], additional_deps = [ + ":distribution_strategy_utils", ":integration_scripts", + "@absl_py//absl/testing:parameterized", "//tensorflow:tensorflow_py", + "//tensorflow/python/distribute:combinations", ], shard_count = 4, tags = [ diff --git a/tensorflow/examples/saved_model/integration_tests/distribution_strategy_utils.py b/tensorflow/examples/saved_model/integration_tests/distribution_strategy_utils.py new file mode 100644 index 00000000000..4fdfc3dc8c0 --- /dev/null +++ b/tensorflow/examples/saved_model/integration_tests/distribution_strategy_utils.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils related to tf.distribute.strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.distribute import strategy_combinations + +_strategies = [ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_one_gpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_two_gpus, +] + +named_strategies = collections.OrderedDict() +named_strategies[None] = None + + +for strategy in _strategies: + named_strategies[str(strategy)] = strategy diff --git a/tensorflow/examples/saved_model/integration_tests/mnist_util.py b/tensorflow/examples/saved_model/integration_tests/mnist_util.py index fccd7188f62..8e4cdac748f 100644 --- a/tensorflow/examples/saved_model/integration_tests/mnist_util.py +++ b/tensorflow/examples/saved_model/integration_tests/mnist_util.py @@ -33,7 +33,7 @@ def _load_random_data(num_train_and_test): def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False): """Returns MNIST or Fashion MNIST or fake train and test data.""" - load = ((lambda: _load_random_data([10, 50])) if fake_tiny_data else + load = ((lambda: _load_random_data([16, 128])) if fake_tiny_data else tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else tf.keras.datasets.mnist.load_data) (x_train, y_train), (x_test, y_test) = load() diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py index 29bbac4d5c4..04cd392dedf 100644 --- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py +++ b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py @@ -23,10 +23,12 @@ import os from absl.testing import parameterized import tensorflow.compat.v2 as tf -from tensorflow.examples.saved_model.integration_tests import integration_scripts +from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils +from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts +from tensorflow.python.distribute import combinations -class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase): +class SavedModelTest(scripts.TestCase, parameterized.TestCase): def __init__(self, method_name="runTest", has_extra_deps=False): super(SavedModelTest, self).__init__(method_name) @@ -71,40 +73,48 @@ class SavedModelTest(integration_scripts.TestCase, parameterized.TestCase): self.assertCommandSucceeded( "use_text_embedding_in_dataset", model_dir=export_dir) - NAMED_PARAMETERS_FOR_TEST_MNIST_CNN = ( - ("", dict()), - ("_with_retraining", dict( - retrain=True, - regularization_loss_multiplier=2, # Test impact of b/134528831. - )), - ("_with_mirrored_strategy", dict( - retrain=True, # That's the relevant case for distribution. - use_mirrored_strategy=True, - )), - ) + @combinations.generate( + combinations.combine( + named_strategy=list(ds_utils.named_strategies.values()), + retrain_flag_value=["true", "false"], + regularization_loss_multiplier=[0, 2]), + test_combinations=[combinations.NamedGPUCombination()]) + def test_mnist_cnn(self, named_strategy, retrain_flag_value, + regularization_loss_multiplier): - @parameterized.named_parameters(*NAMED_PARAMETERS_FOR_TEST_MNIST_CNN) - def test_mnist_cnn(self, use_kwargs): self.skipIfMissingExtraDeps() - if use_kwargs.get("use_mirrored_strategy", None): - self.skipTest( - "b/129134185 - saved model and distribution strategy integration") + fast_test_mode = True temp_dir = self.get_temp_dir() feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor") - full_model_dir = os.path.join(temp_dir, "full_model") + + # TODO(b/135043074): remove this if-else. + if named_strategy is None: + full_model_dir = os.path.join(temp_dir, "full_model") + else: + full_model_dir = None + self.assertCommandSucceeded( - "export_mnist_cnn", fast_test_mode=fast_test_mode, + "export_mnist_cnn", + fast_test_mode=fast_test_mode, export_dir=feature_extrator_dir) + self.assertCommandSucceeded( - "use_mnist_cnn", fast_test_mode=fast_test_mode, + "use_mnist_cnn", + fast_test_mode=fast_test_mode, input_saved_model_dir=feature_extrator_dir, - output_saved_model_dir=full_model_dir, **use_kwargs) - self.assertCommandSucceeded( - "deploy_mnist_cnn", fast_test_mode=fast_test_mode, - saved_model_dir=full_model_dir) + output_saved_model_dir=full_model_dir, + strategy=str(named_strategy), + retrain=retrain_flag_value, + regularization_loss_multiplier=regularization_loss_multiplier) + + if full_model_dir is not None: + self.assertCommandSucceeded( + "deploy_mnist_cnn", + fast_test_mode=fast_test_mode, + saved_model_dir=full_model_dir) if __name__ == "__main__": - integration_scripts.MaybeRunScriptInstead() + scripts.MaybeRunScriptInstead() tf.test.main() diff --git a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py index cec671b46ac..ea363bc28a6 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py +++ b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py @@ -31,6 +31,7 @@ from absl import flags import tensorflow.compat.v2 as tf import tensorflow_hub as hub +from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils from tensorflow.examples.saved_model.integration_tests import mnist_util FLAGS = flags.FLAGS @@ -57,12 +58,29 @@ flags.DEFINE_bool( flags.DEFINE_bool( 'fast_test_mode', False, 'Shortcut training for running in unit tests.') -flags.DEFINE_bool( - 'use_mirrored_strategy', False, - 'Whether to use mirrored distribution strategy.') flags.DEFINE_string( 'output_saved_model_dir', None, 'Directory of the SavedModel that was exported for reuse.') +flags.DEFINE_string('strategy', None, + 'Name of the distribution strategy to use.') + + +class MaybeDistributionScope(object): + """Provides a context allowing no distribution strategy.""" + + def __init__(self, distribution): + self._distribution = distribution + self._scope = None + + def __enter__(self): + if self._distribution: + self._scope = self._distribution.scope() + self._scope.__enter__() + + def __exit__(self, exc_type, value, traceback): + if self._distribution: + self._scope.__exit__(exc_type, value, traceback) + self._scope = None def make_feature_extractor(saved_model_path, trainable, @@ -108,12 +126,11 @@ def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5): def main(argv): del argv - if FLAGS.use_mirrored_strategy: - strategy = tf.distribute.MirroredStrategy() - else: - strategy = tf.distribute.get_strategy() + named_strategy = ( + ds_utils.named_strategies.get(FLAGS.strategy) if FLAGS.strategy else None) + strategy = named_strategy.strategy if named_strategy else None - with strategy.scope(): + with MaybeDistributionScope(strategy): feature_extractor = make_feature_extractor( FLAGS.input_saved_model_dir, FLAGS.retrain, @@ -137,7 +154,7 @@ def main(argv): verbose=1, validation_data=(x_test, y_test)) - if FLAGS.output_saved_model_dir: + if FLAGS.output_saved_model_dir and FLAGS.output_saved_model_dir != 'None': tf.saved_model.save(model, FLAGS.output_saved_model_dir) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index fe43393a695..3105deb3512 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -72,8 +72,9 @@ BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" CacheKey = collections.namedtuple("CacheKey", [ - "input_signature", "parent_graph", "device_functions", - "colocation_stack"]) + "input_signature", "parent_graph", "device_functions", "colocation_stack", + "in_cross_replica_context" +]) CacheKey.replace = CacheKey._replace # pylint: disable=protected-access @@ -1562,9 +1563,15 @@ class Function(object): device_functions = tuple(default_graph._device_functions_outer_to_inner) else: device_functions = () - # pylint: enable=protected-access + + in_cross_replica_context = False + try: + in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access + except (AttributeError, IndexError): + pass + return CacheKey(input_signature, parent_graph, device_functions, - colocation_stack) + colocation_stack, in_cross_replica_context) def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): """Create a `ConcreteFunction` from `args` and `kwargs`.""" @@ -1664,6 +1671,7 @@ class Function(object): if self.input_signature is None or args is not None or kwargs is not None: args, kwargs = self._function_spec.canonicalize_function_inputs( *args, **kwargs) + cache_key = self._cache_key(args, kwargs) try: