apply test review changes

This commit is contained in:
kushanam 2020-10-16 14:39:36 -07:00
parent 7eb9ba4ef9
commit 6b7d4554f5
2 changed files with 7 additions and 6 deletions

View File

@ -1461,9 +1461,9 @@ class _SingleWorkerDatasetIteratorBase(object):
raise NotImplementedError("must be implemented in descendants") raise NotImplementedError("must be implemented in descendants")
def _format_data_list_with_options(self, data_list): def _format_data_list_with_options(self, data_list):
"""hange the data list to tuple type if required """Change the data list to tuple type if required
The OwnedMultiDeviceIterator returns the tuple data type, The OwnedMultiDeviceIterator returns the tuple data type,
while the PER_REPLICA iterator (when used with prefetch enabled) while the PER_REPLICA iterator (when used with prefetch disabled)
returns without the enclosed tuple. This is to fix the inconsistency. returns without the enclosed tuple. This is to fix the inconsistency.
""" """
if (self._options if (self._options

View File

@ -1370,10 +1370,10 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
ds = distribution.experimental_distribute_datasets_from_function( ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options) dataset_fn, input_options)
for x in ds: # validating the values
# validating the values x = next(iter(ds))
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
@ -1409,6 +1409,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
# validating the values # validating the values
assert x.values[0].numpy() == expected[i] assert x.values[0].numpy() == expected[i]
assert x.values[1].numpy() == expected[i] * 2 assert x.values[1].numpy() == expected[i] * 2
assert i == len(expected) - 1
if __name__ == "__main__": if __name__ == "__main__":
combinations.main() combinations.main()