[tf.data service] Add test that different workers use independent shuffle orders.
If shuffle seeds are unspecified, shuffle order should be non-deterministically chosen on each worker. PiperOrigin-RevId: 317375549 Change-Id: I35e32cfbbfb8558451a079875b495708347a23bf
This commit is contained in:
parent
4b3576c081
commit
83fe1bad15
@ -31,6 +31,7 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
@ -78,6 +79,28 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDifferentShuffleOrders(self):
|
||||||
|
random_seed.set_random_seed(None)
|
||||||
|
num_elements = 100
|
||||||
|
master_address = self.create_cluster(2)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds = ds.shuffle(num_elements)
|
||||||
|
ds = _make_distributed_dataset(ds, master_address)
|
||||||
|
output = [elem.numpy() for elem in ds]
|
||||||
|
|
||||||
|
# The output will be two sequences of range(num_elements)
|
||||||
|
# non-deterministically interleaved together. If the orders of the elements
|
||||||
|
# were the same, first_order and second_order computed below will be equal.
|
||||||
|
first_order = {}
|
||||||
|
second_order = {}
|
||||||
|
for element in output:
|
||||||
|
if element in first_order:
|
||||||
|
second_order[element] = len(second_order)
|
||||||
|
else:
|
||||||
|
first_order[element] = len(first_order)
|
||||||
|
self.assertNotEqual(first_order, second_order)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testMultipleEpochs(self):
|
def testMultipleEpochs(self):
|
||||||
num_elements = 3
|
num_elements = 3
|
||||||
|
Loading…
Reference in New Issue
Block a user