diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 2014c0dde3f..b192ba726f4 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -81,9 +81,9 @@ class ClusterTest(test.TestCase): self.assertLessEqual(1, len(peak_mem)) snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0'] peak_usage = snapshot[0] - self.assertEqual(52, peak_usage) + self.assertEqual(12, peak_usage) live_tensors = snapshot[1] - self.assertEqual(15, len(live_tensors)) + self.assertEqual(5, len(live_tensors)) def testVirtualCluster(self): with ops.Graph().as_default() as g: @@ -107,8 +107,8 @@ class ClusterTest(test.TestCase): disable_timeline=False, devices=[named_device]) op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item) - self.assertEqual(run_time, 0.000545) - self.assertEqual(len(op_perfs), 15) + self.assertEqual(run_time, 0.000209) + self.assertEqual(len(op_perfs), 5) estimated_perf = grappler_cluster.EstimatePerformance(named_device) self.assertEqual(7680.0, estimated_perf)