[tf.data] Enabling a newly passing test.
PiperOrigin-RevId: 305730251 Change-Id: I273537bcf89023d263485e0aac951dd14ad21044
This commit is contained in:
parent
1f9d744e66
commit
d3fdebcec3
@ -270,21 +270,19 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"counter", (), dtypes.int32, use_resource=True)
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: counter_var.assign_add(1))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
with ops.device(self._device0):
|
||||
self.assertDatasetProduces(
|
||||
dataset0, range(1, 101), requires_initialization=True)
|
||||
with ops.device(self._device1):
|
||||
self.assertDatasetProduces(
|
||||
dataset1, range(101, 201), requires_initialization=True)
|
||||
with ops.device(self._device2):
|
||||
self.assertDatasetProduces(
|
||||
dataset2, range(201, 301), requires_initialization=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user