Remove enable collective ops tests

We now have multi worker unit tests so these are already covered. We should test
the API, not the implementation.

PiperOrigin-RevId: 344128126
Change-Id: I83e237e0620d65b7b6cfc84df2c686bd6b55a019
This commit is contained in:
Ran Chen 2020-11-24 14:03:58 -08:00 committed by TensorFlower Gardener
parent 51e3032483
commit 8ca9530005

View File

@ -400,48 +400,6 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertEqual(['CollectiveReduce'],
new_rewrite_options.scoped_allocator_opts.enable_op)
def _get_strategy_with_mocked_methods(self):
mock_called = [False]
# pylint: disable=dangerous-default-value
def mock_enable_collective_ops(server_def, mock_called=mock_called):
self.assertEqual('worker', server_def.job_name)
self.assertEqual(1, server_def.task_index)
self.assertEqual('grpc', server_def.protocol)
mock_called[0] = True
def mock_configure_collective_ops(*args, **kwargs):
del args, kwargs
with test.mock.patch.object(context.context(), 'enable_collective_ops',
mock_enable_collective_ops), \
test.mock.patch.object(context.context(), 'configure_collective_ops',
mock_configure_collective_ops):
strategy, _, _ = self._get_test_object(
task_type='worker', task_id=1, num_gpus=2)
return strategy, mock_called
@combinations.generate(combinations.combine(mode=['eager']))
def testEnableCollectiveOps(self):
# We cannot enable check health with this test because it mocks
# enable_collective_ops.
CollectiveAllReduceExtended._enable_check_health = False
strategy, mock_called = self._get_strategy_with_mocked_methods()
CollectiveAllReduceExtended._enable_check_health = True
self.assertTrue(strategy.extended._std_server_started)
self.assertTrue(mock_called[0])
@combinations.generate(combinations.combine(mode=['eager']))
def testEnableCollectiveOpsAndClusterResolver(self):
# We cannot enable check health with this test because it mocks
# enable_collective_ops.
CollectiveAllReduceExtended._enable_check_health = False
strategy, _ = self._get_strategy_with_mocked_methods()
CollectiveAllReduceExtended._enable_check_health = True
self.assertEqual(strategy.cluster_resolver.task_type, 'worker')
self.assertEqual(strategy.cluster_resolver.task_id, 1)
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):