Fix for b/146078486. In py3, range() returns iterable object.
PiperOrigin-RevId: 285282691 Change-Id: I4e6520e5fa6e5df62e88c8fb8a0296ba86fb8465
This commit is contained in:
parent
9c4b6d9df8
commit
356de62421
@ -364,9 +364,9 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
computed_value = sess.run([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(expected_value, computed_value)
|
||||
self.assertCountEqual(list(expected_value), list(computed_value))
|
||||
else:
|
||||
self.assertEqual(expected_value, computed_value)
|
||||
self.assertEqual(list(expected_value), list(computed_value))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
@ -382,9 +382,9 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
computed_value = sess.run([values.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
if ignore_order:
|
||||
self.assertCountEqual(expected_value, computed_value)
|
||||
self.assertCountEqual(list(expected_value), list(computed_value))
|
||||
else:
|
||||
self.assertEqual(expected_value, computed_value)
|
||||
self.assertEqual(list(expected_value), list(computed_value))
|
||||
|
||||
|
||||
class DistributedCollectiveAllReduceStrategyTest(
|
||||
|
Loading…
Reference in New Issue
Block a user