Update grappler/cluster_test.py reflecting changes in random_uniform op
PiperOrigin-RevId: 283834913 Change-Id: If2869e4786d7c10c180bf6e7cb2ba5361a46c3cb
This commit is contained in:
parent
41228d7f14
commit
6d7926bb87
@ -81,9 +81,9 @@ class ClusterTest(test.TestCase):
|
||||
self.assertLessEqual(1, len(peak_mem))
|
||||
snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0']
|
||||
peak_usage = snapshot[0]
|
||||
self.assertEqual(52, peak_usage)
|
||||
self.assertEqual(12, peak_usage)
|
||||
live_tensors = snapshot[1]
|
||||
self.assertEqual(15, len(live_tensors))
|
||||
self.assertEqual(5, len(live_tensors))
|
||||
|
||||
def testVirtualCluster(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
@ -107,8 +107,8 @@ class ClusterTest(test.TestCase):
|
||||
disable_timeline=False,
|
||||
devices=[named_device])
|
||||
op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
|
||||
self.assertEqual(run_time, 0.000545)
|
||||
self.assertEqual(len(op_perfs), 15)
|
||||
self.assertEqual(run_time, 0.000209)
|
||||
self.assertEqual(len(op_perfs), 5)
|
||||
|
||||
estimated_perf = grappler_cluster.EstimatePerformance(named_device)
|
||||
self.assertEqual(7680.0, estimated_perf)
|
||||
|
Loading…
Reference in New Issue
Block a user