From 83fe1bad15e65e8db5e546d683d0ad591f19fad7 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Fri, 19 Jun 2020 14:02:55 -0700 Subject: [PATCH] [tf.data service] Add test that different workers use independent shuffle orders. If shuffle seeds are unspecified, shuffle order should be non-deterministically chosen on each worker. PiperOrigin-RevId: 317375549 Change-Id: I35e32cfbbfb8558451a079875b495708347a23bf --- .../kernel_tests/data_service_ops_test.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index 796ab328980..488bf97f184 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops @@ -78,6 +79,28 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) + @combinations.generate(test_base.eager_only_combinations()) + def testDifferentShuffleOrders(self): + random_seed.set_random_seed(None) + num_elements = 100 + master_address = self.create_cluster(2) + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.shuffle(num_elements) + ds = _make_distributed_dataset(ds, master_address) + output = [elem.numpy() for elem in ds] + + # The output will be two sequences of range(num_elements) + # non-deterministically interleaved together. If the orders of the elements + # were the same, first_order and second_order computed below will be equal. + first_order = {} + second_order = {} + for element in output: + if element in first_order: + second_order[element] = len(second_order) + else: + first_order[element] = len(first_order) + self.assertNotEqual(first_order, second_order) + @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): num_elements = 3