Add dataset_fn to fault_tolerance_test to increase test coverage.

PiperOrigin-RevId: 341748168
Change-Id: I13c0da9a5e6f4aea69f1306da243e4f1963f9f9f
This commit is contained in:
Yuefeng Zhou 2020-11-10 19:44:41 -08:00 committed by TensorFlower Gardener
parent 88385067a2
commit 24d1fba948

View File

@ -24,6 +24,7 @@ import threading
import time
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
@ -35,6 +36,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@ -61,19 +63,26 @@ class Model(object):
dtype=dtypes.float32)
self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
def dataset_fn():
data = random_ops.random_uniform((1000, 1000))
dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
return dataset
self.iterator = iter(
self.cluster_coord.create_per_worker_dataset(dataset_fn))
@def_function.function
def train_fn(self):
# train_fn roughly took 0.5s to execute on Intel Xeon Gold 6154 (3.00GHZ)
# w/o any compilation optimization (two worker setup).
def train_fn(self, iterator):
for _ in math_ops.range(5):
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), self.w)
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x)
self.w.assign_add(x)
self.iterations.assign_add(1)
def schedule_training_functions(self, num_steps):
with self.strategy.scope():
for _ in range(num_steps):
self.cluster_coord.schedule(self.train_fn)
self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
def join_training_functions(self):
self.cluster_coord.join()