[tf.data] Enabling a newly passing test.

PiperOrigin-RevId: 305730251
Change-Id: I273537bcf89023d263485e0aac951dd14ad21044
This commit is contained in:
Jiri Simsa 2020-04-09 12:00:16 -07:00 committed by TensorFlower Gardener
parent 1f9d744e66
commit d3fdebcec3

View File

@ -270,21 +270,19 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
"counter", (), dtypes.int32, use_resource=True) "counter", (), dtypes.int32, use_resource=True)
dataset0 = dataset_ops.Dataset.range(100).map( dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: counter_var.assign_add(1)) lambda _: counter_var.assign_add(1))
with self.assertRaises(errors.InvalidArgumentError): replicated_ds = distribute.replicate(dataset0,
replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2])
[self._device1, self._device2]) dataset1 = replicated_ds[self._device1]
dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2]
dataset2 = replicated_ds[self._device2] with ops.device(self._device0):
with ops.device(self._device0): self.assertDatasetProduces(
get_next0 = self.getNext(dataset0) dataset0, range(1, 101), requires_initialization=True)
with ops.device(self._device1): with ops.device(self._device1):
get_next1 = self.getNext(dataset1) self.assertDatasetProduces(
with ops.device(self._device2): dataset1, range(101, 201), requires_initialization=True)
get_next2 = self.getNext(dataset2) with ops.device(self._device2):
for _ in range(100): self.assertDatasetProduces(
self.evaluate(get_next0()) dataset2, range(201, 301), requires_initialization=True)
self.evaluate(get_next1())
self.evaluate(get_next2())
if __name__ == "__main__": if __name__ == "__main__":