Remove enable collective ops tests
We now have multi worker unit tests so these are already covered. We should test the API, not the implementation. PiperOrigin-RevId: 344128126 Change-Id: I83e237e0620d65b7b6cfc84df2c686bd6b55a019
This commit is contained in:
parent
51e3032483
commit
8ca9530005
@ -400,48 +400,6 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
self.assertEqual(['CollectiveReduce'],
|
||||
new_rewrite_options.scoped_allocator_opts.enable_op)
|
||||
|
||||
def _get_strategy_with_mocked_methods(self):
|
||||
mock_called = [False]
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def mock_enable_collective_ops(server_def, mock_called=mock_called):
|
||||
self.assertEqual('worker', server_def.job_name)
|
||||
self.assertEqual(1, server_def.task_index)
|
||||
self.assertEqual('grpc', server_def.protocol)
|
||||
mock_called[0] = True
|
||||
|
||||
def mock_configure_collective_ops(*args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
with test.mock.patch.object(context.context(), 'enable_collective_ops',
|
||||
mock_enable_collective_ops), \
|
||||
test.mock.patch.object(context.context(), 'configure_collective_ops',
|
||||
mock_configure_collective_ops):
|
||||
strategy, _, _ = self._get_test_object(
|
||||
task_type='worker', task_id=1, num_gpus=2)
|
||||
|
||||
return strategy, mock_called
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testEnableCollectiveOps(self):
|
||||
# We cannot enable check health with this test because it mocks
|
||||
# enable_collective_ops.
|
||||
CollectiveAllReduceExtended._enable_check_health = False
|
||||
strategy, mock_called = self._get_strategy_with_mocked_methods()
|
||||
CollectiveAllReduceExtended._enable_check_health = True
|
||||
self.assertTrue(strategy.extended._std_server_started)
|
||||
self.assertTrue(mock_called[0])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testEnableCollectiveOpsAndClusterResolver(self):
|
||||
# We cannot enable check health with this test because it mocks
|
||||
# enable_collective_ops.
|
||||
CollectiveAllReduceExtended._enable_check_health = False
|
||||
strategy, _ = self._get_strategy_with_mocked_methods()
|
||||
CollectiveAllReduceExtended._enable_check_health = True
|
||||
self.assertEqual(strategy.cluster_resolver.task_type, 'worker')
|
||||
self.assertEqual(strategy.cluster_resolver.task_id, 1)
|
||||
|
||||
|
||||
class DistributedCollectiveAllReduceStrategyTestWithChief(
|
||||
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user