STT-tensorflow/tensorflow/python/distribute/mirrored_strategy_test.py
Ran Chen bf0b6db744 Retire MultiWorkerAllReduce
MultiWorkerAllReduce is for in graph replicated multi worker training, which is no longer supported. MultiWorkerAllReduce is no longer exposed so we shouldn't expect any usage of it.

Existing tests with MultiWorkerAllReduce are changed to use CollectiveAllReduce.
It's not officially supported but we should grandfather existing tests since
it's non-trivial to migrate them to MultiWorkerMirroredStrategy.

PiperOrigin-RevId: 328685198
Change-Id: Ib3e65fa772f05ca6a69f9e98864ac42bfdf4af51
2020-08-27 00:12:37 -07:00

1441 lines
55 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import sys
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib
GPU_TEST = "test_gpu" in sys.argv[0]
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode=["graph", "eager"]))
class MirroredTwoDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase,
parameterized.TestCase):
def testMinimizeLoss(self, distribution):
if context.executing_eagerly():
self._test_minimize_loss_eager(distribution)
else:
self._test_minimize_loss_graph(distribution)
def testReplicaId(self, distribution):
self._test_replica_id(distribution)
def testNumReplicasInSync(self, distribution):
self.assertEqual(2, distribution.num_replicas_in_sync)
def testCallAndMergeExceptions(self, distribution):
self._test_call_and_merge_exceptions(distribution)
def testRunRegroupError(self, distribution):
def run_fn():
replica_id = int(self.evaluate(_replica_id()))
# Generates a list with different lengths on different devices.
# Will fail in _regroup() (if more than one device).
return list(range(replica_id))
with distribution.scope(), self.assertRaises(AssertionError):
distribution.extended.call_for_each_replica(run_fn)
def testReduceToCpu(self, distribution):
with distribution.scope():
result = distribution.extended.call_for_each_replica(_replica_id)
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=None)
expected = sum(range(distribution.num_replicas_in_sync))
self.assertEqual(expected, self.evaluate(reduced))
def reduce_axis_helper(self, distribution, replica_squared_fn):
with distribution.scope():
num_replicas = distribution.num_replicas_in_sync
result = distribution.extended.call_for_each_replica(replica_squared_fn)
# sum
reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0)
expected = sum(x * (x + 1) for x in range(num_replicas))
self.assertNear(expected, self.evaluate(reduced), 0.00001)
# mean
reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0)
expected /= sum(x + 1 for x in range(num_replicas))
self.assertNear(expected, self.evaluate(reduced), 0.00001)
def testReduceAxisToCpu(self, distribution):
for dtype in (dtypes.float32, dtypes.int32):
def replica_squared_fn(dtype=dtype):
# Lists with different lengths on different replicas.
replica_id = _replica_id_as_int()
return array_ops.identity(
math_ops.cast([replica_id] * (replica_id + 1), dtype))
self.reduce_axis_helper(distribution, replica_squared_fn)
def set_v2_tensorshape(self, v2):
if v2:
tensor_shape.enable_v2_tensorshape()
else:
tensor_shape.disable_v2_tensorshape()
def testReduceAxisToCpuUnknownShape(self, distribution):
original_v2 = tensor_shape._TENSORSHAPE_V2_OVERRIDE # pylint: disable=protected-access
try:
for v2 in (False, True):
self.set_v2_tensorshape(v2)
for dtype in (dtypes.float32, dtypes.int32):
for shape in ((None,), None): # Test both unknown size and rank.
def replica_squared_fn(dtype=dtype, shape=shape):
# Lists with different lengths on different replicas.
replica_id = _replica_id_as_int()
tensor = math_ops.cast([replica_id] * (replica_id + 1), dtype)
# Erase shape information
return array_ops.placeholder_with_default(tensor, shape=shape)
self.reduce_axis_helper(distribution, replica_squared_fn)
finally:
self.set_v2_tensorshape(original_v2)
def testReplicateDataset(self, distribution):
if tf2.enabled() and not context.executing_eagerly():
self.skipTest("Skipping test since we do not support graph mode in TF 2")
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
self._test_input_fn_iterable(distribution, input_fn, expected_values)
def testMakeInputFnIteratorWithDataset(self, distribution):
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
expected_values)
def testMakeInputFnIteratorWithCallable(self, distribution):
def fn():
dataset = dataset_ops.Dataset.range(2).interleave(
(lambda _: dataset_ops.Dataset.range(10)), cycle_length=2)
it = dataset_ops.make_one_shot_iterator(dataset)
return it.get_next
expected_values = [[i, i] for i in range(0, 10)]
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
expected_values, test_reinitialize=False,
ignore_order=True)
def testNumpyDataset(self, distribution):
self._test_numpy_dataset(distribution)
def testGlobalStepUpdate(self, distribution):
self._test_global_step_update(distribution)
def testRun(self, distribution):
self._test_run(distribution)
def testAllReduceSum(self, distribution):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self, distribution):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self, distribution):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self, distribution):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self, distribution):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
def testSummaryForReplicaZeroOnly(self, distribution):
self._test_summary_for_replica_zero_only(distribution)
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
def test_prefetch_to_device_dataset(self, distribution):
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=True)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if context.executing_eagerly():
item = next(iter(dataset))
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest("unsupported test combination")
device_types = [
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values]
expected_device_types = [
tf_device.DeviceSpec.from_string(device).device_type for
device in distribution.extended.worker_devices]
self.assertAllEqual(device_types, expected_device_types)
def test_prefetch_to_host_dataset(self, distribution):
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if context.executing_eagerly():
item = next(iter(dataset))
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest("unsupported test combination")
device_types = {
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values}
self.assertAllEqual(list(device_types), ["CPU"])
def one_device_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph", "eager"])
@combinations.generate(one_device_combinations())
class MirroredOneDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.OneDeviceDistributionTestBase,
parameterized.TestCase):
def testMinimizeLoss(self, distribution):
if context.executing_eagerly():
self._test_minimize_loss_eager(distribution)
else:
self._test_minimize_loss_graph(distribution)
def testReplicaId(self, distribution):
self._test_replica_id(distribution)
def testCallAndMergeExceptions(self, distribution):
self._test_call_and_merge_exceptions(distribution)
def testRun(self, distribution):
self._test_run(distribution)
def testAllReduceSum(self, distribution):
self._test_all_reduce_sum(distribution)
def testAllReduceSumGradients(self, distribution):
self._test_all_reduce_sum_gradients(distribution)
def testAllReduceSumGradientTape(self, distribution):
self._test_all_reduce_sum_gradient_tape(distribution)
def testAllReduceMean(self, distribution):
self._test_all_reduce_mean(distribution)
def testAllReduceMeanGradients(self, distribution):
self._test_all_reduce_mean_gradients(distribution)
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
class MirroredStrategyVariableCreatorStackTest(
test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph"]))
def testCreatorStacksAreThreadLocal(self, distribution):
def model_fn():
replica_id_str = str(self.evaluate(_replica_id()))
def thread_creator_fn(next_creator, **kwargs):
return next_creator(**kwargs) + ":thread_" + replica_id_str
with variable_scope.variable_creator_scope(thread_creator_fn):
# Create a variable in this scope.
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
ds_context.get_replica_context().merge_call(lambda _: _)
return v
def main_thread_creator(next_creator, **kwargs):
# We are not using the underlying next_creator for test purposes.
del next_creator, kwargs
return "main_thread"
with context.graph_mode(), \
distribution.scope(), \
variable_scope.variable_creator_scope(main_thread_creator):
result = distribution.extended.call_for_each_replica(model_fn)
result = distribution.experimental_local_results(result)
expected = ("main_thread:thread_0", "main_thread:thread_1")
self.assertEqual(expected, result)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredStrategyCallForEachReplicaTest(test.TestCase):
def testExecutingEagerlyOutsideFunction(self, distribution):
"""Verify we preserve the value of executing_eagerly_outside_functions()."""
def model_fn():
return ops.executing_eagerly_outside_functions()
originally = ops.executing_eagerly_outside_functions()
with distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.experimental_local_results(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
# Verify this all again, but this time in a FuncGraph.
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.experimental_local_results(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
def testFunctionInCallForEachReplica(self, distribution):
traces = []
@def_function.function
def model_fn():
traces.append(1)
return ds_context.get_replica_context().replica_id_in_sync_group
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(
(0, 1),
self.evaluate(distribution.experimental_local_results(result)))
self.assertLen(traces, distribution.num_replicas_in_sync)
def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution):
traces = []
@def_function.function
def model_fn():
traces.append(1)
return ds_context.get_replica_context().replica_id_in_sync_group
@def_function.function
def step():
return distribution.extended.call_for_each_replica(model_fn)
with distribution.scope():
result = step()
self.assertEqual(
(0, 1),
self.evaluate(distribution.experimental_local_results(result)))
self.assertLen(traces, distribution.num_replicas_in_sync)
def testControlFlowFunctionInCallForEachReplicaWithMergeCall(
self, distribution):
def merge_fn(strategy, value):
return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
@def_function.function
def model_fn():
def body_fn(i):
return ds_context.get_replica_context().merge_call(merge_fn, args=(i,))
return control_flow_ops.while_loop_v2(lambda i: i < 2, body_fn, [0])
with distribution.scope():
with self.assertRaisesRegex(
RuntimeError, "`merge_call` called while defining a new graph."):
distribution.extended.call_for_each_replica(model_fn)
def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution):
def merge_fn(strategy, value):
return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
def model_fn():
@def_function.function
def model_fn_nested():
t = constant_op.constant(1)
return ds_context.get_replica_context().merge_call(merge_fn, args=(t,))
return model_fn_nested()
with distribution.scope():
with self.assertRaisesRegex(
RuntimeError, "`merge_call` called while defining a new graph."):
distribution.extended.call_for_each_replica(model_fn)
def testFunctionInCallForEachReplicaWithMergeCall(self, distribution):
def merge_fn(_):
pass
@def_function.function
def model_fn():
ds_context.get_replica_context().merge_call(merge_fn)
return 0.
with distribution.scope():
self.assertEqual(
self.evaluate(distribution.extended.call_for_each_replica(model_fn)),
0.)
def testFunctionInCallForEachReplicaCached(self, distribution):
traces = []
@def_function.function
def model_fn():
traces.append(None)
self.assertEmpty(traces)
for i in range(10):
distribution.extended.call_for_each_replica(model_fn)
if i == 0:
num_devices = len(traces)
self.assertGreater(num_devices, 0)
else:
# model_fn should not have been re-evaluated so the length should remain
# the same.
self.assertLen(traces, num_devices)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph"]))
class MirroredStrategyNameScopeTest(test.TestCase):
# NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
# testing this in eager mode.
def testNameScope(self, distribution):
def model_fn():
with ops.name_scope("foo"):
a = constant_op.constant(1.0, name="a")
ds_context.get_replica_context().merge_call(lambda _: _)
b = constant_op.constant(1.0, name="b")
return a, b
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
v0, v1 = distribution.experimental_local_results(v)
self.assertEqual("main/foo/" + name + ":0", v0.name)
self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name)
def testWithDefaultName(self, distribution):
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
ds_context.get_replica_context().merge_call(lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
with context.graph_mode(), distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
v0, v1 = distribution.experimental_local_results(v)
self.assertEqual("foo/" + name + ":0", v0.name)
self.assertEqual("replica_1/foo/" + name + ":0", v1.name)
# variable_scope.variable() respects name scopes when creating
# variables. On the other hand variable_scope.get_variable() ignores name
# scopes but respects variable scope when creating variables. We test both
# methods of creating variables to make sure that we have the same
# variable names in both cases.
def testNameScopeWithVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.variable(1.0, name="c")
return c
def model_fn():
b = variable_scope.variable(1.0, name="b")
with ops.name_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
a = variable_scope.variable(1.0, name="a")
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("main/a:0", a0.name)
self.assertEqual("main/a/replica_1:0", a1.name)
self.assertEqual("main/b:0", b0.name)
self.assertEqual("main/b/replica_1:0", b1.name)
self.assertEqual("main/foo/c:0", c0.name)
self.assertEqual("main/foo/c/replica_1:0", c1.name)
def testNameScopeWithGetVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.get_variable("c", [1])
return c
def model_fn():
b = variable_scope.get_variable("b", [1])
with ops.name_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
a = variable_scope.get_variable("a", [1])
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("a:0", a0.name)
self.assertEqual("a/replica_1:0", a1.name)
self.assertEqual("b:0", b0.name)
self.assertEqual("b/replica_1:0", b1.name)
self.assertEqual("c:0", c0.name)
self.assertEqual("c/replica_1:0", c1.name)
def testVariableScopeWithGetVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.get_variable("c", [1])
return c
def model_fn():
b = variable_scope.get_variable("b", [1])
with variable_scope.variable_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with variable_scope.variable_scope("main"):
a = variable_scope.get_variable("a", [1])
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("main/a:0", a0.name)
self.assertEqual("main/a/replica_1:0", a1.name)
self.assertEqual("main/b:0", b0.name)
self.assertEqual("main/b/replica_1:0", b1.name)
self.assertEqual("main/foo/c:0", c0.name)
self.assertEqual("main/foo/c/replica_1:0", c1.name)
@combinations.generate(
combinations.combine(
distribution=[
combinations.NamedDistribution(
"Mirrored3Devices",
# pylint: disable=g-long-lambda
lambda: mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]),
required_gpus=2)
],
mode=["graph", "eager"]))
class MirroredThreeDeviceDistributionTest(
strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
def testThreeDevices(self, distribution):
def model_fn():
v = variable_scope.variable(1.0, name="foo")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
result = distribution.extended.call_for_each_replica(model_fn)
self.assertTrue(distribute_utils.is_mirrored(result))
self.assertEqual("foo:0", result.name)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredVariableUpdateTest(test.TestCase):
# The following tests check assign, assign_add and assign_sub on Mirrored
# variables in replica and cross replica context.
def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
distribution):
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
def model_fn():
return mirrored_var.assign(5.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(5.0, self.evaluate(mirrored_var))
def testAssignMirroredVarReplicaContextWithSum(self, distribution):
# Test that we don't reduce a non-per-replica value with the "sum"
# aggregation type.
def var_fn():
v = variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
return v
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
def model_fn():
return mirrored_var.assign(5.0)
with self.assertRaisesRegex(
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
"with the given reduce op ReduceOp.SUM."):
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(1.0, name="foo")
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
self.assertEqual(6.0, mirrored_var_result)
def testAssignMirroredVarReplicaContext(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(0.5, self.evaluate(mirrored_var))
def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
return mirrored_var.assign(5.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(5.0, self.evaluate(mirrored_var))
def testAssignAddMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(1.0, name="foo")
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
# read_value == True
mirrored_var_result = self.evaluate(
mirrored_var.assign_add(6.0, read_value=True))
self.assertEqual(7.0, mirrored_var_result)
self.assertEqual(
7.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[0]))
self.assertEqual(
7.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[1]))
self.assertEqual(
distribution.extended.worker_devices[0], mirrored_var._devices[0])
self.assertEqual(
distribution.extended.worker_devices[1], mirrored_var._devices[1])
# read_value == False
self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
self.assertEqual(
9.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[0]))
self.assertEqual(
9.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[1]))
self.assertEqual(
distribution.extended.worker_devices[0], mirrored_var._devices[0])
self.assertEqual(
distribution.extended.worker_devices[1], mirrored_var._devices[1])
def testAssignAddMirroredVarReplicaContext(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(1.5, self.evaluate(mirrored_var))
def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution):
def var_fn():
return variable_scope.variable(
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
def model_fn():
return mirrored_var.assign_add(5.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(6.0, self.evaluate(mirrored_var))
def testAssignSubMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(5.0, name="foo")
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEqual(3.0, mirrored_var_result)
self.assertEqual(
3.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[0]))
self.assertEqual(
3.0,
self.evaluate(
distribution.experimental_local_results(mirrored_var)[1]))
self.assertEqual(
distribution.extended.worker_devices[0], mirrored_var._devices[0])
self.assertEqual(
distribution.extended.worker_devices[1], mirrored_var._devices[1])
def testAssignSubMirroredVarReplicaContext(self, distribution):
def var_fn():
return variable_scope.variable(
5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
def model_fn():
value = math_ops.cast(
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.5, self.evaluate(mirrored_var))
def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution):
def var_fn():
return variable_scope.variable(
5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
def model_fn():
return mirrored_var.assign_sub(1.0)
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.0, self.evaluate(mirrored_var))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredAndSyncOnReadVariableInitializerTest(test.TestCase):
def testAssignMirroredVarInitializer(self, distribution):
# This test is not eager compatible since in eager variables are initialized
# upon construction instead of once the initialization op is run.
with context.graph_mode():
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v
with distribution.scope():
mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
self.evaluate(mirrored_var.initializer)
self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
def testAssignReplicaLocalVarInitializer(self, distribution):
# This test is not eager compatible since in eager variables are initialized
# upon construction instead of once the initialization op is run.
with context.graph_mode():
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
return v_sum
with distribution.scope():
sync_on_read_var = distribution.extended.call_for_each_replica(
model_fn)
self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var))
self.assertFalse(self.evaluate(sync_on_read_var.is_initialized()))
self.evaluate(sync_on_read_var.initializer)
self.assertTrue(self.evaluate(sync_on_read_var.is_initialized()))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class SyncOnReadVariableAssignTest(test.TestCase):
def testAssignReplicaLocalVarSumAggregation(self, distribution):
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
return v_sum
with distribution.scope():
sync_on_read_var = distribution.extended.call_for_each_replica(model_fn)
self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var))
self.evaluate(variables.global_variables_initializer())
# Each replica has a value of 1.0 assigned to it in replica context.
# When we read the value using `read_var` we should see the SUM of each of
# values on each of the replicas.
self.assertEqual(2.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
# Assigning 6.0 in cross replica context will assign a value of
# 6.0/num_replicas to each replica.
tlv_ops = sync_on_read_var.assign(6.0)
self.evaluate(tlv_ops)
# On reading the sync on read var we should get the assigned value back.
# The value on all the replicas are added before being returned by
# `read_var`.
self.assertEqual(6.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
def testAssignReplicaLocalVarMeanAggregation(self, distribution):
def model_fn():
v_sum = variable_scope.variable(
1.0,
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.MEAN)
return v_sum
with distribution.scope():
sync_on_read_var = distribution.extended.call_for_each_replica(model_fn)
self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var))
self.evaluate(variables.global_variables_initializer())
# Each replica has a value of 1.0 assigned to it in replica context.
# When we read the value using `read_var` we should see the MEAN of values
# on all replicas which is the value assigned in replica context.
self.assertEqual(1.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
tlv_ops = sync_on_read_var.assign(6.0)
self.evaluate(tlv_ops)
# On reading the sync on read var we should get the MEAN of all values
# which is equal to the value assigned.
self.assertEqual(6.0, self.evaluate(
distribution.extended.read_var(sync_on_read_var)))
class MockModel(object):
def __init__(self, two_variables=False):
self.variables = []
self.variables.append(variable_scope.variable(1.25, name="dummy_var1"))
if two_variables:
self.variables.append(variable_scope.variable(2.0, name="dummy_var2"))
def __call__(self, factor=2):
x = factor * self.variables[0]
if len(self.variables) > 1:
x += self.variables[1]
return x
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
class MirroredStrategyDefunTest(test.TestCase):
def _call_and_check(self, distribution, model_fn, inputs, expected_result,
defuns, two_variables=False):
cpu_dev = device_util.canonicalize("CPU:0")
gpu_dev = device_util.canonicalize("GPU:0")
devices = [cpu_dev, gpu_dev]
with distribution.scope():
mock_model = MockModel(two_variables)
self.evaluate(variables.global_variables_initializer())
result = distribution.extended.call_for_each_replica(
model_fn, args=[mock_model] + inputs)
for r in range(len(devices)):
device_result = distribute_utils.select_replica(r, result)
device_expected_result = distribute_utils.select_replica(
r, expected_result)
self.assertAllClose(device_expected_result,
self.evaluate(device_result))
for defun in defuns:
# `Function`s are specialized to the current device stack, so
# call_for_each has one trace per device. To check that the expected set
# of variables was accessed on each trace, we first retrieve each
# device-specific graph function.
per_replica_graph_functions = (
distribution.extended.call_for_each_replica(
defun.get_concrete_function, args=[mock_model] + inputs))
for i in range(len(devices)):
graph_function = distribution.experimental_local_results(
per_replica_graph_functions)[i]
# TODO(b/129555712): re-enable an assertion here that the two sets of
# variables are the same.
# self.assertEqual(set(graph_function.graph.variables),
# set(mock_model.variables))
del graph_function
def testVariableInDefun(self, distribution):
@function.defun
def times_two(mock_model):
return mock_model()
def model_fn(mock_model):
return times_two(mock_model)
self._call_and_check(distribution, model_fn, [], 2.5, [times_two])
def testVariableInNestedDefun(self, distribution):
@function.defun
def times_two(mock_model):
return mock_model()
@function.defun
def two_x_plus_one(mock_model):
return times_two(mock_model) + 1
def model_fn(mock_model):
return two_x_plus_one(mock_model)
self._call_and_check(distribution, model_fn, [], 3.5,
[times_two, two_x_plus_one])
def testTwoVariablesInNestedDefun(self, distribution):
@function.defun
def fn1(mock_model):
return mock_model()
@function.defun
def fn2(mock_model):
return fn1(mock_model) + 1
def model_fn(mock_model):
return fn2(mock_model)
self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2],
two_variables=True)
def testGradientTapeOverNestedDefuns(self, distribution):
@function.defun
def fn1(mock_model):
return mock_model()
@function.defun
def fn2(mock_model):
return fn1(mock_model) + 1
def model_fn(mock_model):
with backprop.GradientTape(persistent=True) as gtape:
result = fn2(mock_model)
grads = gtape.gradient(result,
[v._get() for v in mock_model.variables])
return grads
self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2],
two_variables=True)
def testPassPerReplica(self, distribution):
@function.defun
def fn1(mock_model, factor):
return mock_model(factor)
factors = values.PerReplica((5.0, 3.0))
expected_result = values.PerReplica((5.0 * 1.25, 3.0 * 1.25))
self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
@combinations.generate(
combinations.combine(
distribution=[
combinations.NamedDistribution(
"Mirrored",
# pylint: disable=g-long-lambda
lambda: mirrored_strategy.MirroredStrategy(
devices=mirrored_strategy.all_local_devices(),
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice(
),
),
required_gpus=1)
],
mode=["graph"]))
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,
strategy_test_lib.DistributionTestBase):
def _configure_distribution_strategy(self, distribution):
cluster_spec = server_lib.ClusterSpec({
"worker": ["/job:worker/task:0", "/job:worker/task:1"]
})
distribution.configure(cluster_spec=cluster_spec)
def test_num_replicas_in_sync(self, distribution):
self._configure_distribution_strategy(distribution)
# We calculate the total number of gpus across the workers(2) specified in
# the cluster spec.
self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync)
def testMinimizeLossGraph(self, distribution):
self._configure_distribution_strategy(distribution)
self._test_minimize_loss_graph(distribution, learning_rate=0.05)
def testDeviceScope(self, distribution):
"""Test the device scope of multi-worker MirroredStrategy."""
self._configure_distribution_strategy(distribution)
with distribution.scope():
a = constant_op.constant(1.)
with ops.device("/cpu:0"):
b = constant_op.constant(1.)
self.assertEqual(a.device, "/job:worker/task:0")
self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
def testMakeInputFnIteratorWithDataset(self, distribution):
self._configure_distribution_strategy(distribution)
dataset_fn = lambda: dataset_ops.Dataset.range(100)
num_gpus = context.num_gpus()
num_workers = 2
expected_values = [[i+j for j in range(num_gpus)] * num_workers
for i in range(0, 100, num_gpus)]
with context.graph_mode(), self.cached_session() as sess:
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
dataset_fn,
expected_num_replicas_in_sync=num_workers*num_gpus,
expected_num_input_pipelines=num_workers,
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator, distribution.extended.worker_devices, expected_values, sess)
def testMakeInputFnIteratorWithCallable(self, distribution):
self._configure_distribution_strategy(distribution)
def fn():
dataset = dataset_ops.Dataset.range(100)
it = dataset_ops.make_one_shot_iterator(dataset)
return it.get_next
num_gpus = context.num_gpus()
num_workers = 2
expected_values = []
for i in range(0, 100, num_gpus):
expected_values.append([i+j for j in range(num_gpus)] * num_workers)
with context.graph_mode(), self.cached_session() as sess:
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=num_workers*num_gpus,
expected_num_input_pipelines=num_workers,
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
iterator, distribution.extended.worker_devices, expected_values, sess,
test_reinitialize=False, ignore_order=True)
def testUpdateConfigProto(self, distribution):
distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]})
config_proto = config_pb2.ConfigProto()
new_config = distribution.update_config_proto(config_proto)
# Verify isolate_session_state
self.assertTrue(new_config.isolate_session_state)
@combinations.generate(
combinations.combine(
distribution=[
combinations.NamedDistribution(
"Mirrored",
# pylint: disable=g-long-lambda
lambda: mirrored_strategy.MirroredStrategy(
devices=["/job:worker/task:0/gpu:{}".format(
i) for i in range(context.num_gpus())]),
required_gpus=1)
],
mode=["graph"]))
class RemoteSingleWorkerMirroredStrategyGraph(
multi_worker_test_base.SingleWorkerTestBaseGraph,
strategy_test_lib.RemoteSingleWorkerMirroredStrategyBase):
def _get_num_gpus(self):
return context.num_gpus()
def testNumReplicasInSync(self, distribution):
self._testNumReplicasInSync(distribution)
def testMinimizeLoss(self, distribution):
self._testMinimizeLoss(distribution)
def testDeviceScope(self, distribution):
self._testDeviceScope(distribution)
def testMakeInputFnIteratorWithDataset(self, distribution):
self._testMakeInputFnIteratorWithDataset(distribution)
def testMakeInputFnIteratorWithCallable(self, distribution):
self._testMakeInputFnIteratorWithCallable(distribution)
class MultiWorkerMirroredStrategyTestWithChief(
multi_worker_test_base.MultiWorkerTestBase,
strategy_test_lib.DistributionTestBase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers and 1 chief."""
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=2, num_ps=0, has_chief=True)
cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
def _make_cross_device_ops(self):
return cross_device_ops_lib.ReductionToOneDevice()
def testMinimizeLossGraph(self):
with context.graph_mode():
strategy = mirrored_strategy.MirroredStrategy(
cross_device_ops=self._make_cross_device_ops())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
def testMinimizeLossGraphMirroredStrategy(self):
with context.graph_mode():
strategy = mirrored_strategy.MirroredStrategy(
mirrored_strategy.all_local_devices(),
cross_device_ops=self._make_cross_device_ops())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
def testMinimizeLossGraphMirroredStrategyWithOneNode(self):
with context.graph_mode():
cluster_spec = {}
cluster_spec["chief"] = self._cluster_spec["chief"]
tf_config = {"cluster": cluster_spec}
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
strategy = mirrored_strategy.MirroredStrategy()
if context.num_gpus() > 0:
self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
cross_device_ops_lib.NcclAllReduce)
else:
self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
cross_device_ops_lib.ReductionToOneDevice)
self.skipTest("b/130551176, run the following once fixed.")
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
def testInitializeFromTFConfig(self):
with context.graph_mode():
tf_config = {"cluster": self._cluster_spec}
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
strategy = mirrored_strategy.MirroredStrategy(
cross_device_ops=self._make_cross_device_ops())
self.assertEqual(
max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync)
def testSummaryForReplicaZeroOnly(self):
with context.graph_mode():
strategy = mirrored_strategy.MirroredStrategy(
mirrored_strategy.all_local_devices(),
cross_device_ops=self._make_cross_device_ops())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_summary_for_replica_zero_only(strategy)
class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph"]))
def testMirroredVariableAsStopGradient(self, distribution):
with distribution.scope():
inp = constant_op.constant(1.0)
x = variables.Variable(1.0)
y = inp*x
grads = gradients.gradients(x, y, stop_gradients=x)
self.assertIsNone(grads[0])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["eager"]))
class FunctionTest(test.TestCase, parameterized.TestCase):
def testBackwardFunctionDevicePlacement(self, distribution):
with distribution.scope():
w = variable_scope.variable([1.5], name="w")
b = variable_scope.variable([0.5], name="b")
@def_function.function
def forward(x, w, b):
return x * w + b
x = array_ops.identity([1.0], name="x_useless")
concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
with distribution.scope():
def replica_fn():
with backprop.GradientTape() as t:
x = array_ops.identity([1.0], name="x")
loss = concrete_forward(x, w._get(), b._get()) - [1.0]
return t.gradient(loss, [w, b])
def step_fn():
return distribution.run(replica_fn)
context.enable_run_metadata()
g1, g2 = step_fn()
run_metadata = context.export_run_metadata()
context.disable_run_metadata()
self.assertEqual(self.evaluate(g1._primary), 1.0)
self.assertEqual(self.evaluate(g2._primary), 1.0)
# Verify that this node runs on both devices.
node_name = "gradients_mul_grad_mul_1_x"
devices_for_this_node = set()
for partition_graph in run_metadata.partition_graphs:
for node in partition_graph.node:
if node.name == node_name:
devices_for_this_node.add(node.device)
devices = [device_util.resolve("/device:GPU:0"),
device_util.resolve("/device:CPU:0")]
self.assertSetEqual(devices_for_this_node, set(devices))
def testFuctionPreservesAutoGraph(self, distribution):
def f():
self.assertTrue(converter_testing.is_inside_generated_code())
return 1
with distribution.scope():
@def_function.function
def replica_fn():
return f()
distribution.run(replica_fn)
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
return array_ops.identity(replica_id)
def _replica_id_as_int():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if isinstance(replica_id, ops.Tensor):
replica_id = tensor_util.constant_value(replica_id)
return replica_id
if __name__ == "__main__":
test.main()