Add a multi worker input test with a single file
This test describes about the current behavior, which may not be ideal. PiperOrigin-RevId: 314974505 Change-Id: I95ef46cbb2f5b5dd2dce72a2ba5d3b7e33dd1642
This commit is contained in:
parent
735ded004a
commit
d66ae5d65f
@ -12,6 +12,9 @@ exports_files(["LICENSE"])
|
|||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "distribute_test_lib_pip",
|
name = "distribute_test_lib_pip",
|
||||||
|
data = [
|
||||||
|
"//tensorflow/python/distribute/testdata:text_input",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":combinations",
|
":combinations",
|
||||||
":model_combinations",
|
":model_combinations",
|
||||||
@ -1767,6 +1770,9 @@ py_test(
|
|||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "strategy_common_test",
|
name = "strategy_common_test",
|
||||||
srcs = ["strategy_common_test.py"],
|
srcs = ["strategy_common_test.py"],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/python/distribute/testdata:text_input",
|
||||||
|
],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
@ -1783,10 +1789,12 @@ cuda_py_test(
|
|||||||
":strategy_test_lib",
|
":strategy_test_lib",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python/data/experimental/ops:distribute_options",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/data/ops:readers",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
|
|||||||
@ -20,7 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import readers
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
@ -28,8 +30,10 @@ from tensorflow.python.distribute import strategy_combinations
|
|||||||
from tensorflow.python.distribute import strategy_test_lib
|
from tensorflow.python.distribute import strategy_test_lib
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +43,7 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
|
|||||||
combinations.combine(
|
combinations.combine(
|
||||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
|
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
|
||||||
strategy_combinations.strategies_minus_tpu,
|
strategy_combinations.strategies_minus_tpu,
|
||||||
mode=['eager']))
|
mode=["eager"]))
|
||||||
def testSimpleReduce(self, strategy):
|
def testSimpleReduce(self, strategy):
|
||||||
|
|
||||||
def fn_eager():
|
def fn_eager():
|
||||||
@ -65,14 +69,14 @@ class StrategyReduceTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
|
||||||
|
mode=["eager"]))
|
||||||
class DistributedCollectiveAllReduceStrategyTest(
|
class DistributedCollectiveAllReduceStrategyTest(
|
||||||
strategy_test_lib.DistributionTestBase,
|
strategy_test_lib.DistributionTestBase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.combine(
|
|
||||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers],
|
|
||||||
mode=['eager']))
|
|
||||||
def testDatasetFromFunction(self, strategy):
|
def testDatasetFromFunction(self, strategy):
|
||||||
def dataset_fn(input_context):
|
def dataset_fn(input_context):
|
||||||
global_batch_size = 10
|
global_batch_size = 10
|
||||||
@ -95,6 +99,37 @@ class DistributedCollectiveAllReduceStrategyTest(
|
|||||||
sum_value.numpy(),
|
sum_value.numpy(),
|
||||||
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
|
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
|
||||||
|
|
||||||
|
def testDatasetFromFunctionSingleFile(self, strategy):
|
||||||
|
files = [
|
||||||
|
test.test_src_dir_path("python/distribute/testdata/input0.txt"),
|
||||||
|
]
|
||||||
|
dataset = readers.TextLineDataset(files).batch(4)
|
||||||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||||
|
# Different workers may get different errors.
|
||||||
|
if multi_worker_test_base.get_task_index() == 0:
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"aren't enough elements"):
|
||||||
|
strategy.experimental_local_results(iter(dataset).get_next())
|
||||||
|
else:
|
||||||
|
with self.assertRaisesRegex(errors.OutOfRangeError, "End of sequence"):
|
||||||
|
strategy.experimental_local_results(iter(dataset).get_next())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def testDatasetFromFunctionSingleFileAutoShardPolicyData(self, strategy):
|
||||||
|
files = [
|
||||||
|
test.test_src_dir_path("python/distribute/testdata/input0.txt"),
|
||||||
|
]
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_distribute.auto_shard_policy = (
|
||||||
|
distribute_options.AutoShardPolicy.DATA)
|
||||||
|
dataset = readers.TextLineDataset(files).map(
|
||||||
|
string_ops.string_to_number).batch(4).with_options(options)
|
||||||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||||
|
result = strategy.experimental_local_results(iter(dataset).get_next())
|
||||||
|
expected_on_workers = [[[1., 2.]], [[3., 4.]]]
|
||||||
|
self.assertAllEqual(
|
||||||
|
self.evaluate(result),
|
||||||
|
expected_on_workers[multi_worker_test_base.get_task_index()])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
combinations.main()
|
combinations.main()
|
||||||
|
|||||||
11
tensorflow/python/distribute/testdata/BUILD
vendored
Normal file
11
tensorflow/python/distribute/testdata/BUILD
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "text_input",
|
||||||
|
srcs = [
|
||||||
|
"input0.txt",
|
||||||
|
],
|
||||||
|
)
|
||||||
6
tensorflow/python/distribute/testdata/input0.txt
vendored
Normal file
6
tensorflow/python/distribute/testdata/input0.txt
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
Loading…
x
Reference in New Issue
Block a user