Add a multi worker input test with a single file

This test describes about the current behavior, which may not be ideal.

PiperOrigin-RevId: 314974505
Change-Id: I95ef46cbb2f5b5dd2dce72a2ba5d3b7e33dd1642
This commit is contained in:
Ran Chen 2020-06-05 12:25:42 -07:00 committed by TensorFlower Gardener
parent 735ded004a
commit d66ae5d65f
4 changed files with 67 additions and 7 deletions

View File

@ -12,6 +12,9 @@ exports_files(["LICENSE"])
py_library(
name = "distribute_test_lib_pip",
data = [
"//tensorflow/python/distribute/testdata:text_input",
],
deps = [
":combinations",
":model_combinations",
@ -1767,6 +1770,9 @@ py_test(
cuda_py_test(
name = "strategy_common_test",
srcs = ["strategy_common_test.py"],
data = [
"//tensorflow/python/distribute/testdata:text_input",
],
python_version = "PY3",
tags = [
"multi_and_single_gpu",
@ -1783,10 +1789,12 @@ cuda_py_test(
":strategy_test_lib",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:distribute_options",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:def_function",
"@absl_py//absl/testing:parameterized",
],

View File

@ -20,7 +20,9 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
@ -28,8 +30,10 @@ from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
@ -39,7 +43,7 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
combinations.combine(
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
strategy_combinations.strategies_minus_tpu,
mode=['eager']))
mode=["eager"]))
def testSimpleReduce(self, strategy):
def fn_eager():
@ -65,14 +69,14 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
mode=["eager"]))
class DistributedCollectiveAllReduceStrategyTest(
strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
@combinations.generate(
combinations.combine(
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
mode=['eager']))
def testDatasetFromFunction(self, strategy):
def dataset_fn(input_context):
global_batch_size = 10
@ -95,6 +99,37 @@ class DistributedCollectiveAllReduceStrategyTest(
sum_value.numpy(),
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
def testDatasetFromFunctionSingleFile(self, strategy):
files = [
test.test_src_dir_path("python/distribute/testdata/input0.txt"),
]
dataset = readers.TextLineDataset(files).batch(4)
dataset = strategy.experimental_distribute_dataset(dataset)
# Different workers may get different errors.
if multi_worker_test_base.get_task_index() == 0:
with self.assertRaisesRegex(errors.InvalidArgumentError,
"aren't enough elements"):
strategy.experimental_local_results(iter(dataset).get_next())
else:
with self.assertRaisesRegex(errors.OutOfRangeError, "End of sequence"):
strategy.experimental_local_results(iter(dataset).get_next())
if __name__ == '__main__':
def testDatasetFromFunctionSingleFileAutoShardPolicyData(self, strategy):
files = [
test.test_src_dir_path("python/distribute/testdata/input0.txt"),
]
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = (
distribute_options.AutoShardPolicy.DATA)
dataset = readers.TextLineDataset(files).map(
string_ops.string_to_number).batch(4).with_options(options)
dataset = strategy.experimental_distribute_dataset(dataset)
result = strategy.experimental_local_results(iter(dataset).get_next())
expected_on_workers = [[[1., 2.]], [[3., 4.]]]
self.assertAllEqual(
self.evaluate(result),
expected_on_workers[multi_worker_test_base.get_task_index()])
if __name__ == "__main__":
combinations.main()

View File

@ -0,0 +1,11 @@
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "text_input",
srcs = [
"input0.txt",
],
)

View File

@ -0,0 +1,6 @@
1
2
3
4
5
6