Remove a dependency on multi_worker_test from multi_worker_optimizer_comparison_test since it's not needed. Resolve a TODO to remove the function def.
PiperOrigin-RevId: 245841884
This commit is contained in:
parent
7d847cc329
commit
b6c3476c76
@ -386,6 +386,5 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_oss", # TODO(b/130035424)
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,6 @@ from tensorflow.python.distribute import combinations
|
|||||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||||
from tensorflow.python.keras.distribute import mnist_multi_worker
|
from tensorflow.python.keras.distribute import mnist_multi_worker
|
||||||
from tensorflow.python.keras.distribute import multi_worker_test
|
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||||
@ -42,7 +41,6 @@ from tensorflow.python.training import gradient_descent as gradient_descent_v1
|
|||||||
from tensorflow.python.training import rmsprop as rmsprop_v1
|
from tensorflow.python.training import rmsprop as rmsprop_v1
|
||||||
|
|
||||||
|
|
||||||
get_strategy_object = multi_worker_test.get_strategy_object
|
|
||||||
# TODO(rchao): Move maybe_shard_dataset to shared util.
|
# TODO(rchao): Move maybe_shard_dataset to shared util.
|
||||||
maybe_shard_dataset = mnist_multi_worker.maybe_shard_dataset
|
maybe_shard_dataset = mnist_multi_worker.maybe_shard_dataset
|
||||||
|
|
||||||
@ -96,7 +94,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
|
|||||||
|
|
||||||
# Clear Keras session to reset device assignment
|
# Clear Keras session to reset device assignment
|
||||||
keras.backend._SESSION.session = None
|
keras.backend._SESSION.session = None
|
||||||
strategy = get_strategy_object(strategy_cls)
|
strategy = strategy_cls()
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
train_ds = get_input_datasets()
|
train_ds = get_input_datasets()
|
||||||
@ -126,7 +124,7 @@ class KerasMultiWorkerOptimizerTest(test_base.IndependentWorkerTestBase,
|
|||||||
cluster_spec)
|
cluster_spec)
|
||||||
|
|
||||||
threads_to_join = []
|
threads_to_join = []
|
||||||
strategy = get_strategy_object(strategy_cls)
|
strategy = strategy_cls()
|
||||||
if strategy.extended.experimental_between_graph:
|
if strategy.extended.experimental_between_graph:
|
||||||
for ts in threads.values():
|
for ts in threads.values():
|
||||||
threads_to_join.extend(ts)
|
threads_to_join.extend(ts)
|
||||||
|
@ -296,13 +296,6 @@ def _run_standalone_client(test_obj, strategy, cluster_spec):
|
|||||||
cluster_spec=cluster_spec)
|
cluster_spec=cluster_spec)
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): remove this function once
|
|
||||||
# multi_worker_optimizer_comparison_test no longer depends on it.
|
|
||||||
def get_strategy_object(strategy_cls):
|
|
||||||
# CollectiveAllReduceStrategy and ParameterServerStrategy.
|
|
||||||
return strategy_cls()
|
|
||||||
|
|
||||||
|
|
||||||
class KerasMultiWorkerTestStandaloneClient(test.TestCase,
|
class KerasMultiWorkerTestStandaloneClient(test.TestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user