[tf.data] Enabling a newly passing test.

PiperOrigin-RevId: 305730251
Change-Id: I273537bcf89023d263485e0aac951dd14ad21044
This commit is contained in:
Jiri Simsa 2020-04-09 12:00:16 -07:00 committed by TensorFlower Gardener
parent 1f9d744e66
commit d3fdebcec3

View File

@ -270,21 +270,19 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
"counter", (), dtypes.int32, use_resource=True)
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: counter_var.assign_add(1))
with self.assertRaises(errors.InvalidArgumentError):
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
get_next0 = self.getNext(dataset0)
self.assertDatasetProduces(
dataset0, range(1, 101), requires_initialization=True)
with ops.device(self._device1):
get_next1 = self.getNext(dataset1)
self.assertDatasetProduces(
dataset1, range(101, 201), requires_initialization=True)
with ops.device(self._device2):
get_next2 = self.getNext(dataset2)
for _ in range(100):
self.evaluate(get_next0())
self.evaluate(get_next1())
self.evaluate(get_next2())
self.assertDatasetProduces(
dataset2, range(201, 301), requires_initialization=True)
if __name__ == "__main__":