From b6c3476c764005df55ec7507d5af1af5baa9d52e Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Mon, 29 Apr 2019 15:59:55 -0700 Subject: [PATCH] Remove a dependency on multi_worker_test from multi_worker_optimizer_comparison_test since it's not needed. Resolve a TODO to remove the function def. PiperOrigin-RevId: 245841884 --- tensorflow/python/keras/distribute/BUILD | 1 - .../distribute/multi_worker_optimizer_comparison_test.py | 6 ++---- tensorflow/python/keras/distribute/multi_worker_test.py | 7 ------- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index f769c79cf13..487ce1909ed 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -386,6 +386,5 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/130035424) ], ) diff --git a/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py b/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py index e9294b50b2b..27f0e1ad9bf 100644 --- a/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_optimizer_comparison_test.py @@ -32,7 +32,6 @@ from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.keras.distribute import mnist_multi_worker -from tensorflow.python.keras.distribute import multi_worker_test from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.optimizer_v2 import gradient_descent @@ -42,7 +41,6 @@ from tensorflow.python.training import gradient_descent as gradient_descent_v1 from tensorflow.python.training import rmsprop as rmsprop_v1 -get_strategy_object = multi_worker_test.get_strategy_object # TODO(rchao): Move maybe_shard_dataset to shared util. maybe_shard_dataset = mnist_multi_worker.maybe_shard_dataset @@ -96,7 +94,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase, # Clear Keras session to reset device assignment keras.backend._SESSION.session = None - strategy = get_strategy_object(strategy_cls) + strategy = strategy_cls() with strategy.scope(): train_ds = get_input_datasets() @@ -126,7 +124,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase, cluster_spec) threads_to_join = [] - strategy = get_strategy_object(strategy_cls) + strategy = strategy_cls() if strategy.extended.experimental_between_graph: for ts in threads.values(): threads_to_join.extend(ts) diff --git a/tensorflow/python/keras/distribute/multi_worker_test.py b/tensorflow/python/keras/distribute/multi_worker_test.py index 968e3d371ae..411d02e197d 100644 --- a/tensorflow/python/keras/distribute/multi_worker_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_test.py @@ -296,13 +296,6 @@ def _run_standalone_client(test_obj, strategy, cluster_spec): cluster_spec=cluster_spec) -# TODO(yuefengz): remove this function once -# multi_worker_optimizer_comparison_test no longer depends on it. -def get_strategy_object(strategy_cls): - # CollectiveAllReduceStrategy and ParameterServerStrategy. - return strategy_cls() - - class KerasMultiWorkerTestStandaloneClient(test.TestCase, parameterized.TestCase):