Re-enable multi process pool runner tests
PiperOrigin-RevId: 339084496 Change-Id: I0085dcc86ddcf8ff237234bc15e50815587442b5
This commit is contained in:
parent
7f532163ff
commit
e357988ab3
@ -590,13 +590,9 @@ class MultiProcessPoolRunnerTest(test.TestCase):
|
||||
self.assertAllEqual(result, [1, 1])
|
||||
|
||||
def test_global_pool(self):
|
||||
if multi_process_runner.is_oss():
|
||||
self.skipTest('TODO(b/170360740): Failing in OSS')
|
||||
_global_pool.run(fn_that_does_nothing)
|
||||
|
||||
def test_nested_pool(self):
|
||||
if multi_process_runner.is_oss():
|
||||
self.skipTest('TODO(b/170360740): Failing in OSS')
|
||||
|
||||
def fn():
|
||||
# This runs in sub processes, so they are each using their own
|
||||
|
||||
@ -309,52 +309,50 @@ class LocalCollectiveAllReduceStrategy(
|
||||
with policy.policy_scope('mixed_float16'):
|
||||
self._test_mixed_precision(None, None, required_gpus)
|
||||
|
||||
# TODO(b/170360740): Timeout in OSS
|
||||
if not multi_process_runner.is_oss():
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
],
|
||||
mode=['eager']))
|
||||
class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
],
|
||||
mode=['eager']))
|
||||
class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testFitWithoutStepsPerEpochPartialBatch(self, strategy):
|
||||
def testFitWithoutStepsPerEpochPartialBatch(self, strategy):
|
||||
|
||||
def _model_fn():
|
||||
x = layers.Input(shape=(1,), name='input')
|
||||
y = layers.Dense(1, name='dense')(x)
|
||||
model = training.Model(x, y)
|
||||
return model
|
||||
def _model_fn():
|
||||
x = layers.Input(shape=(1,), name='input')
|
||||
y = layers.Dense(1, name='dense')(x)
|
||||
model = training.Model(x, y)
|
||||
return model
|
||||
|
||||
def _get_dataset():
|
||||
inputs = array_ops.expand_dims_v2(
|
||||
constant_op.constant(range(10)), axis=1)
|
||||
targets = array_ops.expand_dims_v2(
|
||||
constant_op.constant(range(10)), axis=1)
|
||||
# Make global batch size 12 for 2 replicas and a non-repeated dataset
|
||||
# with 10 elements so that we have partial batch
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
(inputs, targets)).batch(
|
||||
12, drop_remainder=False)
|
||||
return dataset
|
||||
def _get_dataset():
|
||||
inputs = array_ops.expand_dims_v2(
|
||||
constant_op.constant(range(10)), axis=1)
|
||||
targets = array_ops.expand_dims_v2(
|
||||
constant_op.constant(range(10)), axis=1)
|
||||
# Make global batch size 12 for 2 replicas and a non-repeated dataset
|
||||
# with 10 elements so that we have partial batch
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
(inputs, targets)).batch(
|
||||
12, drop_remainder=False)
|
||||
return dataset
|
||||
|
||||
with strategy.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
optimizer = optimizer_fn(0.001)
|
||||
model = _model_fn()
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
model.compile(optimizer, loss, metrics=metrics)
|
||||
dataset = _get_dataset()
|
||||
kernel_before = model.get_weights()[0][0]
|
||||
model.fit(dataset, epochs=10)
|
||||
kernel_after = model.get_weights()[0][0]
|
||||
self.assertNotEqual(kernel_before, kernel_after)
|
||||
self.assertGreater(abs(kernel_before - 1), abs(kernel_after - 1))
|
||||
with strategy.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
optimizer = optimizer_fn(0.001)
|
||||
model = _model_fn()
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
model.compile(optimizer, loss, metrics=metrics)
|
||||
dataset = _get_dataset()
|
||||
kernel_before = model.get_weights()[0][0]
|
||||
model.fit(dataset, epochs=10)
|
||||
kernel_after = model.get_weights()[0][0]
|
||||
self.assertNotEqual(kernel_before, kernel_after)
|
||||
self.assertGreater(abs(kernel_before - 1), abs(kernel_after - 1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user