STT-tensorflow/tensorflow/python/distribute/values_test.py

2280 lines
85 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the distributed values library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import itertools
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model.model_utils import mode_keys
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.types import core
from tensorflow.python.util import nest
class DistributedValuesTest(test.TestCase, parameterized.TestCase):
def testGetEager(self):
one = constant_op.constant(1)
two = constant_op.constant(2)
v = values.DistributedValues((one, two))
self.assertEqual(one, v._get())
with distribute_lib.ReplicaContext(None, 1):
self.assertEqual(two, v._get())
def testGetGraph(self):
with context.graph_mode(), ops.Graph().as_default():
one = constant_op.constant(1)
two = constant_op.constant(2)
v = values.DistributedValues((one, two))
self.assertEqual(one, v._get())
with distribute_lib.ReplicaContext(None, 1):
self.assertEqual(two, v._get())
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueFromTensor(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
single_value = constant_op.constant(1)
def value_fn(ctx):
del ctx
return single_value
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
self.assertAllEqual(
distribution.experimental_local_results(distributed_values),
constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
array_value = np.array([1., 2., 3.])
def value_fn(ctx):
del ctx
return array_value
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
self.assertLen(local_results, distribution.num_replicas_in_sync)
for result in local_results:
self.assertAllEqual(result, [1., 2., 3.])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueTupleConstant(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
tuple_value = (1., 2., 3.)
def value_fn(ctx):
del ctx
return tuple_value
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
for result in local_results:
self.assertAllEqual(result, (1., 2., 3.))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
tuple_value = (1., 2., 3.)
def value_fn(ctx):
per_replica = []
for val in tuple_value:
per_replica.append(val * ctx.replica_id_in_sync_group)
return per_replica
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(
distribute_utils.select_replica(i, distributed_values),
(1. * i, 2. * i, 3. * i))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueSpareTensor(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
def value_fn(ctx):
del ctx
return sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(
sparse_ops.sparse_tensor_to_dense(local_results[i]),
[[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueExtractFromArray(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
multiple_values = range(distribution.num_replicas_in_sync)
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(local_results[i], i)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueAndRun(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
@def_function.function
def run():
multiple_values = range(distribution.num_replicas_in_sync)
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
def computation(x):
return math_ops.square(x)
outputs = distribution.experimental_local_results(
distribution.run(computation,
args=(distributed_values,)))
return outputs
local_results = run()
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(local_results[i], i**2)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
# TODO(b/137795644): support CentralStroageStrategy
# strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
multiple_values = []
for i in range(distribution.num_replicas_in_sync):
multiple_values.append(constant_op.constant(1.0))
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(distributed_values._values[i].device,
"/job:localhost/replica:0/task:0/device:CPU:0")
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
# TODO(b/137795644): support CentralStroageStrategy
# strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
worker_devices = distribution.extended.worker_devices
multiple_values = []
for i in range(distribution.num_replicas_in_sync):
with ops.device(worker_devices[i]):
multiple_values.append(array_ops.identity(1.0))
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(distributed_values._values[i].device,
worker_devices[i])
class DistributedDelegateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testGetAttr(self):
class Foo(object):
def __init__(self, x):
self.x = x
v = values.DistributedDelegate((Foo(7), Foo(8)))
self.assertEqual(7, v.x)
with self.assertRaises(AttributeError):
_ = v.y
@test_util.run_in_graph_and_eager_modes
def testOperatorOverride(self):
v = values.DistributedDelegate((7, 8))
# v should act like int(7).
self.assertEqual(8, v + 1)
self.assertEqual(10, 3 + v)
self.assertEqual(14, v + v)
self.assertEqual(5, v - 2)
self.assertEqual(6, 13 - v)
self.assertEqual(0, v - v)
self.assertEqual(14, v * 2)
self.assertEqual(21, 3 * v)
self.assertEqual(49, v * v)
self.assertEqual(3.5, v / 2)
self.assertEqual(1.5, 10.5 / v)
self.assertEqual(3, v // 2)
self.assertEqual(2, 15 // v)
self.assertEqual(1, v % 2)
self.assertEqual(2, 16 % v)
# pylint: disable=g-generic-assert
self.assertTrue(v < 12)
self.assertTrue(v <= 12)
self.assertFalse(v > 12)
self.assertFalse(v >= 12)
self.assertFalse(12 < v)
self.assertFalse(12 <= v)
self.assertTrue(12 > v)
self.assertTrue(12 >= v)
# pylint: enable=g-generic-assert
self.assertEqual(3, v & 3)
self.assertEqual(3, 11 & v)
self.assertEqual(15, v | 8)
self.assertEqual(23, 16 | v)
self.assertEqual(4, v ^ 3)
self.assertEqual(12, 11 ^ v)
self.assertEqual(343, pow(v, 3))
self.assertEqual(3, pow(v, 3, 10))
self.assertEqual(128, pow(2, v))
self.assertEqual(-7, -v)
self.assertEqual(~7, ~v)
self.assertEqual(7, abs(v))
with self.assertRaises(TypeError):
_ = v[2]
@test_util.run_in_graph_and_eager_modes
def testCopy(self):
class Foo(object):
def __init__(self, x):
self.x = x
v = values.DistributedDelegate((Foo(7), Foo(8)))
v_shallow_copy = copy.copy(v)
self.assertEqual(v.x, v_shallow_copy.x)
v_deep_copy = copy.deepcopy(v)
self.assertEqual(v.x, v_deep_copy.x)
def _device_str(d):
return "/device:GPU:" + str(d)
def _nested_value(d):
return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
def _make_mirrored_val(init_val=5.0):
v = []
devices = ["/device:GPU:0", "/device:CPU:0"]
for d, _ in zip(devices, ["v", "v/replica"]):
with ops.device(d):
v.append(constant_op.constant(init_val))
return values.Mirrored(v)
def _make_mirrored():
v = []
devices = ["/device:GPU:0", "/device:CPU:0"]
for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
with ops.device(d):
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
mirrored = values.MirroredVariable(
None, v, variable_scope.VariableAggregation.SUM)
return mirrored
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"])
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
def _is_per_replica(self, result, expected, klass=values.PerReplica):
self.assertIsInstance(result, klass)
for i, exp in enumerate(expected):
self.assertEqual(exp, result.values[i])
def testNested(self):
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")))
self.assertIsInstance(result, tuple)
self.assertLen(result, 3)
self._is_per_replica(result[0], ["a1", "a2"])
self._is_per_replica(result[2], ["h1", "h2"])
self.assertIsInstance(result[1], list)
self.assertLen(result[1], 3)
self._is_per_replica(result[1][0], ["b1", "b2"])
self._is_per_replica(result[1][2], ["g1", "g2"])
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
self._is_per_replica(result[1][1]["e"], ["f1", "f2"])
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
distribute_utils.select_replica(0, result))
self.assertEqual(_nested_value("2"),
distribute_utils.select_replica(1, result))
# select_device_mirrored() should fail due to non-mirrored values
with self.assertRaises(TypeError):
distribute_utils.select_replica_mirrored(0, result)
with self.assertRaises(TypeError):
distribute_utils.select_replica_mirrored(1, result)
def testRegroupKeepsDictBasedClass(self):
class DictBasedClass(dict):
"""Dummy class inherited from a dict."""
result = distribute_utils.regroup(
(DictBasedClass(a="a1", b="b1"), DictBasedClass(a="a2", b="b2")))
self.assertIsInstance(result, DictBasedClass)
self._is_per_replica(result["a"], ["a1", "a2"])
self._is_per_replica(result["b"], ["b1", "b2"])
def testWrapClass(self):
# Normally a mirrored value would be the same across devices, but
# for a test it is convenient to be able to tell the values apart.
result = distribute_utils.regroup((_nested_value("1"), _nested_value("2")),
values.Mirrored)
self.assertIsInstance(result, tuple)
self.assertLen(result, 3)
self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)
self.assertIsInstance(result[1], list)
self.assertLen(result[1], 3)
self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
distribute_utils.select_replica(0, result))
self.assertEqual(_nested_value("2"),
distribute_utils.select_replica(1, result))
# Values are marked as mirrored, so select_device_mirrored() is allowed.
self.assertEqual(_nested_value("1"),
distribute_utils.select_replica_mirrored(0, result))
self.assertEqual(_nested_value("2"),
distribute_utils.select_replica_mirrored(1, result))
def testWrapAListOfTwoTuples(self):
result = distribute_utils.regroup([("1", "2"), ("3", "4")])
self.assertIsInstance(result, tuple)
self.assertLen(result, 2)
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_one_cpu,
],
mode=["graph", "eager"],
))
def testMirroredContainer(self, distribution):
with distribution.scope():
v = variable_scope.variable(
1., aggregation=variable_scope.VariableAggregation.SUM)
self.assertTrue(distribute_utils.is_distributed_variable(v))
self.assertTrue(distribute_utils.is_distributed_variable(
distribute_utils.regroup(v.values)))
def testSameId(self):
foo = object()
result = distribute_utils.regroup((("a", foo), ("b", foo)))
self.assertIsInstance(result, tuple)
self.assertLen(result, 2)
self._is_per_replica(result[0], ["a", "b"])
self.assertIs(foo, result[1])
# Test select_replica(), should undo the merge done by regroup().
result_0 = distribute_utils.select_replica(0, result)
self.assertIsInstance(result_0, tuple)
self.assertLen(result_0, 2)
self.assertEqual("a", result_0[0])
self.assertIs(foo, result_0[1])
result_1 = distribute_utils.select_replica(1, result)
self.assertIsInstance(result_1, tuple)
self.assertLen(result_1, 2)
self.assertEqual("b", result_1[0])
self.assertIs(foo, result_1[1])
def testOneDevice(self):
result = distribute_utils.regroup((_nested_value("1"),))
# On one device regroup() and select_replica() are basically identity.
self.assertEqual(_nested_value("1"), result)
self.assertEqual(_nested_value("1"),
distribute_utils.select_replica(0, result))
def testNamedTuple(self):
# We include toy implementations of Scaffold and EstimatorSpec to
# avoid a dependency on Estimator here.
class Scaffold(object):
pass
class EstimatorSpec(collections.namedtuple(
"EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])):
def __new__(cls, mode, loss, train_op, scaffold=None):
return super(EstimatorSpec, cls).__new__(
cls, mode=mode, loss=loss, train_op=train_op,
scaffold=scaffold or Scaffold())
with context.graph_mode(), ops.Graph().as_default():
created_estimator_specs = []
for device_id in range(3):
spec = EstimatorSpec(
mode=mode_keys.EstimatorModeKeys.TRAIN,
loss=constant_op.constant(device_id / 2),
train_op=array_ops.identity(constant_op.constant(device_id)))
created_estimator_specs.append(spec)
merged_estimator_spec = distribute_utils.regroup(created_estimator_specs)
self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
merged_estimator_spec.mode)
for device_id in range(3):
self.assertEqual(created_estimator_specs[device_id].loss,
merged_estimator_spec.loss.values[device_id])
self.assertEqual(created_estimator_specs[device_id].train_op,
merged_estimator_spec.train_op.values[device_id])
# Scaffold is populated by `EstimatorSpec.__new__`.
self.assertEqual(created_estimator_specs[device_id].scaffold,
merged_estimator_spec.scaffold.values[device_id])
self.assertIsInstance(created_estimator_specs[device_id].scaffold,
Scaffold)
# Also test that we can undo the merge using select_replica()
self.assertEqual(created_estimator_specs[device_id],
distribute_utils.select_replica(
device_id, merged_estimator_spec))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
synchronization=[
variables_lib.VariableSynchronization.ON_READ,
variables_lib.VariableSynchronization.ON_WRITE,
],
aggregation=[
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
],
mode=["graph", "eager"]))
class DistributedVariableTest(test.TestCase, parameterized.TestCase):
def testExtendsVariable(self, distribution, synchronization, aggregation):
with distribution.scope():
v = variables_lib.Variable(
1., synchronization=synchronization, aggregation=aggregation)
self.assertIsInstance(v, variables_lib.Variable)
def testCheckpointing(self, distribution, synchronization, aggregation):
with distribution.scope():
v = variables_lib.Variable(
constant_op.constant([1., 2., 3., 4]),
synchronization=synchronization,
aggregation=aggregation)
self.evaluate(v.initializer)
before_save = self.evaluate(v.read_value())
# Save random weights into checkpoint.
checkpoint = trackable_utils.Checkpoint(v=v)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
with self.test_session():
save_path = checkpoint.save(prefix)
# Assign inverted value.
self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
after_assign = self.evaluate(v.read_value())
self.assertNotAllClose(before_save, after_assign)
# Restore from the checkpoint.
with self.test_session():
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
after_restore = self.evaluate(v)
self.assertAllClose(before_save, after_restore)
def testTraceback(self, distribution, synchronization, aggregation):
if context.executing_eagerly():
self.skipTest("does not apply to eager")
with distribution.scope():
variable_scope.get_variable(
name="testVar",
initializer=1.,
use_resource=True,
synchronization=synchronization,
aggregation=aggregation)
with self.assertRaisesRegex(ValueError,
"Variable testVar already exists"):
variable_scope.get_variable(
name="testVar",
initializer=1.,
use_resource=True,
synchronization=synchronization,
aggregation=aggregation)
def testSelectReplica(self, distribution, synchronization, aggregation):
with distribution.scope():
v = variables_lib.Variable(
1., synchronization=synchronization, aggregation=aggregation)
self.assertIs(v, distribute_utils.select_replica(0, v))
def testIsTensorLike(self, distribution, synchronization, aggregation):
if isinstance(distribution.extended,
tpu_strategy.TPUExtended) and context.executing_eagerly():
self.skipTest("TPU doesn't support pure eager")
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
# In cross replica context.
self.assertIsInstance(v, core.Tensor)
# In replica context.
distribution.run(
lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
aggregation):
if isinstance(distribution.extended, tpu_strategy.TPUExtended):
if context.executing_eagerly():
self.skipTest("TPU doesn't support pure eager")
else:
self.skipTest("b/152076846")
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
def assert_is_tensor_like(v):
# We can't use Python literals because they are treated as non-distributed
# values is not allowed when aggregation is SUM. See
# `cross_device_ops.reduce_non_distributed_value`.
delta = array_ops.identity(1.)
self.assertIsInstance(v.assign(delta), core.Tensor)
self.assertIsInstance(v.assign_sub(delta), core.Tensor)
self.assertIsInstance(v.assign_add(delta), core.Tensor)
# In cross replica context we return a PerReplica which is not Tensor like
# all the time yet.
if (synchronization == variables_lib.VariableSynchronization.ON_READ and
aggregation != variables_lib.VariableAggregation.SUM):
assert_is_tensor_like(v)
# In replica context.
distribution.run(assert_is_tensor_like, args=(v,))
def testAssignSignature(self, distribution, synchronization, aggregation):
# This test verifies assign*() can be called in the same way as normal
# variables.
with distribution.scope():
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
def assign():
one = constant_op.constant(1.)
v.assign(one, True, "assign", False)
# TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing
# value as a keyword argument.
v.assign(one, use_locking=True, name="assign", read_value=False)
v.assign_add(one, True, "assign", False)
v.assign_add(one, use_locking=True, name="assign", read_value=False)
v.assign_sub(one, True, "assign", False)
v.assign_sub(one, use_locking=True, name="assign", read_value=False)
# Return something for graph mode to fetch.
return constant_op.constant(1)
self.evaluate(variables_lib.global_variables_initializer())
if not (synchronization == variables_lib.VariableSynchronization.ON_READ
and aggregation == variables_lib.VariableAggregation.SUM):
self.evaluate(distribution.experimental_local_results(assign()))
if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and
context.executing_eagerly()):
self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
def testPackedVariable(self, distribution):
with distribution.scope():
v0 = variables_lib.Variable(0.)
self.assertIsNone(v0._packed_var)
distribution._enable_packed_variable_in_eager_mode = True
with distribution.scope():
v1 = variables_lib.Variable(0)
self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
devices = v1._devices
for i in range(1, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
v1.assign(i)
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[0])
self.assertEqual(self.evaluate(val.read_value()), 0)
for i in range(0, len(devices)):
with distribute_lib.ReplicaContext(distribution, i):
val = v1._get()
self.assertIsInstance(val, packed.PackedVarAndDevice)
self.assertEqual(val.device, devices[i])
self.assertEqual(self.evaluate(val.read_value()), i)
class MirroredVariableTest(test.TestCase, parameterized.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
@test_util.run_in_graph_and_eager_modes(config=config)
def testProperties(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
mirrored = _make_mirrored()
v = mirrored.values[0]
self.assertEqual(v.name, mirrored.name)
self.assertEqual(v.dtype, mirrored.dtype)
self.assertEqual(v.shape, mirrored.shape)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
mirrored = values.MirroredVariable(
None, (v,), variable_scope.VariableAggregation.MEAN)
self.assertEqual(v.name, mirrored.name)
self.assertEqual(v.dtype, mirrored.dtype)
self.assertEqual(v.shape, mirrored.shape)
def _assign_mirrored(self, v, new):
for var, n in zip(v.values, new):
self.evaluate(var.assign(n))
def _save_return_saver(self, sess, var):
saver = saver_lib.Saver(var_list=[var])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
return saver.save(sess, prefix), saver
def _save(self, sess, var):
save_path, _ = self._save_return_saver(sess, var)
return save_path
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveAndRestoreMirroredOneGraph(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
with self.cached_session(config=self.config) as sess:
mirrored = _make_mirrored()
v = mirrored.values
# Overwrite the initial values.
self._assign_mirrored(mirrored, [3., 4.])
# Saves the current value of v[0], 3.
save_path, saver = self._save_return_saver(sess, mirrored)
# Change the values between save and restore.
self._assign_mirrored(mirrored, [5., 6.])
# Restores the saved value of 3. to both variables.
saver.restore(sess, save_path)
self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
def _save_mirrored(self):
"""Save variables with mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
mirrored = _make_mirrored()
# Overwrite the initial values.
self._assign_mirrored(mirrored, [3., 4.])
# Saves the current value of v[0], 3.
save_path = self._save(sess, mirrored)
# Change the values between save and restore.
self._assign_mirrored(mirrored, [5., 6.])
return save_path
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(3.))
# Saves the current value of var, 3.
save_path = self._save(sess, var)
# Change the values between save and restore.
self.evaluate(var.assign(5.))
return save_path
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(8.))
# Restores the saved value of 3. to `var`.
saver = saver_lib.Saver(var_list=[var])
saver.restore(sess, save_path)
self.assertEqual(3., self.evaluate(var))
def _restore_mirrored(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
mirrored = _make_mirrored()
v = mirrored.values
# Overwrite the initial values.
self._assign_mirrored(mirrored, [7., 8.])
# Restores the saved value of 3. to both variables.
saver = saver_lib.Saver(var_list=[mirrored])
saver.restore(sess, save_path)
self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveMirroredRestoreMirrored(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
save_path = self._save_mirrored()
self._restore_mirrored(save_path)
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveMirroredRestoreNormal(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
save_path = self._save_mirrored()
self._restore_normal(save_path)
@test_util.run_in_graph_and_eager_modes(config=config)
def testSaveNormalRestoreMirrored(self):
if context.num_gpus() < 1 and context.executing_eagerly():
# Graph mode can work without GPU because the Placer "moves" the
# variable to a CPU. In other words, if there is no GPU available, but
# user requested to create a variable on GPU, Placer will ignore the
# user request and assign the VarHandleOp to CPU. This requires
# soft_placement, which is on by default.
self.skipTest("A GPU is not available for this test in eager mode.")
save_path = self._save_normal()
self._restore_mirrored(save_path)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph"]))
def testFetchAMirroredVariable(self, distribution):
with self.session(graph=ops.Graph()) as sess, distribution.scope():
with ops.device("/device:GPU:0"):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
mirrored = values.MirroredVariable(
distribution, (v,), variable_scope.VariableAggregation.MEAN)
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
with distribution.scope():
v = variables_lib.Variable(1.0, name="foo")
@def_function.function
def mytest():
def model_fn():
v.assign(5.0)
return v.read_value()
return distribution.run(model_fn)
mytest()
self.assertAllEqual([5.0, 5.0], self.evaluate(v.values))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"]))
def testValueInReplicaContext(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
1., aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def f():
with ops.control_dependencies([v.assign_add(1.)]):
return v.value()
results = self.evaluate(
distribution.experimental_local_results(
distribution.run(f)))
for value in results:
self.assertEqual(2., value)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"]))
def testAssignOutOfScope(self, distribution):
with distribution.scope():
mirrored = variables_lib.Variable(1.)
self.evaluate(mirrored.assign(3.))
self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
for component in mirrored.values:
self.assertEqual(self.evaluate(component.read_value()), 3.)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testAssignAggregationMeanDTypeNonFloat(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
1,
aggregation=variable_scope.VariableAggregation.MEAN,
dtype=dtypes.int32)
self.evaluate(v.initializer)
@def_function.function
def assign():
ctx = distribution_strategy_context.get_replica_context()
return v.assign(ctx.replica_id_in_sync_group)
# disallow assign() with distributed value in replica context.
with self.assertRaisesRegex(ValueError,
"Cannot update non-float variables"):
self.evaluate(
distribution.experimental_local_results(
distribution.run(assign)))
# allow assign() with same value in replica context.
@def_function.function
def assign_same():
return v.assign(2)
self.evaluate(
distribution.experimental_local_results(
distribution.run(assign_same)))
self.assertEqual(self.evaluate(v.read_value()), 2)
# allow assign() with mirrored variable in replica context.
with distribution.scope():
v2 = variables_lib.Variable(
3,
aggregation=variable_scope.VariableAggregation.SUM,
dtype=dtypes.int32)
self.evaluate(v2.initializer)
@def_function.function
def assign_mirrored():
return v.assign(v2)
self.evaluate(
distribution.experimental_local_results(
distribution.run(assign_mirrored)))
self.assertEqual(self.evaluate(v.read_value()), 3)
# allow assign() in cross replica context.
with distribution.scope():
self.evaluate(v.assign(4))
self.assertEqual(self.evaluate(v.read_value()), 4)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testInitializedToSameValueInsideEagerRun(self, distribution):
v = [None]
@def_function.function
def step():
def f():
if v[0] is None:
v[0] = variables_lib.Variable(random_ops.random_normal([]))
distribution.run(f)
context.set_global_seed(None)
step()
vals = self.evaluate(v[0].values)
self.assertAllEqual(vals[0], vals[1])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"]))
def testAggregationOnlyFirstReplica(self, distribution):
with distribution.scope():
v = variable_scope.variable(
15.,
synchronization=variables_lib.VariableSynchronization.ON_WRITE,
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def assign():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
per_replica_results = self.evaluate(distribution.experimental_local_results(
distribution.run(assign)))
# The per-replica values should always match the first replicas value.
self.assertAllEqual(
array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["eager"]))
def testInitScope(self, distribution):
class C(object):
pass
obj = C()
obj.w = None
obj.v = None
@def_function.function
def assign():
with ops.init_scope():
if obj.w is None:
obj.w = variables_lib.Variable(
0, aggregation=variables_lib.VariableAggregation.MEAN)
obj.v = variables_lib.Variable(
obj.w.read_value(),
aggregation=variables_lib.VariableAggregation.MEAN)
return obj.v.assign_add(2)
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
def testOperatorOverride(self, distribution):
with distribution.scope():
v = variable_scope.variable(
1, aggregation=variables_lib.VariableAggregation.MEAN)
self.assertEqual(2, self.evaluate(v + 1))
@def_function.function
def add():
return v + 1
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(add)))
self.assertAllEqual([2, 2], per_replica_results)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testAssignAdd(self, distribution):
with distribution.scope():
v = variable_scope.variable(
1, aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def assign():
return v.assign_add(2)
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
# The per-replica values should always match the first replicas value.
self.assertAllEqual([3, 3], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterSub(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(v.initializer)
@def_function.function
def scatter_sub():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops.stack([
math_ops.cast(replica_id, dtypes.float32),
math_ops.cast(replica_id + 1, dtypes.float32)
]),
indices=array_ops.stack([replica_id, replica_id + 1]),
dense_shape=(3,))
return v.scatter_sub(value)
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_sub)))
self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterAdd(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(v.initializer)
@def_function.function
def scatter_add():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops.stack([replica_id, replica_id + 1]),
indices=array_ops.stack([replica_id, replica_id + 1]),
dense_shape=(3,))
return v.scatter_add(value)
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_add)))
self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterDiv(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(v.initializer)
@def_function.function
def scatter_div():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops.reshape(replica_id + 2, [1]),
indices=array_ops.reshape(replica_id, [1]),
dense_shape=(3,))
return v.scatter_div(value)
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_div)))
self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterMul(self, distribution):
with distribution.scope():
v = variables_lib.Variable(
[2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(v.initializer)
@def_function.function
def scatter_mul():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
value = indexed_slices.IndexedSlices(
values=array_ops.reshape(
math_ops.cast(replica_id + 2, dtypes.float32), [1]),
indices=array_ops.reshape(replica_id, [1]),
dense_shape=(3,))
return v.scatter_mul(value)
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_mul)))
self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterMin(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable(
[0, 2, 0],
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def scatter_min(v):
value = indexed_slices.IndexedSlices(
values=array_ops.identity([1]),
indices=array_ops.identity([1]),
dense_shape=(3,))
return v.scatter_min(value)
with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_min, args=(v1,))))
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_min, args=(v2,))))
self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterMax(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable(
[0, 0, 0],
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def scatter_max(v):
value = indexed_slices.IndexedSlices(
values=array_ops.identity([1]),
indices=array_ops.identity([0]),
dense_shape=(3,))
return v.scatter_max(value)
with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_max, args=(v1,))))
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_max, args=(v2,))))
self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterUpdate(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable(
[0, 0, 0],
aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
self.evaluate(variables_lib.global_variables_initializer())
@def_function.function
def scatter_update(v):
value = indexed_slices.IndexedSlices(
values=array_ops.identity([3]),
indices=array_ops.identity([1]),
dense_shape=(3,))
return v.scatter_update(value)
with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_update, args=(v1,))))
per_replica_results = self.evaluate(
distribution.experimental_local_results(
distribution.run(scatter_update, args=(v2,))))
self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"]))
def testScatterOpsInCrossReplicaContext(self, distribution):
with distribution.scope():
v1 = variables_lib.Variable(
[1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
v2 = variables_lib.Variable([1, 1, 1])
self.evaluate(variables_lib.global_variables_initializer())
value = indexed_slices.IndexedSlices(
values=array_ops.identity([2]),
indices=array_ops.identity([0]),
dense_shape=(3,))
with distribution.scope():
self.evaluate(v1.scatter_add(value))
self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
self.evaluate(v2.scatter_min(value))
self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
def _make_replica_local(method, strategy=None):
if strategy is None:
devices = ("/device:GPU:0", "/device:CPU:0")
else:
devices = strategy.extended.worker_devices
v = []
for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
with ops.device(d):
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
var_cls = tpu_values.TPUSyncOnReadVariable
else:
var_cls = values.SyncOnReadVariable
replica_local = var_cls(strategy, v, method)
return v, replica_local
class SyncOnReadVariablePropertiesTest(test.TestCase):
config = config_pb2.ConfigProto()
config.allow_soft_placement = True
@test_util.run_in_graph_and_eager_modes(config=config)
def testProperties(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM)
self.assertEqual(v[0].constraint, replica_local.constraint)
self.assertEqual(v[0].name, replica_local.name)
self.assertEqual(v[0].dtype, replica_local.dtype)
self.assertEqual(v[0].shape, replica_local.shape)
self.assertEqual(variable_scope.VariableAggregation.SUM,
replica_local.aggregation)
@test_util.run_v2_only
def testCanPassToDefFun(self):
@def_function.function
def add1(x):
return x + 1
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
replica_local = values.SyncOnReadVariable(
None, (v,), variable_scope.VariableAggregation.MEAN)
self.assertEqual(2., self.evaluate(add1(replica_local)))
# TODO(b/144432582): Add variable aggregation type to combinations to simplify
# tests.
def strategy_and_run_tf_function_combinations():
# Test the combination of different strategies and whether a tf.function
# is passed into strategy.run."""
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"],
experimental_run_tf_function=[True, False]) + combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
],
mode=["graph", "eager"],
experimental_run_tf_function=[True])
class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
def _assign_replica_local(self, v, new):
for var, n in zip(v, new):
with ops.device(var.device):
self.evaluate(var.assign(n))
def _save_return_saver(self, sess, var):
saver = saver_lib.Saver(var_list=[var])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
return saver.save(sess, prefix), saver
def _save(self, sess, var):
save_path, _ = self._save_return_saver(sess, var)
return save_path
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testTensorConversion(self, distribution):
with context.graph_mode():
_, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
converted = ops.convert_to_tensor(replica_local, as_ref=False)
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, replica_local.dtype)
converted = ops.convert_to_tensor(replica_local, as_ref=True)
# Resources variable are converted to tensors as well when as_ref is True.
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, replica_local.dtype)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
with self.cached_session() as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [3., 4.])
with distribution.scope():
# Saves the current value of v[0] + v[1], 7.
save_path, saver = self._save_return_saver(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
# Restores the saved value of 7. which gets divided equally
# between the variables.
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
with self.cached_session() as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.MEAN, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [3., 4.])
with distribution.scope():
# Saves the current value of (v[0] + v[1])/2, 3.5.
save_path, saver = self._save_return_saver(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
# Restores the saved value of 3.5 to both variables.
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
def _save_replica_local_mean(self, distribution):
"""Save variables with mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.MEAN, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [3., 4.])
with distribution.scope():
# Saves the current value of (v[0] + v[1])/2, 3.5
save_path = self._save(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
return save_path
def _save_replica_local_sum(self, distribution):
"""Save variables with mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [1.5, 2.])
with distribution.scope():
# Saves the current value of v[0] + v[1], 3.5
save_path = self._save(sess, replica_local)
# Change the values between save and restore.
self._assign_replica_local(v, [5., 6.])
return save_path
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(3.5))
# Saves the current value of var, 3.5.
save_path = self._save(sess, var)
# Change the values between save and restore.
self.evaluate(var.assign(5.))
return save_path
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
# Overwrite the initial value.
self.evaluate(var.assign(8.))
# Restores the saved value of 3.5 to `var`.
saver = saver_lib.Saver(var_list=[var])
saver.restore(sess, save_path)
self.assertEqual(3.5, self.evaluate(var))
def _restore_replica_local_mean(self, save_path, distribution):
"""Restore to variables with mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.MEAN, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [7., 8.])
with distribution.scope():
# Restores the saved value of 3.5 to both variables.
saver = saver_lib.Saver(var_list=[replica_local])
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
def _restore_replica_local_sum(self, save_path, distribution):
"""Restore to variables with mirroring in a fresh graph."""
with self.session(graph=ops.Graph()) as sess:
v, replica_local = _make_replica_local(
variable_scope.VariableAggregation.SUM, distribution)
# Overwrite the initial values.
self._assign_replica_local(v, [7., 8.])
with distribution.scope():
# Restores the saved value of 3.5 to both variables.
saver = saver_lib.Saver(var_list=[replica_local])
saver.restore(sess, save_path)
self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
save_path = self._save_replica_local_mean(distribution)
self._restore_replica_local_mean(save_path, distribution)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
save_path = self._save_replica_local_sum(distribution)
self._restore_replica_local_sum(save_path, distribution)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
save_path = self._save_replica_local_mean(distribution)
self._restore_normal(save_path)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalSumRestoreNormal(self, distribution):
save_path = self._save_replica_local_sum(distribution)
self._restore_normal(save_path)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveNormalRestoreReplicaLocalMean(self, distribution):
save_path = self._save_normal()
self._restore_replica_local_mean(save_path, distribution)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveNormalRestoreReplicaLocalSum(self, distribution):
save_path = self._save_normal()
self._restore_replica_local_sum(save_path, distribution)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssign(self, distribution, experimental_run_tf_function):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return distribution.experimental_local_results(
distribution.run(update_fn))
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# VariableAggregation.SUM in cross-replica mode is tested below,
# VariableAggregation.NONE in cross-replica mode is not supported.
if cross_replica and aggregation in [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.NONE,
]:
continue
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignDtypeConversion(self, distribution,
experimental_run_tf_function):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return distribution.experimental_local_results(
distribution.run(update_fn))
updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
options = list(
x for x in itertools.product(updates, aggregations, [True, False]))
for update, aggregation, cross_replica in options:
# VariableAggregation.SUM in cross-replica mode is tested below,
# VariableAggregation.NONE in cross-replica mode is not supported.
if cross_replica and aggregation in [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.NONE,
]:
continue
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
fn, update_value = update
self.evaluate(assign(fn, v, update_value, cross_replica))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testAssignWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
for component in v._values:
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testAssignAddSubWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.SUM)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_add(1.))
with self.assertRaisesRegex(
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_sub(1.))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInReplicaContext(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
if experimental_run_tf_function:
read_var_fn = def_function.function(v.read_value)
else:
read_var_fn = v.read_value
results = self.evaluate(
distribution.experimental_local_results(
distribution.run(read_var_fn)))
for component, value in zip(v._values, results):
self.assertAllEqual(self.evaluate(component.read_value()), value)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInCrossReplicaContext(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.SUM,
# variables_lib.VariableAggregation.MEAN,
# variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
if isinstance(distribution, _TPU_STRATEGIES):
resolver = tpu_cluster_resolver.TPUClusterResolver("")
tpu_strategy_util.initialize_tpu_system(resolver)
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
def assign(v=v):
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
if experimental_run_tf_function:
assign = def_function.function(assign)
self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
num_replicas = distribution.num_replicas_in_sync
sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
if aggregation == variables_lib.VariableAggregation.SUM:
expected = sum_of_replica_values
elif aggregation == variables_lib.VariableAggregation.MEAN:
expected = sum_of_replica_values / num_replicas
else:
expected = 0
self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
self.assertEqual(expected, self.evaluate(v.value()), aggregation)
self.assertEqual(expected, self.evaluate(v), aggregation)
self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
aggregation)
# TODO(b/145574622): Re-enable this test once ReduceOp argument is
# respected on GPUs.
@combinations.generate(strategy_and_run_tf_function_combinations())
def disable_testAllReduce(self, distribution,
experimental_run_tf_function):
with distribution.scope():
v = variable_scope.variable(
2.,
synchronization=variables_lib.VariableSynchronization.ON_WRITE,
aggregation=variables_lib.VariableAggregation.MEAN)
self.evaluate(variables_lib.global_variables_initializer())
def all_reduce():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
dtypes.float32)
if experimental_run_tf_function:
all_reduce = def_function.function(all_reduce)
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(all_reduce)))
expected_result = []
for i in range(distribution.num_replicas_in_sync):
expected_result.append(2.0 * distribution.num_replicas_in_sync +
1.0 * i)
self.assertEqual(per_replica_results, tuple(expected_result))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignPerReplicaBeforeRead(self, distribution,
experimental_run_tf_function):
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
]
for aggregation in aggregations:
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
def assign(var=v):
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return var.assign(math_ops.cast(replica_id, dtypes.float32))
if experimental_run_tf_function:
assign = def_function.function(assign)
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(assign)))
expected_result = []
for i in range(distribution.num_replicas_in_sync):
expected_result.append(1.0 * i)
self.assertEqual(per_replica_results, tuple(expected_result))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.NONE)
self.evaluate(variables_lib.global_variables_initializer())
with self.assertRaisesRegex(
ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
self.evaluate(v.read_value())
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testInitializedToSameValueInsideEagerRun(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only")
v = [None]
@def_function.function
def step():
def f():
if v[0] is None:
v[0] = variables_lib.Variable(
random_ops.random_normal([]),
synchronization=variables_lib.VariableSynchronization.ON_READ)
distribution.run(f)
context.set_global_seed(None)
step()
vals = self.evaluate(v[0].values)
self.assertAllEqual(vals[0], vals[1])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
def testOperatorOverride(self, distribution):
with distribution.scope():
v = variable_scope.variable(
0.0,
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=variables_lib.VariableAggregation.MEAN)
@def_function.function
def assign():
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
# Assign different replicas with different values.
distribution.run(assign)
self.assertEqual(1.5, self.evaluate(v + 1))
@def_function.function
def add():
return v + 1
per_replica_results = self.evaluate(
distribution.experimental_local_results(distribution.run(add)))
self.assertAllEqual([1, 2], per_replica_results)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
aggregation=[
variables_lib.VariableAggregation.MEAN,
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
],
mode=["graph", "eager"]))
class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
def testScatterSub(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[1., 1., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
def testScatterAdd(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[1., 1., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
def testScatterDiv(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[2., 6., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
def testScatterMul(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[2., 1., 1.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
def testScatterMin(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[3., 4., 5.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
def testScatterMax(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[3., 4., 5.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
def testScatterUpdate(self, distribution, aggregation):
with distribution.scope():
v = variables_lib.Variable(
[0., 0., 0.],
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(v.initializer)
delta = values.PerReplica([
indexed_slices.IndexedSlices(
values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
indexed_slices.IndexedSlices(
values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
])
with self.assertRaises(NotImplementedError):
self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
class MirroredTest(test.TestCase):
def testAddOp(self):
if context.num_gpus() < 1:
self.skipTest("A GPU is not available for this test.")
mirrored_val = _make_mirrored_val(init_val=3.)
self.assertEqual(self.evaluate(constant_op.constant(6.)),
self.evaluate(mirrored_val + mirrored_val))
self.assertEqual(self.evaluate(constant_op.constant(4.)),
self.evaluate(mirrored_val + 1))
self.assertEqual(self.evaluate(mirrored_val + 1),
self.evaluate(math_ops.add(mirrored_val, 1)))
self.assertEqual(type(mirrored_val + 1),
type(math_ops.add(mirrored_val, 1)))
class PerReplicaTest(test.TestCase, parameterized.TestCase):
def testTypeSpec(self):
vals = (constant_op.constant(1.),)
per_replica = values.PerReplica(vals)
spec = per_replica._type_spec
self.assertEqual(spec._value_specs,
(tensor_spec.TensorSpec([], dtypes.float32),))
def testTypeSpecRoundTrip(self):
vals = (constant_op.constant(1.),)
per_replica = values.PerReplica(vals)
spec = per_replica._type_spec
tensor_list = spec._to_components(per_replica)
reconstructed = spec._from_components(tensor_list)
self.assertAllEqual(per_replica.values, reconstructed.values)
def testTypeSpecNest(self):
vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
per_replica = values.PerReplica(vals)
# Note: nest.map_structure exercises nest.flatten and
# nest.pack_sequence_as.
result = nest.map_structure(
lambda t: t + 10, per_replica, expand_composites=True)
self.assertLen(result.values, 2)
self.assertAllEqual(result.values[0], 11.)
self.assertAllEqual(result.values[1], [15., 16.0])
@test_util.run_in_graph_and_eager_modes
def testIsGraphTensor(self):
per_replica = values.PerReplica((constant_op.constant(1.),))
for t in nest.flatten(per_replica, expand_composites=True):
self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
def testDoesNotTriggerFunctionTracing(self):
traces = []
@def_function.function
def f(x):
traces.append(None) # Only happens on trace.
return x
per_replica = values.PerReplica((constant_op.constant(1.),))
# Trace once.
f(per_replica)
self.assertNotEmpty(traces)
del traces[:]
per_replica_spec = per_replica._type_spec
for _ in range(5):
vals = per_replica_spec._to_components(per_replica)
vals = [v * 2 for v in vals]
per_replica = per_replica_spec._from_components(vals)
output = f(per_replica)
self.assertIsInstance(output, values.PerReplica)
self.assertAllEqual(output._values, per_replica._values)
self.assertEmpty(traces) # Make sure we're not re-tracing `f`.
def testFunctionCanReturnPerReplica(self):
f = def_function.function(lambda x: x)
x = values.PerReplica((constant_op.constant(1.),))
y = f(x)
self.assertIsNot(x, y)
nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
self.assertEqual(x._type_spec, y._type_spec)
@test_util.run_in_graph_and_eager_modes
def testCondWithTensorValues(self):
per_replica_1 = values.PerReplica((constant_op.constant("a"),))
per_replica_2 = values.PerReplica((constant_op.constant(["b", "c"]),))
condition = array_ops.placeholder_with_default(True, [])
result = control_flow_ops.cond(
condition, lambda: per_replica_1, lambda: per_replica_2)
self.assertLen(result.values, 1)
self.assertAllEqual(result.values[0], "a")
@test_util.run_in_graph_and_eager_modes
def testCondWithValuesConvertibleToTensor(self):
per_replica_1 = values.PerReplica(("a",))
per_replica_2 = values.PerReplica(("b",))
condition = array_ops.placeholder_with_default(True, [])
result = control_flow_ops.cond(
condition, lambda: per_replica_1, lambda: per_replica_2)
self.assertLen(result.values, 1)
self.assertAllEqual(result.values[0], "a")
@test_util.build_as_function_and_v1_graph
def testCondWithValuesNotConvertibleToTensor(self):
per_replica_1 = values.PerReplica(({"a"},))
per_replica_2 = values.PerReplica(({"b", "c"},))
condition = array_ops.placeholder(dtypes.bool, [])
with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
control_flow_ops.cond(
condition, lambda: per_replica_1, lambda: per_replica_2)
def _make_index_slices(vals, indices, dense_shape=None):
if dense_shape:
dense_shape = array_ops.identity(dense_shape)
return indexed_slices.IndexedSlices(
array_ops.identity(vals), array_ops.identity(indices), dense_shape)
if __name__ == "__main__":
test.main()