diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py index cc075d09c3d..96ac19aff94 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py @@ -24,6 +24,7 @@ import threading import time from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 @@ -35,6 +36,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -61,19 +63,26 @@ class Model(object): dtype=dtypes.float32) self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) + def dataset_fn(): + data = random_ops.random_uniform((1000, 1000)) + dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() + return dataset + + self.iterator = iter( + self.cluster_coord.create_per_worker_dataset(dataset_fn)) + @def_function.function - def train_fn(self): - # train_fn roughly took 0.5s to execute on Intel Xeon Gold 6154 (3.00GHZ) - # w/o any compilation optimization (two worker setup). + def train_fn(self, iterator): for _ in math_ops.range(5): - x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), self.w) + x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w) + x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x) self.w.assign_add(x) self.iterations.assign_add(1) def schedule_training_functions(self, num_steps): with self.strategy.scope(): for _ in range(num_steps): - self.cluster_coord.schedule(self.train_fn) + self.cluster_coord.schedule(self.train_fn, args=(self.iterator,)) def join_training_functions(self): self.cluster_coord.join()