1199 lines
47 KiB
Python
1199 lines
47 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for the input_lib library."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import json
|
|
import threading
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
from tensorflow.python import tf2
|
|
from tensorflow.python.compat import compat
|
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
|
from tensorflow.python.distribute import combinations
|
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.distribute import distribute_lib
|
|
from tensorflow.python.distribute import distribute_utils
|
|
from tensorflow.python.distribute import input_lib
|
|
from tensorflow.python.distribute import mirrored_strategy
|
|
from tensorflow.python.distribute import multi_worker_test_base
|
|
from tensorflow.python.distribute import parameter_server_strategy
|
|
from tensorflow.python.distribute import reduce_util
|
|
from tensorflow.python.distribute import strategy_combinations
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import test
|
|
from tensorflow.python.framework import composite_tensor
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import sparse_ops
|
|
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
class DistributedIteratorTestBase(test.TestCase):
|
|
|
|
# The passed input_context is to create a sharded dataset in between-graph
|
|
# case.
|
|
def _wrap_iterator(self,
|
|
input_type,
|
|
dataset_or_input_fn,
|
|
input_workers,
|
|
devices,
|
|
split_batch_by,
|
|
strategy,
|
|
input_context=None):
|
|
# The `input_context` passed in is to shard dataset for
|
|
# MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
|
|
# multiple InputContexts are needed.
|
|
if input_type == "input_fn":
|
|
self.assertIsNone(
|
|
input_context,
|
|
msg=("`The input_context` arg is only used to shard dataset in "
|
|
"`MultiWorkerMirroredStrategy` when the input type is dataset."))
|
|
|
|
input_contexts = []
|
|
for i in range(input_workers.num_workers):
|
|
input_contexts.append(
|
|
distribute_lib.InputContext(
|
|
# Note: `input_workers.num_workers` is always 1 in between-graph
|
|
# case.
|
|
num_input_pipelines=input_workers.num_workers,
|
|
input_pipeline_id=i,
|
|
num_replicas_in_sync=len(devices)))
|
|
|
|
iterator = input_lib.InputFunctionIterator(
|
|
dataset_or_input_fn,
|
|
input_workers,
|
|
input_contexts,
|
|
strategy)
|
|
else:
|
|
iterator = input_lib.DatasetIterator(
|
|
dataset_or_input_fn,
|
|
input_workers,
|
|
strategy,
|
|
split_batch_by=split_batch_by,
|
|
input_context=input_context)
|
|
return iterator
|
|
|
|
def _wrap_dataset(self,
|
|
input_type,
|
|
dataset,
|
|
input_workers,
|
|
split_batch_by,
|
|
strategy,
|
|
input_context=None):
|
|
if input_type == "dataset":
|
|
if tf2.enabled():
|
|
return input_lib.DistributedDataset(
|
|
dataset,
|
|
input_workers,
|
|
strategy,
|
|
split_batch_by=split_batch_by,
|
|
input_context=input_context)
|
|
else:
|
|
return input_lib.DistributedDatasetV1(
|
|
dataset,
|
|
input_workers,
|
|
strategy,
|
|
split_batch_by=split_batch_by,
|
|
input_context=input_context)
|
|
else:
|
|
return strategy.experimental_distribute_datasets_from_function(dataset)
|
|
|
|
def _test_input_iteration(self,
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
strategy,
|
|
sess=None,
|
|
split_batch_by=None,
|
|
input_context=None):
|
|
if iteration_type == "for_loop" and not context.executing_eagerly():
|
|
self.skipTest("unsupported test combination.")
|
|
|
|
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
|
|
self.skipTest("unsupported test combination.")
|
|
|
|
if api_type == "wrap_into_iterator" and input_type == "input_fn":
|
|
self.skipTest("unsupported test combination.")
|
|
|
|
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
|
|
|
if api_type == "wrap_into_iterator":
|
|
iterator = self._wrap_iterator(
|
|
input_type,
|
|
dataset_or_input_fn,
|
|
input_workers,
|
|
devices,
|
|
split_batch_by,
|
|
strategy,
|
|
input_context=input_context)
|
|
else:
|
|
# wrapping into a dataset:
|
|
dataset = self._wrap_dataset(
|
|
input_type,
|
|
dataset_or_input_fn,
|
|
input_workers,
|
|
split_batch_by,
|
|
strategy,
|
|
input_context=input_context)
|
|
|
|
if ops.executing_eagerly_outside_functions():
|
|
iterator = iter(dataset)
|
|
else:
|
|
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
|
iterator = dataset.make_initializable_iterator()
|
|
else:
|
|
self.skipTest("unsupported test combination")
|
|
|
|
if isinstance(iterator, composite_tensor.CompositeTensor):
|
|
nest.assert_same_structure(iterator, iterator._type_spec,
|
|
expand_composites=True)
|
|
|
|
if iteration_type == "get_next":
|
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
|
if not ops.executing_eagerly_outside_functions():
|
|
evaluate(control_flow_ops.group(iterator.initializer))
|
|
|
|
def test_get_next(iterator):
|
|
for expected_value in expected_values:
|
|
next_element = iterator.get_next()
|
|
computed_value = evaluate([
|
|
distribute_utils.select_replica(r, next_element)
|
|
for r in range(len(devices))
|
|
])
|
|
|
|
self.assertEqual(len(expected_value), len(computed_value))
|
|
for i in range(len(expected_value)):
|
|
self.assertAllEqual(expected_value[i], computed_value[i])
|
|
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
next_element = iterator.get_next()
|
|
evaluate([
|
|
distribute_utils.select_replica(r, next_element)
|
|
for r in range(len(devices))
|
|
])
|
|
|
|
# After re-initializing the iterator, should be able to iterate again.
|
|
if not ops.executing_eagerly_outside_functions():
|
|
evaluate(control_flow_ops.group(iterator.initializer))
|
|
else:
|
|
if api_type == "wrap_into_iterator":
|
|
self.skipTest("unsupported test combination")
|
|
else:
|
|
iterator = iter(dataset)
|
|
|
|
for expected_value in expected_values:
|
|
next_element = iterator.get_next()
|
|
computed_value = evaluate([
|
|
distribute_utils.select_replica(r, next_element)
|
|
for r in range(len(devices))
|
|
])
|
|
self.assertEqual(len(expected_value), len(computed_value))
|
|
for i in range(len(expected_value)):
|
|
self.assertAllEqual(expected_value[i], computed_value[i])
|
|
|
|
def test_get_next_as_optional(iterator):
|
|
for expected_value in expected_values:
|
|
next_element = iterator.get_next_as_optional()
|
|
computed_value = evaluate([
|
|
distribute_utils.select_replica(r, next_element.get_value())
|
|
for r in range(len(devices))
|
|
])
|
|
|
|
self.assertEqual(len(expected_value), len(computed_value))
|
|
for i in range(len(expected_value)):
|
|
self.assertAllEqual(expected_value[i], computed_value[i])
|
|
|
|
next_element = iterator.get_next_as_optional()
|
|
self.assertFalse(self.evaluate(next_element.has_value()))
|
|
with self.assertRaises(errors.InvalidArgumentError):
|
|
evaluate([
|
|
distribute_utils.select_replica(r, next_element.get_value())
|
|
for r in range(len(devices))
|
|
])
|
|
|
|
test_get_next(iterator)
|
|
|
|
# re-initializing the iterator
|
|
if not tf2.enabled():
|
|
self.skipTest("Not testing get_next_as_optional in TF1")
|
|
else:
|
|
if api_type == "wrap_into_iterator":
|
|
self.skipTest("unsupported test combination")
|
|
else:
|
|
iterator = iter(dataset)
|
|
|
|
test_get_next_as_optional(iterator)
|
|
|
|
if iteration_type == "for_loop" and context.executing_eagerly():
|
|
actual_values = []
|
|
for x in dataset:
|
|
computed_value = self.evaluate(
|
|
[distribute_utils.select_replica(r, x)
|
|
for r in range(len(devices))])
|
|
actual_values.append(computed_value)
|
|
for i, expected_value in enumerate(expected_values):
|
|
self.assertEqual(len(expected_value), len(actual_values[i]))
|
|
for j in range(len(expected_value)):
|
|
self.assertAllEqual(expected_value[j], actual_values[i][j])
|
|
|
|
def _create_dataset_or_input_fn(self, input_type, input_fn):
|
|
if input_type == "input_fn":
|
|
return input_fn
|
|
else:
|
|
return input_fn(distribute_lib.InputContext())
|
|
|
|
|
|
class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
|
parameterized.TestCase):
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
input_type=["input_fn", "dataset"],
|
|
distribution=[
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu
|
|
]))
|
|
def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
|
|
if not tf2.enabled():
|
|
self.skipTest("unsupported test combination")
|
|
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
|
if input_type == "dataset":
|
|
dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
|
|
input_workers,
|
|
distribution)
|
|
else:
|
|
dist_dataset = input_lib.get_distributed_datasets_from_function(
|
|
dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
|
|
distribution)
|
|
|
|
# Default Iterator types in TF2.
|
|
iterator = iter(dist_dataset)
|
|
self.assertIsInstance(iterator, input_lib.DistributedIterator)
|
|
self.assertIsInstance(iterator._iterators[0],
|
|
input_lib._SingleWorkerOwnedDatasetIterator)
|
|
|
|
# Disable creating owned iterators by setting a property on the strategy.
|
|
distribution._enable_legacy_iterators = True
|
|
iterator = iter(dist_dataset)
|
|
self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
|
|
self.assertIsInstance(iterator._iterators[0],
|
|
input_lib._SingleWorkerDatasetIterator)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu
|
|
]))
|
|
def testMultiDeviceIterInitialize(self, distribution):
|
|
if tf2.enabled():
|
|
self.skipTest("Only V1 is supported.")
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
|
"/device:CPU:0"])]
|
|
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
|
|
|
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
|
|
|
dist_dataset = input_lib.get_distributed_dataset(
|
|
dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
|
|
|
|
iterator = dataset_ops.make_one_shot_iterator(dist_dataset)
|
|
|
|
@def_function.function
|
|
def init_func_for_iter():
|
|
self.evaluate(iterator.initializer)
|
|
|
|
init_func_for_iter()
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
distribution=[
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu
|
|
],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
|
|
enable_get_next_as_optional):
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
expected_values = [[i] for i in range(10)]
|
|
|
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
distribution)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
|
],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
|
|
distribution, enable_get_next_as_optional):
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
|
"/device:CPU:0"])]
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
|
|
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
distribution)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
distribution=[strategy_combinations.tpu_strategy],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testTPU(self, input_type, api_type, iteration_type, distribution,
|
|
enable_get_next_as_optional):
|
|
worker_device_pairs = collections.OrderedDict()
|
|
for tpu_device in distribution.extended.worker_devices:
|
|
host_device = device_util.get_host_for_device(tpu_device)
|
|
worker_device_pairs.setdefault(host_device, [])
|
|
worker_device_pairs[host_device].append(tpu_device)
|
|
worker_device_pairs = worker_device_pairs.items()
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
expected_values = [[i, i + 1] for i in range(0, 10, 2)]
|
|
|
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
distribution)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
|
],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
|
|
enable_get_next_as_optional):
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
|
"/device:CPU:0"])]
|
|
|
|
def dataset_fn(ctx):
|
|
del ctx
|
|
if tf2.enabled():
|
|
dataset1 = dataset_ops.DatasetV2.range(10)
|
|
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
|
|
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
|
|
else:
|
|
dataset1 = dataset_ops.Dataset.range(10)
|
|
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
|
|
|
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
distribution)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu
|
|
]))
|
|
def testIterableIterator(self, distribution):
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
|
|
|
dataset = dataset_ops.DatasetV2.range(10)
|
|
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
|
|
distribution)
|
|
|
|
iterator = iter(dist_dataset)
|
|
for i, element in enumerate(iterator):
|
|
self.assertEqual(i, element.numpy())
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
drop_remainder=[True, False],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
|
]))
|
|
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
|
|
drop_remainder, distribution):
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
|
"/device:CPU:0"])]
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch( # pylint: disable=g-long-lambda
|
|
2, drop_remainder=drop_remainder)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda
|
|
2, drop_remainder=drop_remainder)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
# The last global batch only contains data for one replica.
|
|
if drop_remainder:
|
|
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
|
|
else:
|
|
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
|
|
distribution.extended.experimental_enable_get_next_as_optional = True
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
distribution)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
input_type=["dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
split_batch_by=[None, 2],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
|
],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testBatchSplitting(self, input_type, api_type, iteration_type,
|
|
split_batch_by, distribution,
|
|
enable_get_next_as_optional):
|
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
|
"/device:CPU:0"])]
|
|
batch_size = 10
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
updated_batch_size = (
|
|
batch_size // split_batch_by if split_batch_by else batch_size)
|
|
expected_values = [[range(i, i+updated_batch_size),
|
|
range(i+updated_batch_size, i+2*updated_batch_size)]
|
|
for i in range(0, 100, updated_batch_size*2)]
|
|
|
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_device_pairs,
|
|
expected_values,
|
|
distribution,
|
|
sess=None,
|
|
split_batch_by=split_batch_by)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.tpu_strategy,
|
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
|
],
|
|
))
|
|
def testCacheAcrossIteration(self, distribution):
|
|
if not tf2.enabled():
|
|
self.skipTest("Only V2 is supported.")
|
|
|
|
dataset = dataset_ops.Dataset.range(10).shuffle(10).cache().batch(2)
|
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
|
|
|
first_epoch = list(
|
|
distribution.experimental_local_results(x) for x in dist_dataset)
|
|
second_epoch = list(
|
|
distribution.experimental_local_results(x) for x in dist_dataset)
|
|
|
|
self.assertAllEqual(first_epoch, second_epoch)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.tpu_strategy,
|
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
|
],
|
|
reshuffle=[True, False]))
|
|
def testShuffleAcrossIterations(self, distribution, reshuffle):
|
|
if not tf2.enabled():
|
|
self.skipTest("Only V2 is supported.")
|
|
|
|
if not reshuffle and not compat.forward_compatible(2020, 5, 22):
|
|
self.skipTest("Functionality currently not supported.")
|
|
|
|
dataset = dataset_ops.Dataset.range(10).shuffle(
|
|
10, reshuffle_each_iteration=reshuffle).batch(2)
|
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
|
|
|
first_epoch = list(
|
|
distribution.experimental_local_results(x) for x in dist_dataset)
|
|
second_epoch = list(
|
|
distribution.experimental_local_results(x) for x in dist_dataset)
|
|
|
|
if reshuffle:
|
|
self.assertNotAllEqual(first_epoch, second_epoch)
|
|
else:
|
|
self.assertAllEqual(first_epoch, second_epoch)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.tpu_strategy,
|
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
|
]))
|
|
def testGetNextOptionalShape(self, distribution):
|
|
batch_size = 8
|
|
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
|
"feature": array_ops.ones([batch_size, 10]),
|
|
"label": array_ops.ones([batch_size]),
|
|
})
|
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
|
per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
|
|
|
|
@def_function.function
|
|
def train_fn():
|
|
for data in dist_dataset:
|
|
data = nest.map_structure(distribution.experimental_local_results, data)
|
|
feature = data["feature"]
|
|
label = data["label"]
|
|
|
|
# Asser the shapes are still staic from all replicas.
|
|
for replica_id in range(distribution.num_replicas_in_sync):
|
|
self.assertEqual([per_replica_batch_size, 10],
|
|
feature[replica_id].shape)
|
|
self.assertEqual([per_replica_batch_size], label[replica_id].shape)
|
|
|
|
train_fn()
|
|
|
|
|
|
class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|
parameterized.TestCase):
|
|
"""Tests for DistributedDataset with non-dense tensors."""
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
|
],
|
|
input_type=["dataset", "input_fn"],
|
|
drop_remainder=[False, True],
|
|
defun_type=["lambda", "tf_function"],
|
|
))
|
|
def testRaggedSparse(self, distribution, input_type, drop_remainder,
|
|
defun_type):
|
|
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
|
if not tf2.enabled():
|
|
self.skipTest("Only V2 is supported.")
|
|
|
|
defun = {"lambda": lambda f: f,
|
|
"tf_function": def_function.function}[defun_type]
|
|
distribution.extended.experimental_enable_get_next_as_optional = True
|
|
global_batch_size = 8
|
|
|
|
def dataset_fn(ctx=None):
|
|
ctx = ctx or distribute_lib.InputContext()
|
|
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
|
|
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
|
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
|
|
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
|
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
|
|
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
|
"dense": ragged_tensor.to_tensor(),
|
|
"ragged": ragged_tensor,
|
|
"sparse": ragged_tensor.to_sparse(),
|
|
})
|
|
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
|
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
|
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
|
|
distribution.extended._input_workers,
|
|
len(distribution.extended.worker_devices),
|
|
distribution)
|
|
# Assert that the tensors are rebatched and sparsity is preserved.
|
|
per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
|
|
self.assertAllEqual(
|
|
distribute_utils.select_replica(0, per_replica_batch["dense"]),
|
|
[[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
|
|
self.assertAllEqual(
|
|
distribute_utils.select_replica(1, per_replica_batch["dense"]),
|
|
[[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
|
|
# Transitively check the ragged and sparse tensors by densification.
|
|
for i in range(2):
|
|
self.assertLen(
|
|
distribute_utils.select_replica(i,
|
|
per_replica_batch["ragged"]).values,
|
|
6)
|
|
self.assertAllEqual(
|
|
distribute_utils.select_replica(
|
|
i, per_replica_batch["ragged"]).to_tensor(),
|
|
distribute_utils.select_replica(i, per_replica_batch["dense"]))
|
|
self.assertLen(
|
|
distribute_utils.select_replica(i,
|
|
per_replica_batch["sparse"]).indices,
|
|
6)
|
|
self.assertAllEqual(
|
|
sparse_ops.sparse_tensor_to_dense(
|
|
distribute_utils.select_replica(i, per_replica_batch["sparse"])),
|
|
distribute_utils.select_replica(i, per_replica_batch["dense"]))
|
|
# Iterate through all the batches and sum them up.
|
|
def sum_batch(per_replica_features):
|
|
"""Sums the `PerReplica` values in the `per_replica_features` map."""
|
|
|
|
def map_fn(per_replica_values):
|
|
per_replica_sums = distribution.run(
|
|
(lambda x: math_ops.reduce_sum(x.values)) if all(
|
|
map(sparse_tensor.is_sparse, per_replica_values.values)) else
|
|
math_ops.reduce_sum, (per_replica_values,))
|
|
return distribution.reduce(
|
|
reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
|
|
|
|
return nest.map_structure(map_fn, per_replica_features)
|
|
|
|
def _reduce(state, batch):
|
|
sums = sum_batch(batch)
|
|
return {name: value + sums[name] for name, value in state.items()}
|
|
|
|
def sum_for_loop(dataset):
|
|
sums = {"dense": 0., "ragged": 0., "sparse": 0.}
|
|
for batch in dataset:
|
|
sums = _reduce(sums, batch)
|
|
return sums
|
|
|
|
def sum_while_loop(iterator, reduce_fn):
|
|
sums = {"dense": 0., "ragged": 0., "sparse": 0.}
|
|
while True:
|
|
try:
|
|
sums = reduce_fn(sums, iterator)
|
|
except (StopIteration, errors.OutOfRangeError):
|
|
return sums
|
|
|
|
while_sums = sum_while_loop(
|
|
iter(dataset),
|
|
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
|
self.assertAllEqual(
|
|
nest.flatten(while_sums),
|
|
# When there's no partial batch, the sum is smaller.
|
|
[200. if drop_remainder else 310.] * 3)
|
|
for_sums = defun(sum_for_loop)(dataset)
|
|
# For loops always call get next as optional inside tf functions, so we
|
|
# expect 310 here when using an input function (as there are 5 batches of
|
|
# size 4 round robined over 2 replicas.
|
|
expected_for_sum = 200.
|
|
if (not drop_remainder or (
|
|
defun_type == "tf_function" and input_type == "input_fn")):
|
|
expected_for_sum = 310.
|
|
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["eager"],
|
|
distribution=[
|
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
|
strategy_combinations.one_device_strategy,
|
|
strategy_combinations.mirrored_strategy_with_one_cpu
|
|
],
|
|
input_type=["dataset", "input_fn"],
|
|
drop_remainder=[False, True],
|
|
tensor_type=["sparse", "ragged"],
|
|
enable_get_next_as_optional=[True, False]
|
|
))
|
|
def testRaggedSparseGetNextAsOptional(
|
|
self, distribution, input_type, drop_remainder, tensor_type,
|
|
enable_get_next_as_optional):
|
|
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
|
if not tf2.enabled():
|
|
self.skipTest("Only V2 is supported.")
|
|
|
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
global_batch_size = 8
|
|
|
|
def dataset_fn(ctx=None):
|
|
ctx = ctx or distribute_lib.InputContext()
|
|
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
|
|
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
|
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
|
|
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
|
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
|
|
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
|
tensor_type: (ragged_tensor if tensor_type == "ragged" else
|
|
ragged_tensor.to_sparse()),
|
|
})
|
|
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
|
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
|
|
|
if input_type == "dataset":
|
|
ds = distribution.experimental_distribute_dataset(
|
|
dataset_fn(distribute_lib.InputContext()))
|
|
else:
|
|
ds = distribution.experimental_distribute_datasets_from_function(
|
|
dataset_fn)
|
|
iterator = iter(ds)
|
|
|
|
self.assertEqual(iterator._enable_get_next_as_optional,
|
|
(not drop_remainder) and enable_get_next_as_optional)
|
|
|
|
|
|
class DistributedIteratorMultiWorkerTest(
|
|
multi_worker_test_base.MultiWorkerTestBase, DistributedIteratorTestBase,
|
|
parameterized.TestCase):
|
|
|
|
def _cpu_devices(self):
|
|
return [
|
|
("/job:worker/replica:0/task:0",
|
|
["/job:worker/replica:0/task:0/device:CPU:0"]),
|
|
("/job:worker/replica:0/task:1",
|
|
["/job:worker/replica:0/task:1/device:CPU:0"])]
|
|
|
|
def _cpu_and_one_gpu_devices(self):
|
|
return [
|
|
("/job:worker/replica:0/task:0", [
|
|
"/job:worker/replica:0/task:0/device:GPU:0",
|
|
"/job:worker/replica:0/task:0/device:CPU:0"
|
|
]),
|
|
("/job:worker/replica:0/task:1", [
|
|
"/job:worker/replica:0/task:1/device:GPU:0",
|
|
"/job:worker/replica:0/task:1/device:CPU:0"
|
|
])
|
|
]
|
|
|
|
@combinations.generate(combinations.combine(
|
|
mode=["graph"],
|
|
input_type=["dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
|
|
def testAutoshardingOption(self, input_type, api_type, iteration_type,
|
|
auto_shard_policy):
|
|
ds_option = dataset_ops.Options()
|
|
ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
|
|
if tf2.enabled():
|
|
dataset_fn = (
|
|
lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option))
|
|
else:
|
|
dataset_fn = (
|
|
lambda _: dataset_ops.Dataset.range(4).with_options(ds_option))
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
strategy = mirrored_strategy.MirroredStrategy(
|
|
devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
|
|
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
|
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
|
worker_devices = self._cpu_devices()
|
|
with context.graph_mode(), self.cached_session() as sess:
|
|
if auto_shard_policy == AutoShardPolicy.AUTO:
|
|
expected_values = [[0, 1], [2, 3]]
|
|
else:
|
|
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
|
|
self._test_input_iteration(input_type, api_type, iteration_type,
|
|
dataset_or_input_fn, worker_devices,
|
|
expected_values, strategy, sess)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testOneDevicePerWorker(self, input_type, api_type, iteration_type,
|
|
enable_get_next_as_optional):
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(4)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
strategy = mirrored_strategy.MirroredStrategy(
|
|
devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
|
|
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
|
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
|
worker_devices = self._cpu_devices()
|
|
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
|
|
|
if input_type == "dataset":
|
|
# Autosharded
|
|
expected_values = [[0, 1], [2, 3]]
|
|
else:
|
|
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
|
|
strategy.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_devices,
|
|
expected_values,
|
|
strategy,
|
|
sess=sess)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
enable_get_next_as_optional=[True, False],
|
|
required_gpus=1))
|
|
def testTwoDevicesPerWorker(self, input_type, api_type, iteration_type,
|
|
enable_get_next_as_optional):
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(4)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
strategy = mirrored_strategy.MirroredStrategy(
|
|
devices=(self._cpu_and_one_gpu_devices()[0][1] +
|
|
self._cpu_and_one_gpu_devices()[1][1]),
|
|
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
|
["/job:worker/task:0", "/job:worker/task:1"], 2))
|
|
worker_devices = self._cpu_and_one_gpu_devices()
|
|
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
|
|
|
if input_type == "dataset":
|
|
# Autosharded
|
|
expected_values = [[0, 2, 1, 3]]
|
|
else:
|
|
expected_values = [[0, 1, 0, 1], [2, 3, 2, 3]]
|
|
strategy.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_devices,
|
|
expected_values,
|
|
strategy,
|
|
sess=sess)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
enable_get_next_as_optional=[True, False]))
|
|
def testTupleDataset(self, input_type, api_type, iteration_type,
|
|
enable_get_next_as_optional):
|
|
strategy = mirrored_strategy.MirroredStrategy(
|
|
devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
|
|
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
|
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
|
worker_devices = self._cpu_devices()
|
|
|
|
def dataset_fn(ctx):
|
|
del ctx
|
|
if tf2.enabled():
|
|
dataset1 = dataset_ops.DatasetV2.range(4)
|
|
dataset2 = dataset_ops.DatasetV2.range(4).map(lambda x: x**2)
|
|
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
|
|
else:
|
|
dataset1 = dataset_ops.Dataset.range(4)
|
|
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
|
|
|
if input_type == "dataset":
|
|
# Autosharded
|
|
expected_values = [[(0, 0), (1, 1)], [(2, 4), (3, 9)]]
|
|
else:
|
|
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
|
|
strategy.extended.experimental_enable_get_next_as_optional = (
|
|
enable_get_next_as_optional)
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_devices,
|
|
expected_values,
|
|
strategy,
|
|
sess=sess)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
required_gpus=1))
|
|
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type):
|
|
strategy = mirrored_strategy.MirroredStrategy(
|
|
devices=(self._cpu_and_one_gpu_devices()[0][1] +
|
|
self._cpu_and_one_gpu_devices()[1][1]),
|
|
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
|
["/job:worker/task:0", "/job:worker/task:1"], 2))
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
worker_devices = self._cpu_and_one_gpu_devices()
|
|
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
|
if input_type == "dataset":
|
|
# Autosharded
|
|
expected_values = [[[0, 1], [4, 5], [2, 3], [6, 7]], [[8], [], [], []]]
|
|
else:
|
|
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
|
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
|
|
strategy.extended.experimental_enable_get_next_as_optional = True
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_devices,
|
|
expected_values,
|
|
strategy,
|
|
sess=sess)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph"],
|
|
input_type=["input_fn", "dataset"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next"],
|
|
strategy_cls=[
|
|
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
|
parameter_server_strategy.ParameterServerStrategy,
|
|
],
|
|
required_gpus=0))
|
|
def testUnevenDatasetBatchesBetweenGraph(self, input_type, api_type,
|
|
iteration_type, strategy_cls):
|
|
if api_type == "wrap_into_dataset" and input_type == "input_fn":
|
|
self.skipTest("unsupported test combination.")
|
|
if tf2.enabled():
|
|
# The V2 tests are skipped since we don't support creating an
|
|
# iterator for DistributedDataset in graph mode.
|
|
self.skipTest("unsupported test combination")
|
|
# Environment variable is global, we need locking when patching TF_CONFIG.
|
|
lock = threading.Lock()
|
|
|
|
def _worker_fn(task_type, task_id, num_gpus):
|
|
del num_gpus
|
|
tf_config = {
|
|
"cluster": self._cluster_spec,
|
|
"task": {
|
|
"type": task_type,
|
|
"index": task_id
|
|
}
|
|
}
|
|
with context.graph_mode(), lock, test.mock.patch.dict(
|
|
"os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
|
|
strategy = strategy_cls()
|
|
with context.graph_mode(), strategy.scope(), self.cached_session(
|
|
target="grpc://" + self._cluster_spec[task_type][task_id]) as sess:
|
|
if tf2.enabled():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(5).batch(2)
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.Dataset.range(5).batch(2)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
if (input_type == "dataset" and strategy_cls is
|
|
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
|
|
# Autosharded
|
|
if task_id == 0:
|
|
expected_values = [[[0, 1]], [[4]]]
|
|
else:
|
|
expected_values = [[[2, 3]], [[]]]
|
|
|
|
# input_context is for between-graph auto-sharding.
|
|
input_context = distribute_lib.InputContext(
|
|
num_input_pipelines=2,
|
|
input_pipeline_id=task_id,
|
|
num_replicas_in_sync=2)
|
|
else:
|
|
expected_values = [[[0, 1]], [[2, 3]], [[4]]]
|
|
input_context = None
|
|
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
[("/job:%s/task:%d" %
|
|
(task_type, task_id), strategy.extended.worker_devices)],
|
|
expected_values,
|
|
strategy,
|
|
sess=sess,
|
|
input_context=input_context)
|
|
|
|
self._run_between_graph_clients(_worker_fn, self._cluster_spec, 0)
|
|
|
|
@combinations.generate(
|
|
combinations.combine(
|
|
mode=["graph"], input_type=["input_fn"],
|
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
|
iteration_type=["get_next", "for_loop"],
|
|
required_gpus=1))
|
|
def testDifferentDatasets(self, input_type, api_type, iteration_type):
|
|
def dataset_fn(ctx):
|
|
if ctx.input_pipeline_id == 0:
|
|
return dataset_ops.Dataset.range(8).batch(2)
|
|
else:
|
|
return dataset_ops.Dataset.range(9).batch(2)
|
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
|
input_type, dataset_fn)
|
|
|
|
strategy = mirrored_strategy.MirroredStrategy(
|
|
devices=(self._cpu_and_one_gpu_devices()[0][1] +
|
|
self._cpu_and_one_gpu_devices()[1][1]),
|
|
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
|
["/job:worker/task:0", "/job:worker/task:1"], 2))
|
|
worker_devices = self._cpu_and_one_gpu_devices()
|
|
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
|
|
|
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
|
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
|
|
strategy.extended.experimental_enable_get_next_as_optional = True
|
|
self._test_input_iteration(
|
|
input_type,
|
|
api_type,
|
|
iteration_type,
|
|
dataset_or_input_fn,
|
|
worker_devices,
|
|
expected_values,
|
|
strategy,
|
|
sess=sess)
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|