Add a multi worker input test with a single file
This test describes about the current behavior, which may not be ideal. PiperOrigin-RevId: 314974505 Change-Id: I95ef46cbb2f5b5dd2dce72a2ba5d3b7e33dd1642
This commit is contained in:
parent
735ded004a
commit
d66ae5d65f
|
@ -12,6 +12,9 @@ exports_files(["LICENSE"])
|
|||
|
||||
py_library(
|
||||
name = "distribute_test_lib_pip",
|
||||
data = [
|
||||
"//tensorflow/python/distribute/testdata:text_input",
|
||||
],
|
||||
deps = [
|
||||
":combinations",
|
||||
":model_combinations",
|
||||
|
@ -1767,6 +1770,9 @@ py_test(
|
|||
cuda_py_test(
|
||||
name = "strategy_common_test",
|
||||
srcs = ["strategy_common_test.py"],
|
||||
data = [
|
||||
"//tensorflow/python/distribute/testdata:text_input",
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
|
@ -1783,10 +1789,12 @@ cuda_py_test(
|
|||
":strategy_test_lib",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/experimental/ops:distribute_options",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
|
|
@ -20,7 +20,9 @@ from __future__ import print_function
|
|||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
|
@ -28,8 +30,10 @@ from tensorflow.python.distribute import strategy_combinations
|
|||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
@ -39,7 +43,7 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
|
|||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
|
||||
strategy_combinations.strategies_minus_tpu,
|
||||
mode=['eager']))
|
||||
mode=["eager"]))
|
||||
def testSimpleReduce(self, strategy):
|
||||
|
||||
def fn_eager():
|
||||
|
@ -65,14 +69,14 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
|
|||
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
|
||||
mode=["eager"]))
|
||||
class DistributedCollectiveAllReduceStrategyTest(
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
|
||||
mode=['eager']))
|
||||
def testDatasetFromFunction(self, strategy):
|
||||
def dataset_fn(input_context):
|
||||
global_batch_size = 10
|
||||
|
@ -95,6 +99,37 @@ class DistributedCollectiveAllReduceStrategyTest(
|
|||
sum_value.numpy(),
|
||||
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
|
||||
|
||||
def testDatasetFromFunctionSingleFile(self, strategy):
|
||||
files = [
|
||||
test.test_src_dir_path("python/distribute/testdata/input0.txt"),
|
||||
]
|
||||
dataset = readers.TextLineDataset(files).batch(4)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
# Different workers may get different errors.
|
||||
if multi_worker_test_base.get_task_index() == 0:
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"aren't enough elements"):
|
||||
strategy.experimental_local_results(iter(dataset).get_next())
|
||||
else:
|
||||
with self.assertRaisesRegex(errors.OutOfRangeError, "End of sequence"):
|
||||
strategy.experimental_local_results(iter(dataset).get_next())
|
||||
|
||||
if __name__ == '__main__':
|
||||
def testDatasetFromFunctionSingleFileAutoShardPolicyData(self, strategy):
|
||||
files = [
|
||||
test.test_src_dir_path("python/distribute/testdata/input0.txt"),
|
||||
]
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = (
|
||||
distribute_options.AutoShardPolicy.DATA)
|
||||
dataset = readers.TextLineDataset(files).map(
|
||||
string_ops.string_to_number).batch(4).with_options(options)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
result = strategy.experimental_local_results(iter(dataset).get_next())
|
||||
expected_on_workers = [[[1., 2.]], [[3., 4.]]]
|
||||
self.assertAllEqual(
|
||||
self.evaluate(result),
|
||||
expected_on_workers[multi_worker_test_base.get_task_index()])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "text_input",
|
||||
srcs = [
|
||||
"input0.txt",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
Loading…
Reference in New Issue