Add dataset_fn to fault_tolerance_test to increase test coverage.

PiperOrigin-RevId: 341748168
Change-Id: I13c0da9a5e6f4aea69f1306da243e4f1963f9f9f
This commit is contained in:
Yuefeng Zhou 2020-11-10 19:44:41 -08:00 committed by TensorFlower Gardener
parent 88385067a2
commit 24d1fba948

View File

@ -24,6 +24,7 @@ import threading
import time import time
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import parameter_server_strategy_v2
@ -35,6 +36,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
@ -61,19 +63,26 @@ class Model(object):
dtype=dtypes.float32) dtype=dtypes.float32)
self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
def dataset_fn():
data = random_ops.random_uniform((1000, 1000))
dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
return dataset
self.iterator = iter(
self.cluster_coord.create_per_worker_dataset(dataset_fn))
@def_function.function @def_function.function
def train_fn(self): def train_fn(self, iterator):
# train_fn roughly took 0.5s to execute on Intel Xeon Gold 6154 (3.00GHZ)
# w/o any compilation optimization (two worker setup).
for _ in math_ops.range(5): for _ in math_ops.range(5):
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), self.w) x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x)
self.w.assign_add(x) self.w.assign_add(x)
self.iterations.assign_add(1) self.iterations.assign_add(1)
def schedule_training_functions(self, num_steps): def schedule_training_functions(self, num_steps):
with self.strategy.scope(): with self.strategy.scope():
for _ in range(num_steps): for _ in range(num_steps):
self.cluster_coord.schedule(self.train_fn) self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
def join_training_functions(self): def join_training_functions(self):
self.cluster_coord.join() self.cluster_coord.join()