Add dataset_fn to fault_tolerance_test to increase test coverage.
PiperOrigin-RevId: 341748168 Change-Id: I13c0da9a5e6f4aea69f1306da243e4f1963f9f9f
This commit is contained in:
parent
88385067a2
commit
24d1fba948
@ -24,6 +24,7 @@ import threading
|
||||
import time
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
@ -35,6 +36,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -61,19 +63,26 @@ class Model(object):
|
||||
dtype=dtypes.float32)
|
||||
self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||
|
||||
def dataset_fn():
|
||||
data = random_ops.random_uniform((1000, 1000))
|
||||
dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
|
||||
return dataset
|
||||
|
||||
self.iterator = iter(
|
||||
self.cluster_coord.create_per_worker_dataset(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def train_fn(self):
|
||||
# train_fn roughly took 0.5s to execute on Intel Xeon Gold 6154 (3.00GHZ)
|
||||
# w/o any compilation optimization (two worker setup).
|
||||
def train_fn(self, iterator):
|
||||
for _ in math_ops.range(5):
|
||||
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), self.w)
|
||||
x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
|
||||
x = math_ops.matmul(random_ops.random_uniform((1000, 1000)), x)
|
||||
self.w.assign_add(x)
|
||||
self.iterations.assign_add(1)
|
||||
|
||||
def schedule_training_functions(self, num_steps):
|
||||
with self.strategy.scope():
|
||||
for _ in range(num_steps):
|
||||
self.cluster_coord.schedule(self.train_fn)
|
||||
self.cluster_coord.schedule(self.train_fn, args=(self.iterator,))
|
||||
|
||||
def join_training_functions(self):
|
||||
self.cluster_coord.join()
|
||||
|
Loading…
Reference in New Issue
Block a user