[tf.data] Enabling a newly passing test.
PiperOrigin-RevId: 305730251 Change-Id: I273537bcf89023d263485e0aac951dd14ad21044
This commit is contained in:
parent
1f9d744e66
commit
d3fdebcec3
@ -270,21 +270,19 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
"counter", (), dtypes.int32, use_resource=True)
|
"counter", (), dtypes.int32, use_resource=True)
|
||||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||||
lambda _: counter_var.assign_add(1))
|
lambda _: counter_var.assign_add(1))
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
replicated_ds = distribute.replicate(dataset0,
|
||||||
replicated_ds = distribute.replicate(dataset0,
|
[self._device1, self._device2])
|
||||||
[self._device1, self._device2])
|
dataset1 = replicated_ds[self._device1]
|
||||||
dataset1 = replicated_ds[self._device1]
|
dataset2 = replicated_ds[self._device2]
|
||||||
dataset2 = replicated_ds[self._device2]
|
with ops.device(self._device0):
|
||||||
with ops.device(self._device0):
|
self.assertDatasetProduces(
|
||||||
get_next0 = self.getNext(dataset0)
|
dataset0, range(1, 101), requires_initialization=True)
|
||||||
with ops.device(self._device1):
|
with ops.device(self._device1):
|
||||||
get_next1 = self.getNext(dataset1)
|
self.assertDatasetProduces(
|
||||||
with ops.device(self._device2):
|
dataset1, range(101, 201), requires_initialization=True)
|
||||||
get_next2 = self.getNext(dataset2)
|
with ops.device(self._device2):
|
||||||
for _ in range(100):
|
self.assertDatasetProduces(
|
||||||
self.evaluate(get_next0())
|
dataset2, range(201, 301), requires_initialization=True)
|
||||||
self.evaluate(get_next1())
|
|
||||||
self.evaluate(get_next2())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user