Remove a dependency on multi_worker_test from multi_worker_optimizer_comparison_test since it's not needed. Resolve a TODO to remove the function def.

PiperOrigin-RevId: 245841884
This commit is contained in:
Rick Chao 2019-04-29 15:59:55 -07:00 committed by TensorFlower Gardener
parent 7d847cc329
commit b6c3476c76
3 changed files with 2 additions and 12 deletions

View File

@ -386,6 +386,5 @@ cuda_py_test(
],
tags = [
"multi_and_single_gpu",
"no_oss", # TODO(b/130035424)
],
)

View File

@ -32,7 +32,6 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.keras.distribute import mnist_multi_worker
from tensorflow.python.keras.distribute import multi_worker_test
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.optimizer_v2 import gradient_descent
@ -42,7 +41,6 @@ from tensorflow.python.training import gradient_descent as gradient_descent_v1
from tensorflow.python.training import rmsprop as rmsprop_v1
get_strategy_object = multi_worker_test.get_strategy_object
# TODO(rchao): Move maybe_shard_dataset to shared util.
maybe_shard_dataset = mnist_multi_worker.maybe_shard_dataset
@ -96,7 +94,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
# Clear Keras session to reset device assignment
keras.backend._SESSION.session = None
strategy = get_strategy_object(strategy_cls)
strategy = strategy_cls()
with strategy.scope():
train_ds = get_input_datasets()
@ -126,7 +124,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
cluster_spec)
threads_to_join = []
strategy = get_strategy_object(strategy_cls)
strategy = strategy_cls()
if strategy.extended.experimental_between_graph:
for ts in threads.values():
threads_to_join.extend(ts)

View File

@ -296,13 +296,6 @@ def _run_standalone_client(test_obj, strategy, cluster_spec):
cluster_spec=cluster_spec)
# TODO(yuefengz): remove this function once
# multi_worker_optimizer_comparison_test no longer depends on it.
def get_strategy_object(strategy_cls):
# CollectiveAllReduceStrategy and ParameterServerStrategy.
return strategy_cls()
class KerasMultiWorkerTestStandaloneClient(test.TestCase,
parameterized.TestCase):