The CollectiveHints class is also renamed to CommunicationOptions. The communication enum is added to it. CommunicationOptions stays experimental since the detailed options may change, but it's rather clear we need an options argument for these cross device communications. PiperOrigin-RevId: 337547832 Change-Id: I376171672698d5923b4e52f2567d4a584c8e21b6
590 lines
22 KiB
Python
590 lines
22 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Test DistributionStrategy, ReplicaContext, and supporting APIs."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from tensorflow.python.autograph.core import converter_testing
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.distribute import combinations
|
|
from tensorflow.python.distribute import distribute_lib
|
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
|
from tensorflow.python.distribute import input_lib
|
|
from tensorflow.python.distribute import reduce_util
|
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.training import server_lib
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
class _TestReplicaContext(distribute_lib.ReplicaContext):
|
|
|
|
def merge_call(self, fn, *args, **kwargs):
|
|
return kwargs["test_arg"]
|
|
|
|
|
|
def _get_test_variable(name, synchronization, aggregation):
|
|
return {
|
|
"name": name,
|
|
"synchronization": synchronization,
|
|
"aggregation": aggregation
|
|
}
|
|
|
|
|
|
def _test_input_fn(input_context):
|
|
del input_context
|
|
return dataset_ops.DatasetV2.from_tensors(1.).repeat()
|
|
|
|
|
|
class _TestStrategy(distribute_lib.Strategy):
|
|
|
|
def __init__(self):
|
|
super(_TestStrategy, self).__init__(_TestExtended(self))
|
|
|
|
|
|
class _TestExtended(distribute_lib.StrategyExtendedV1):
|
|
|
|
def __init__(self, distribute):
|
|
super(_TestExtended, self).__init__(distribute)
|
|
worker_device_pairs = [("", ["/device:CPU:0"])]
|
|
self._input_workers = input_lib.InputWorkers(worker_device_pairs)
|
|
|
|
def _call_for_each_replica(self, fn, args, kwargs):
|
|
with _TestReplicaContext(
|
|
self._container_strategy(), replica_id_in_sync_group=0):
|
|
return fn(*args, **kwargs)
|
|
|
|
def _create_variable(self, next_creator, **kwargs):
|
|
return _get_test_variable(kwargs["name"], kwargs["synchronization"],
|
|
kwargs["aggregation"])
|
|
|
|
def _make_input_fn_iterator(
|
|
self,
|
|
input_fn,
|
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
|
return input_lib.InputFunctionIterator(input_fn, self._input_workers,
|
|
[distribute_lib.InputContext()],
|
|
self._container_strategy())
|
|
|
|
def _distribute_datasets_from_function(self, dataset_fn, options):
|
|
return dataset_fn(distribute_lib.InputContext())
|
|
|
|
def _local_results(self, value):
|
|
return (value,)
|
|
|
|
def _reduce_to(self, reduce_op, value, destinations, options):
|
|
del reduce_op, destinations, options
|
|
return value
|
|
|
|
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
|
initial_loop_values=None):
|
|
# TODO(tomhennigan) This is missing many things (e.g. ctx.run_op).
|
|
ctx = input_lib.MultiStepContext()
|
|
for _ in range(iterations):
|
|
fn(ctx, iterator.get_next())
|
|
return ctx
|
|
|
|
def _update(self, var, fn, args, kwargs, group):
|
|
# The implementations of _update() and _update_non_slot() are identical
|
|
# except _update() passes `var` as the first argument to `fn()`.
|
|
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
|
|
|
|
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
|
|
del colocate_with
|
|
result = fn(*args, **kwargs)
|
|
if group:
|
|
return result
|
|
else:
|
|
return nest.map_structure(self._unwrap, result)
|
|
|
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
|
return replica_id_in_sync_group
|
|
|
|
|
|
def _assert_in_default_state(t):
|
|
t.assertIs(ds_context._get_default_replica_context(),
|
|
ds_context.get_replica_context())
|
|
t.assertIs(None, ds_context.get_cross_replica_context())
|
|
t.assertFalse(ds_context.in_cross_replica_context())
|
|
t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
|
|
t.assertFalse(ds_context.has_strategy())
|
|
|
|
|
|
def _run_in_and_out_of_scope(unbound_test_method):
|
|
def wrapper(test_case):
|
|
dist = _TestStrategy()
|
|
# Running in the default (replica) scope should be supported.
|
|
_assert_in_default_state(test_case)
|
|
unbound_test_method(test_case, dist)
|
|
# As well as running in the strategy scope.
|
|
with dist.scope():
|
|
unbound_test_method(test_case, dist)
|
|
_assert_in_default_state(test_case)
|
|
# When run under a different strategy the test method should fail.
|
|
another_strategy = _TestStrategy()
|
|
msg = "Mixing different .*Strategy objects"
|
|
with test_case.assertRaisesRegex(RuntimeError, msg):
|
|
with another_strategy.scope():
|
|
unbound_test_method(test_case, dist)
|
|
return wrapper
|
|
|
|
|
|
class TestStrategyTest(test.TestCase):
|
|
|
|
def testCallForEachReplica(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
|
|
def run_fn():
|
|
replica_context = ds_context.get_replica_context()
|
|
self.assertIsNotNone(replica_context)
|
|
self.assertIs(None, ds_context.get_cross_replica_context())
|
|
self.assertFalse(ds_context.in_cross_replica_context())
|
|
self.assertTrue(ds_context.has_strategy())
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
|
|
expected_value = _get_test_variable(
|
|
"bar", variable_scope.VariableSynchronization.AUTO,
|
|
variable_scope.VariableAggregation.NONE)
|
|
self.assertDictEqual(expected_value,
|
|
variable_scope.variable(1.0, name="bar"))
|
|
|
|
dist.extended.call_for_each_replica(run_fn)
|
|
with dist.scope():
|
|
dist.extended.call_for_each_replica(run_fn)
|
|
_assert_in_default_state(self)
|
|
|
|
def testScope(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
with dist.scope():
|
|
self.assertIs(None, ds_context.get_replica_context())
|
|
self.assertIs(dist, ds_context.get_cross_replica_context())
|
|
self.assertTrue(ds_context.in_cross_replica_context())
|
|
self.assertTrue(ds_context.has_strategy())
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
expected_value = _get_test_variable(
|
|
"baz", variable_scope.VariableSynchronization.AUTO,
|
|
variable_scope.VariableAggregation.NONE)
|
|
self.assertDictEqual(expected_value,
|
|
variable_scope.variable(1.0, name="baz"))
|
|
_assert_in_default_state(self)
|
|
|
|
def testScopeDeviceNestingError(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
# Open a device scope with dist.scope().
|
|
dist.extended._default_device = "/device:GPU:0"
|
|
scope = dist.scope()
|
|
scope.__enter__()
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
with ops.device("/device:CPU:0"):
|
|
with self.assertRaisesRegex(RuntimeError, "Device scope nesting error"):
|
|
scope.__exit__(None, None, None)
|
|
scope.__exit__(None, None, None)
|
|
_assert_in_default_state(self)
|
|
|
|
def testScopeVarCreatorNestingError(self):
|
|
|
|
def creator(next_creator, **kwargs):
|
|
return next_creator(**kwargs)
|
|
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
scope = dist.scope()
|
|
scope.__enter__()
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
with variable_scope.variable_creator_scope(creator):
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Variable creator scope nesting error"):
|
|
scope.__exit__(None, None, None)
|
|
scope.__exit__(None, None, None)
|
|
_assert_in_default_state(self)
|
|
|
|
def testScopeVarScopeNestingError(self):
|
|
# We create a new graph here to simplify clean-up, since the error
|
|
# we are triggering happens in the middle of scope.__exit__() and
|
|
# leaves us in a weird state.
|
|
with ops.Graph().as_default():
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
scope = dist.scope()
|
|
scope.__enter__()
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
with variable_scope.variable_scope("AA"):
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Variable scope nesting error"):
|
|
scope.__exit__(None, None, None)
|
|
_assert_in_default_state(self)
|
|
|
|
def testSettingSynchronizationAndAggregation(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
with dist.scope():
|
|
expected_value = _get_test_variable(
|
|
"baz", variable_scope.VariableSynchronization.ON_WRITE,
|
|
variable_scope.VariableAggregation.MEAN)
|
|
self.assertDictEqual(
|
|
expected_value,
|
|
variable_scope.variable(
|
|
1.0,
|
|
name="baz",
|
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
aggregation=variable_scope.VariableAggregation.MEAN))
|
|
_assert_in_default_state(self)
|
|
|
|
def testSetStrategy(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
dist2 = _TestStrategy()
|
|
ds_context.experimental_set_strategy(dist)
|
|
self.assertIs(None, ds_context.get_replica_context())
|
|
self.assertIs(dist, ds_context.get_cross_replica_context())
|
|
self.assertTrue(ds_context.in_cross_replica_context())
|
|
self.assertTrue(ds_context.has_strategy())
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
expected_value = _get_test_variable(
|
|
"baz", variable_scope.VariableSynchronization.AUTO,
|
|
variable_scope.VariableAggregation.NONE)
|
|
self.assertDictEqual(expected_value,
|
|
variable_scope.variable(1.0, name="baz"))
|
|
ds_context.experimental_set_strategy(dist2)
|
|
self.assertIs(dist2, ds_context.get_strategy())
|
|
ds_context.experimental_set_strategy(None)
|
|
_assert_in_default_state(self)
|
|
|
|
def testSetStrategyInScope(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
with dist.scope():
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Must not be called inside a `tf.distribute.Strategy` scope"):
|
|
ds_context.experimental_set_strategy(_TestStrategy())
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Must not be called inside a `tf.distribute.Strategy` scope"):
|
|
ds_context.experimental_set_strategy(dist)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Must not be called inside a `tf.distribute.Strategy` scope"):
|
|
ds_context.experimental_set_strategy(None)
|
|
_assert_in_default_state(self)
|
|
|
|
def testSameScopeNesting(self):
|
|
_assert_in_default_state(self)
|
|
dist = _TestStrategy()
|
|
scope_a = dist.scope()
|
|
with scope_a:
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
scope_b = dist.scope()
|
|
with scope_b:
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
with scope_a:
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
dist2 = _TestStrategy()
|
|
scope2 = dist2.scope()
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Mixing different tf.distribute.Strategy objects"):
|
|
with scope2:
|
|
pass
|
|
_assert_in_default_state(self)
|
|
with scope_b:
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
_assert_in_default_state(self)
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testMakeInputFnIterator(self, dist):
|
|
self.assertIsNotNone(dist.make_input_fn_iterator(_test_input_fn))
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testReduce(self, dist):
|
|
x = constant_op.constant(1.)
|
|
x_r = dist.reduce(reduce_util.ReduceOp.MEAN, x, axis=None)
|
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
|
|
|
def testReductions_acceptStringOps(self):
|
|
dist = _TestStrategy()
|
|
for op in ("mean", "MEAN", "sum", "SUM"):
|
|
x = constant_op.constant(1.)
|
|
y = constant_op.constant(1.)
|
|
x_r = dist.reduce(op, x, axis=None)
|
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
|
x_r = dist.extended.reduce_to(op, x, "/CPU:0")
|
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
|
x_r, y_r = dist.extended.batch_reduce_to(op,
|
|
((x, "/CPU:0"), (y, "/CPU:0")))
|
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
|
self.assertEqual(self.evaluate(y), self.evaluate(y_r))
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testExperimentalRunStepsOnIterator(self, dist):
|
|
all_inputs = []
|
|
dataset = dataset_ops.Dataset.from_tensors(1.).repeat()
|
|
dist.extended.experimental_run_steps_on_iterator(
|
|
lambda _, inputs: all_inputs.append(self.evaluate(inputs)),
|
|
dataset_ops.make_one_shot_iterator(dataset))
|
|
self.assertEqual(all_inputs, [1.])
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testReduceTo(self, dist):
|
|
x = constant_op.constant(1.)
|
|
x_r = dist.extended.reduce_to(reduce_util.ReduceOp.MEAN, x, "/CPU:0")
|
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testBatchReduceTo(self, dist):
|
|
x = constant_op.constant(1.)
|
|
y = constant_op.constant(1.)
|
|
x_r, y_r = dist.extended.batch_reduce_to(reduce_util.ReduceOp.MEAN,
|
|
((x, "/CPU:0"), (y, "/CPU:0")))
|
|
self.assertEqual(self.evaluate(x), self.evaluate(x_r))
|
|
self.assertEqual(self.evaluate(y), self.evaluate(y_r))
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testUpdate(self, dist):
|
|
with dist.scope():
|
|
v = variables.Variable(1.)
|
|
t = constant_op.constant(2.)
|
|
|
|
def assign_fn(vv, tt):
|
|
self.assertIs(vv, v)
|
|
self.assertIs(tt, t)
|
|
dist.extended.update(v, assign_fn, (t,))
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testUpdateAutoGraph(self, dist):
|
|
with dist.scope():
|
|
v = variables.Variable(1.)
|
|
t = constant_op.constant(2.)
|
|
|
|
def assign_fn(unused_vv, unused_tt):
|
|
self.assertTrue(converter_testing.is_inside_generated_code())
|
|
|
|
@def_function.function # AutoGraph is default-on only within tf.function
|
|
def test_fn():
|
|
dist.extended.update(v, assign_fn, (t,))
|
|
|
|
test_fn()
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testUpdateNonSlot(self, dist):
|
|
t = constant_op.constant(2.)
|
|
update_calls = []
|
|
dist.extended.update_non_slot(t, lambda: update_calls.append(1))
|
|
self.assertEqual(len(update_calls), 1)
|
|
|
|
@_run_in_and_out_of_scope
|
|
def testUpdateNonSlotAutoGraph(self, dist):
|
|
t = constant_op.constant(2.)
|
|
|
|
def update_fn():
|
|
self.assertTrue(converter_testing.is_inside_generated_code())
|
|
|
|
@def_function.function # AutoGraph is default-on only within tf.function
|
|
def test_fn():
|
|
dist.extended.update_non_slot(t, update_fn)
|
|
|
|
test_fn()
|
|
|
|
def testClusterResolverDefaultNotImplemented(self):
|
|
dist = _TestStrategy()
|
|
self.assertIsNone(dist.cluster_resolver)
|
|
base_cluster_spec = server_lib.ClusterSpec({
|
|
"ps": ["ps0:2222", "ps1:2222"],
|
|
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
|
})
|
|
cluster_resolver = SimpleClusterResolver(base_cluster_spec)
|
|
dist.extended._cluster_resolver = cluster_resolver
|
|
self.assertIs(dist.cluster_resolver, cluster_resolver)
|
|
|
|
|
|
# _TestStrategy2 is like _TestStrategy, except it doesn't change variable
|
|
# creation.
|
|
class _TestStrategy2(distribute_lib.Strategy):
|
|
|
|
def __init__(self):
|
|
super(_TestStrategy2, self).__init__(_TestExtended2(self))
|
|
|
|
|
|
class _TestExtended2(_TestExtended):
|
|
|
|
def _create_variable(self, next_creator, **kwargs):
|
|
return next_creator(**kwargs)
|
|
|
|
|
|
class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def testMergeCall(self):
|
|
_assert_in_default_state(self)
|
|
|
|
def merge_fn(dist, s):
|
|
self.assertIs(ds_context._get_default_strategy(), dist)
|
|
self.assertIs(None, ds_context.get_replica_context())
|
|
self.assertIs(dist, ds_context.get_cross_replica_context())
|
|
self.assertTrue(ds_context.in_cross_replica_context())
|
|
self.assertIs(dist, ds_context.get_strategy())
|
|
self.assertFalse(ds_context.has_strategy())
|
|
return "foo_" + s
|
|
|
|
replica_ctx = ds_context.get_replica_context()
|
|
self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
|
|
self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
|
|
_assert_in_default_state(self)
|
|
|
|
def testMergeCallAutoGraph(self):
|
|
_assert_in_default_state(self)
|
|
|
|
def merge_fn(_, s):
|
|
self.assertTrue(converter_testing.is_inside_generated_code())
|
|
return s
|
|
|
|
@def_function.function # AutoGraph is default-on only within tf.function
|
|
def test_fn():
|
|
replica_ctx = ds_context.get_replica_context()
|
|
replica_ctx.merge_call(merge_fn, args=("bar",))
|
|
|
|
test_fn()
|
|
|
|
def testScopeMostlyNoOp(self):
|
|
_assert_in_default_state(self)
|
|
|
|
test_strategy = _TestStrategy2()
|
|
with test_strategy.scope():
|
|
variable_scope.variable(1.0, name="before")
|
|
|
|
default_strategy = ds_context._get_default_strategy()
|
|
scope = default_strategy.scope()
|
|
with scope:
|
|
_assert_in_default_state(self)
|
|
|
|
with test_strategy.scope():
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Mixing different tf.distribute.Strategy objects"):
|
|
variable_scope.variable(1.0, name="error")
|
|
|
|
with scope:
|
|
_assert_in_default_state(self)
|
|
|
|
with test_strategy.scope():
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Mixing different tf.distribute.Strategy objects"):
|
|
variable_scope.variable(1.0, name="also_error")
|
|
|
|
_assert_in_default_state(self)
|
|
|
|
_assert_in_default_state(self)
|
|
with test_strategy.scope():
|
|
variable_scope.variable(1.0, name="after")
|
|
|
|
def testExperimentalRunV2(self):
|
|
default_strategy = ds_context._get_default_strategy()
|
|
dataset = dataset_ops.Dataset.range(10).batch(2)
|
|
iterator = default_strategy.extended._make_dataset_iterator(dataset)
|
|
next_val = iterator.get_next()
|
|
|
|
def train_step(input_data):
|
|
return input_data
|
|
|
|
for _ in range(2):
|
|
default_strategy.run(train_step, args=(next_val,))
|
|
|
|
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
|
def testDistributedDatasets(self):
|
|
default_strategy = ds_context._get_default_strategy()
|
|
if context.executing_eagerly():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
|
|
dist_dataset = default_strategy.experimental_distribute_dataset(
|
|
dataset_fn(distribute_lib.InputContext()))
|
|
next_val = next(iter(dist_dataset))
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10).batch(2)
|
|
dist_dataset = default_strategy.experimental_distribute_dataset(
|
|
dataset_fn(distribute_lib.InputContext()))
|
|
iterator = dist_dataset.make_initializable_iterator()
|
|
self.evaluate(iterator.initializer)
|
|
next_val = iterator.get_next()
|
|
self.assertAllEqual([0, 1], self.evaluate(next_val))
|
|
|
|
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
|
def testDistributedDatasetsFromFunction(self):
|
|
default_strategy = ds_context._get_default_strategy()
|
|
if context.executing_eagerly():
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
|
|
dist_dataset_from_func = \
|
|
default_strategy.distribute_datasets_from_function(
|
|
dataset_fn)
|
|
next_val = next(iter(dist_dataset_from_func))
|
|
self.assertAllEqual([0, 1], self.evaluate(next_val))
|
|
else:
|
|
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
|
|
dist_dataset_from_func = \
|
|
default_strategy.distribute_datasets_from_function(
|
|
dataset_fn)
|
|
dataset_ops.make_initializable_iterator(dist_dataset_from_func)
|
|
|
|
@combinations.generate(combinations.combine(tf_api_version=1))
|
|
def testV1(self):
|
|
self.assertIsInstance(ds_context.get_strategy(), distribute_lib.StrategyV1)
|
|
|
|
@combinations.generate(combinations.combine(tf_api_version=2))
|
|
def testV2(self):
|
|
self.assertIsInstance(ds_context.get_strategy(), distribute_lib.Strategy)
|
|
|
|
|
|
class InputContextTest(test.TestCase):
|
|
|
|
def testProperties(self):
|
|
input_context = distribute_lib.InputContext(
|
|
num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
|
|
self.assertEqual(6, input_context.num_replicas_in_sync)
|
|
self.assertEqual(1, input_context.input_pipeline_id)
|
|
self.assertEqual(2, input_context.num_input_pipelines)
|
|
|
|
def testPerReplicaBatchSize(self):
|
|
input_context = distribute_lib.InputContext(
|
|
num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
|
|
self.assertEqual(2, input_context.get_per_replica_batch_size(12))
|
|
with self.assertRaises(ValueError):
|
|
input_context.get_per_replica_batch_size(13)
|
|
|
|
def testStr(self):
|
|
input_context = distribute_lib.InputContext(
|
|
num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42)
|
|
self.assertEqual(
|
|
"tf.distribute.InputContext(input pipeline id 0, total: 1)",
|
|
str(input_context))
|
|
input_context = distribute_lib.InputContext(
|
|
num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42)
|
|
self.assertEqual(
|
|
"tf.distribute.InputContext(input pipeline id 1, total: 3)",
|
|
str(input_context))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|