apply test review changes
This commit is contained in:
parent
7eb9ba4ef9
commit
6b7d4554f5
@ -1461,9 +1461,9 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _format_data_list_with_options(self, data_list):
|
||||
"""hange the data list to tuple type if required
|
||||
"""Change the data list to tuple type if required
|
||||
The OwnedMultiDeviceIterator returns the tuple data type,
|
||||
while the PER_REPLICA iterator (when used with prefetch enabled)
|
||||
while the PER_REPLICA iterator (when used with prefetch disabled)
|
||||
returns without the enclosed tuple. This is to fix the inconsistency.
|
||||
"""
|
||||
if (self._options
|
||||
|
||||
@ -1370,10 +1370,10 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, input_options)
|
||||
|
||||
for x in ds:
|
||||
# validating the values
|
||||
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
|
||||
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
|
||||
# validating the values
|
||||
x = next(iter(ds))
|
||||
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
|
||||
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
@ -1409,6 +1409,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, parameterize
|
||||
# validating the values
|
||||
assert x.values[0].numpy() == expected[i]
|
||||
assert x.values[1].numpy() == expected[i] * 2
|
||||
assert i == len(expected) - 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user