apply test review changes
This commit is contained in:
parent
7eb9ba4ef9
commit
6b7d4554f5
@ -1461,9 +1461,9 @@ class _SingleWorkerDatasetIteratorBase(object):
|
|||||||
raise NotImplementedError("must be implemented in descendants")
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
def _format_data_list_with_options(self, data_list):
|
def _format_data_list_with_options(self, data_list):
|
||||||
"""hange the data list to tuple type if required
|
"""Change the data list to tuple type if required
|
||||||
The OwnedMultiDeviceIterator returns the tuple data type,
|
The OwnedMultiDeviceIterator returns the tuple data type,
|
||||||
while the PER_REPLICA iterator (when used with prefetch enabled)
|
while the PER_REPLICA iterator (when used with prefetch disabled)
|
||||||
returns without the enclosed tuple. This is to fix the inconsistency.
|
returns without the enclosed tuple. This is to fix the inconsistency.
|
||||||
"""
|
"""
|
||||||
if (self._options
|
if (self._options
|
||||||
|
|||||||
@ -1370,10 +1370,10 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
|
|||||||
ds = distribution.experimental_distribute_datasets_from_function(
|
ds = distribution.experimental_distribute_datasets_from_function(
|
||||||
dataset_fn, input_options)
|
dataset_fn, input_options)
|
||||||
|
|
||||||
for x in ds:
|
# validating the values
|
||||||
# validating the values
|
x = next(iter(ds))
|
||||||
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
|
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
|
||||||
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
|
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
@ -1409,6 +1409,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
|
|||||||
# validating the values
|
# validating the values
|
||||||
assert x.values[0].numpy() == expected[i]
|
assert x.values[0].numpy() == expected[i]
|
||||||
assert x.values[1].numpy() == expected[i] * 2
|
assert x.values[1].numpy() == expected[i] * 2
|
||||||
|
assert i == len(expected) - 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
combinations.main()
|
combinations.main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user