Remove a dependency on multi_worker_test from multi_worker_optimizer_comparison_test since it's not needed. Resolve a TODO to remove the function def.
PiperOrigin-RevId: 245841884
This commit is contained in:
parent
7d847cc329
commit
b6c3476c76
@ -386,6 +386,5 @@ cuda_py_test(
|
||||
],
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_oss", # TODO(b/130035424)
|
||||
],
|
||||
)
|
||||
|
@ -32,7 +32,6 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.keras.distribute import mnist_multi_worker
|
||||
from tensorflow.python.keras.distribute import multi_worker_test
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
@ -42,7 +41,6 @@ from tensorflow.python.training import gradient_descent as gradient_descent_v1
|
||||
from tensorflow.python.training import rmsprop as rmsprop_v1
|
||||
|
||||
|
||||
get_strategy_object = multi_worker_test.get_strategy_object
|
||||
# TODO(rchao): Move maybe_shard_dataset to shared util.
|
||||
maybe_shard_dataset = mnist_multi_worker.maybe_shard_dataset
|
||||
|
||||
@ -96,7 +94,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
|
||||
|
||||
# Clear Keras session to reset device assignment
|
||||
keras.backend._SESSION.session = None
|
||||
strategy = get_strategy_object(strategy_cls)
|
||||
strategy = strategy_cls()
|
||||
|
||||
with strategy.scope():
|
||||
train_ds = get_input_datasets()
|
||||
@ -126,7 +124,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
|
||||
cluster_spec)
|
||||
|
||||
threads_to_join = []
|
||||
strategy = get_strategy_object(strategy_cls)
|
||||
strategy = strategy_cls()
|
||||
if strategy.extended.experimental_between_graph:
|
||||
for ts in threads.values():
|
||||
threads_to_join.extend(ts)
|
||||
|
@ -296,13 +296,6 @@ def _run_standalone_client(test_obj, strategy, cluster_spec):
|
||||
cluster_spec=cluster_spec)
|
||||
|
||||
|
||||
# TODO(yuefengz): remove this function once
|
||||
# multi_worker_optimizer_comparison_test no longer depends on it.
|
||||
def get_strategy_object(strategy_cls):
|
||||
# CollectiveAllReduceStrategy and ParameterServerStrategy.
|
||||
return strategy_cls()
|
||||
|
||||
|
||||
class KerasMultiWorkerTestStandaloneClient(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user