From 8ca95300057c175ff63f0b5acd2ca17dc5f13a08 Mon Sep 17 00:00:00 2001
From: Ran Chen <crccw@google.com>
Date: Tue, 24 Nov 2020 14:03:58 -0800
Subject: [PATCH] Remove enable collective ops tests

We now have multi worker unit tests so these are already covered. We should test
the API, not the implementation.

PiperOrigin-RevId: 344128126
Change-Id: I83e237e0620d65b7b6cfc84df2c686bd6b55a019
---
 .../collective_all_reduce_strategy_test.py    | 42 -------------------
 1 file changed, 42 deletions(-)

diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
index afcf1959e02..f27164d32b5 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py
@@ -400,48 +400,6 @@ class DistributedCollectiveAllReduceStrategyTest(
     self.assertEqual(['CollectiveReduce'],
                      new_rewrite_options.scoped_allocator_opts.enable_op)
 
-  def _get_strategy_with_mocked_methods(self):
-    mock_called = [False]
-
-    # pylint: disable=dangerous-default-value
-    def mock_enable_collective_ops(server_def, mock_called=mock_called):
-      self.assertEqual('worker', server_def.job_name)
-      self.assertEqual(1, server_def.task_index)
-      self.assertEqual('grpc', server_def.protocol)
-      mock_called[0] = True
-
-    def mock_configure_collective_ops(*args, **kwargs):
-      del args, kwargs
-
-    with test.mock.patch.object(context.context(), 'enable_collective_ops',
-                                mock_enable_collective_ops), \
-         test.mock.patch.object(context.context(), 'configure_collective_ops',
-                                mock_configure_collective_ops):
-      strategy, _, _ = self._get_test_object(
-          task_type='worker', task_id=1, num_gpus=2)
-
-    return strategy, mock_called
-
-  @combinations.generate(combinations.combine(mode=['eager']))
-  def testEnableCollectiveOps(self):
-    # We cannot enable check health with this test because it mocks
-    # enable_collective_ops.
-    CollectiveAllReduceExtended._enable_check_health = False
-    strategy, mock_called = self._get_strategy_with_mocked_methods()
-    CollectiveAllReduceExtended._enable_check_health = True
-    self.assertTrue(strategy.extended._std_server_started)
-    self.assertTrue(mock_called[0])
-
-  @combinations.generate(combinations.combine(mode=['eager']))
-  def testEnableCollectiveOpsAndClusterResolver(self):
-    # We cannot enable check health with this test because it mocks
-    # enable_collective_ops.
-    CollectiveAllReduceExtended._enable_check_health = False
-    strategy, _ = self._get_strategy_with_mocked_methods()
-    CollectiveAllReduceExtended._enable_check_health = True
-    self.assertEqual(strategy.cluster_resolver.task_type, 'worker')
-    self.assertEqual(strategy.cluster_resolver.task_id, 1)
-
 
 class DistributedCollectiveAllReduceStrategyTestWithChief(
     CollectiveAllReduceStrategyTestBase, parameterized.TestCase):