apply test review changes

This commit is contained in:
kushanam 2020-10-16 14:39:36 -07:00
parent 7eb9ba4ef9
commit 6b7d4554f5
2 changed files with 7 additions and 6 deletions

View File

@ -1461,9 +1461,9 @@ class _SingleWorkerDatasetIteratorBase(object):
raise NotImplementedError("must be implemented in descendants")
def _format_data_list_with_options(self, data_list):
"""hange the data list to tuple type if required
"""Change the data list to tuple type if required
The OwnedMultiDeviceIterator returns the tuple data type,
while the PER_REPLICA iterator (when used with prefetch enabled)
while the PER_REPLICA iterator (when used with prefetch disabled)
returns without the enclosed tuple. This is to fix the inconsistency.
"""
if (self._options

View File

@ -1370,10 +1370,10 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
# validating the values
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
# validating the values
x = next(iter(ds))
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
@combinations.generate(
combinations.combine(
@ -1409,6 +1409,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
# validating the values
assert x.values[0].numpy() == expected[i]
assert x.values[1].numpy() == expected[i] * 2
assert i == len(expected) - 1
if __name__ == "__main__":
combinations.main()