From 8ca95300057c175ff63f0b5acd2ca17dc5f13a08 Mon Sep 17 00:00:00 2001 From: Ran Chen <crccw@google.com> Date: Tue, 24 Nov 2020 14:03:58 -0800 Subject: [PATCH] Remove enable collective ops tests We now have multi worker unit tests so these are already covered. We should test the API, not the implementation. PiperOrigin-RevId: 344128126 Change-Id: I83e237e0620d65b7b6cfc84df2c686bd6b55a019 --- .../collective_all_reduce_strategy_test.py | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index afcf1959e02..f27164d32b5 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -400,48 +400,6 @@ class DistributedCollectiveAllReduceStrategyTest( self.assertEqual(['CollectiveReduce'], new_rewrite_options.scoped_allocator_opts.enable_op) - def _get_strategy_with_mocked_methods(self): - mock_called = [False] - - # pylint: disable=dangerous-default-value - def mock_enable_collective_ops(server_def, mock_called=mock_called): - self.assertEqual('worker', server_def.job_name) - self.assertEqual(1, server_def.task_index) - self.assertEqual('grpc', server_def.protocol) - mock_called[0] = True - - def mock_configure_collective_ops(*args, **kwargs): - del args, kwargs - - with test.mock.patch.object(context.context(), 'enable_collective_ops', - mock_enable_collective_ops), \ - test.mock.patch.object(context.context(), 'configure_collective_ops', - mock_configure_collective_ops): - strategy, _, _ = self._get_test_object( - task_type='worker', task_id=1, num_gpus=2) - - return strategy, mock_called - - @combinations.generate(combinations.combine(mode=['eager'])) - def testEnableCollectiveOps(self): - # We cannot enable check health with this test because it mocks - # enable_collective_ops. - CollectiveAllReduceExtended._enable_check_health = False - strategy, mock_called = self._get_strategy_with_mocked_methods() - CollectiveAllReduceExtended._enable_check_health = True - self.assertTrue(strategy.extended._std_server_started) - self.assertTrue(mock_called[0]) - - @combinations.generate(combinations.combine(mode=['eager'])) - def testEnableCollectiveOpsAndClusterResolver(self): - # We cannot enable check health with this test because it mocks - # enable_collective_ops. - CollectiveAllReduceExtended._enable_check_health = False - strategy, _ = self._get_strategy_with_mocked_methods() - CollectiveAllReduceExtended._enable_check_health = True - self.assertEqual(strategy.cluster_resolver.task_type, 'worker') - self.assertEqual(strategy.cluster_resolver.task_id, 1) - class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase):