# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for the distributed values library.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import os from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python import tf2 from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import packed_distributed_variable as packed from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import ps_values from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import test_util as ds_test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as values_lib from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save_context from tensorflow.python.saved_model import save_options from tensorflow.python.training import saver as saver_lib from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.types import core from tensorflow.python.util import nest def _device_str(d): return "/device:GPU:" + str(d) def _nested_value(d): return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) def _make_mirrored_val(init_val=5.0): v = [] devices = ["/device:GPU:0", "/device:CPU:0"] for d, _ in zip(devices, ["v", "v/replica"]): with ops.device(d): v.append(constant_op.constant(init_val)) return values_lib.Mirrored(v) def _make_mirrored(distribution=None): v = [] if distribution: devices = distribution.extended.worker_devices else: devices = ["/device:GPU:0", "/device:CPU:0"] for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append( variable_scope.get_variable( name=n, initializer=init, use_resource=True)) if (distribution is not None) and isinstance(distribution, _TPU_STRATEGIES): var_cls = tpu_values.TPUMirroredVariable else: var_cls = values_lib.MirroredVariable mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM) return mirrored def mirrored_and_tpu_strategy_combinations(): return combinations.combine( distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy_packed_var, ], mode=["graph", "eager"]) class DistributedValuesTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + strategy_combinations.multiworker_strategies), mode=["eager"] )) def testMakeDistributedValueFromTensor(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") single_value = constant_op.constant(1) def value_fn(ctx): del ctx return single_value distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) self.assertAllEqual( ds_test_util.gather(distribution, distributed_values), constant_op.constant(1., shape=(distribution.num_replicas_in_sync))) @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + strategy_combinations.multiworker_strategies), mode=["eager"] )) def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") array_value = np.array([1., 2., 3.]) def value_fn(ctx): del ctx return array_value distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) self.assertAllEqual( ds_test_util.gather(distribution, distributed_values).numpy(), [[1., 2., 3.]] * distribution.num_replicas_in_sync) @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + strategy_combinations.multiworker_strategies), mode=["eager"] )) def testMakeDistributedValueTupleConstant(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") tuple_value = (1., 2., 3.) def value_fn(ctx): del ctx return tuple_value distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) distributed_values = ds_test_util.gather(distribution, distributed_values) # Expected output for 2 replicas: # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0]) expected = tuple([v for i in range(distribution.num_replicas_in_sync)] for v in tuple_value) self.assertAllEqual(distributed_values, expected) @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + strategy_combinations.multiworker_strategies), mode=["eager"] )) def testMakeDistributedValueNestedStructurePerReplica(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") tuple_value = (1., 2., 3.) def value_fn(ctx): per_replica = [] for val in tuple_value: per_replica.append(val * ctx.replica_id_in_sync_group) return tuple(per_replica) distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) distributed_values = ds_test_util.gather(distribution, distributed_values) # Expected output for 2 replicas: # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0]) expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)] for v in tuple_value) self.assertAllEqual(distributed_values, expected) # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because # collective ops do not support SparseTensors. @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies_minus_default, mode=["eager"] )) def testMakeDistributedValueSpareTensor(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") def value_fn(ctx): del ctx return sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) local_results = distribution.experimental_local_results(distributed_values) for i in range(distribution.num_replicas_in_sync): self.assertAllEqual( sparse_ops.sparse_tensor_to_dense(local_results[i]), [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + strategy_combinations.multiworker_strategies), mode=["eager"] )) def testMakeDistributedValueExtractFromArray(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") multiple_values = range(distribution.num_replicas_in_sync) def value_fn(ctx): return multiple_values[ctx.replica_id_in_sync_group] distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) distributed_values = ds_test_util.gather(distribution, distributed_values) expected = range(distribution.num_replicas_in_sync) self.assertAllEqual(distributed_values, expected) @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + strategy_combinations.multiworker_strategies), mode=["eager"] )) def testMakeDistributedValueAndRun(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") @def_function.function def run(): multiple_values = range(distribution.num_replicas_in_sync) def value_fn(ctx): return multiple_values[ctx.replica_id_in_sync_group] distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) def computation(x): return math_ops.square(x) outputs = ds_test_util.gather( distribution, distribution.run(computation, args=(distributed_values,))) return outputs results = run() expected = [i**2 for i in range(distribution.num_replicas_in_sync)] self.assertAllEqual(results, expected) @combinations.generate( combinations.combine( distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy_packed_var, strategy_combinations.central_storage_strategy_with_two_gpus, ] + strategy_combinations.multiworker_strategies, mode=["eager"])) def testMakeDistributedValueDefaultDevicePlacement(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") def value_fn(ctx): del ctx return constant_op.constant(1.0) distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) for i in range(len(distribution.extended.worker_devices)): self.assertAllEqual(distributed_values._values[i].device, "/job:localhost/replica:0/task:0/device:CPU:0") @combinations.generate( combinations.combine( distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy_packed_var, strategy_combinations.central_storage_strategy_with_two_gpus, ] + strategy_combinations.multiworker_strategies, mode=["eager"])) def testMakeDistributedValueExplicitDevicePlacement(self, distribution): if not tf2.enabled(): self.skipTest("Only V2 is supported.") worker_devices = distribution.extended.worker_devices def value_fn(ctx): # In multi client setup, worker_devices is just the devices on that # worker. worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices) with ops.device(worker_devices[worker_device_id]): return array_ops.identity(1.0) distributed_values = ( distribution.experimental_distribute_values_from_function(value_fn)) for i in range(len(distribution.extended.worker_devices)): self.assertAllEqual(distributed_values._values[i].device, worker_devices[i]) class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testGetAttr(self): class Foo(object): def __init__(self, x): self.x = x v = values_lib.DistributedDelegate((Foo(7), Foo(8))) self.assertEqual(7, v.x) with self.assertRaises(AttributeError): _ = v.y @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): v = values_lib.DistributedDelegate((7, 8)) # v should act like int(7). self.assertEqual(8, v + 1) self.assertEqual(10, 3 + v) self.assertEqual(14, v + v) self.assertEqual(5, v - 2) self.assertEqual(6, 13 - v) self.assertEqual(0, v - v) self.assertEqual(14, v * 2) self.assertEqual(21, 3 * v) self.assertEqual(49, v * v) self.assertEqual(3.5, v / 2) self.assertEqual(1.5, 10.5 / v) self.assertEqual(3, v // 2) self.assertEqual(2, 15 // v) self.assertEqual(1, v % 2) self.assertEqual(2, 16 % v) # pylint: disable=g-generic-assert self.assertTrue(v < 12) self.assertTrue(v <= 12) self.assertFalse(v > 12) self.assertFalse(v >= 12) self.assertFalse(12 < v) self.assertFalse(12 <= v) self.assertTrue(12 > v) self.assertTrue(12 >= v) # pylint: enable=g-generic-assert self.assertEqual(3, v & 3) self.assertEqual(3, 11 & v) self.assertEqual(15, v | 8) self.assertEqual(23, 16 | v) self.assertEqual(4, v ^ 3) self.assertEqual(12, 11 ^ v) self.assertEqual(343, pow(v, 3)) self.assertEqual(3, pow(v, 3, 10)) self.assertEqual(128, pow(2, v)) self.assertEqual(-7, -v) self.assertEqual(~7, ~v) self.assertEqual(7, abs(v)) with self.assertRaises(TypeError): _ = v[2] @test_util.run_in_graph_and_eager_modes def testCopy(self): class Foo(object): def __init__(self, x): self.x = x v = values_lib.DistributedDelegate((Foo(7), Foo(8))) v_shallow_copy = copy.copy(v) self.assertEqual(v.x, v_shallow_copy.x) v_deep_copy = copy.deepcopy(v) self.assertEqual(v.x, v_deep_copy.x) @combinations.generate( combinations.combine( distribution=[ strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy_packed_var, strategy_combinations.central_storage_strategy_with_gpu_and_cpu, strategy_combinations.multi_worker_mirrored_2x1_cpu, strategy_combinations.multi_worker_mirrored_2x1_gpu, strategy_combinations.multi_worker_mirrored_2x2_gpu ], synchronization=[ variables_lib.VariableSynchronization.ON_READ, variables_lib.VariableSynchronization.ON_WRITE, ], aggregation=[ variables_lib.VariableAggregation.MEAN, variables_lib.VariableAggregation.SUM, variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, ], mode=["graph", "eager"], use_var_policy=[True, False])) class DistributedVariableTest(test.TestCase, parameterized.TestCase): def testExtendsVariable(self, distribution, synchronization, aggregation): with distribution.scope(): v = variables_lib.Variable( 1., synchronization=synchronization, aggregation=aggregation) self.assertIsInstance(v, variables_lib.Variable) def testCheckpointing(self, distribution, synchronization, aggregation, mode): if (isinstance(distribution, collective_all_reduce_strategy.CollectiveAllReduceStrategy) and mode == "graph"): self.skipTest("MWMS combinations tests do not work well in graph mode.") with distribution.scope(): v = variables_lib.Variable( constant_op.constant([1., 2., 3., 4]), synchronization=synchronization, aggregation=aggregation) self.evaluate(v.initializer) before_save = self.evaluate(v.read_value()) # Save random weights into checkpoint. checkpoint = trackable_utils.Checkpoint(v=v) prefix = os.path.join(self.get_temp_dir(), "ckpt") with self.test_session(): save_path = checkpoint.save(prefix) # Assign inverted value. self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) after_assign = self.evaluate(v.read_value()) self.assertNotAllClose(before_save, after_assign) # Restore from the checkpoint. with self.test_session(): checkpoint.restore(save_path).assert_consumed().run_restore_ops() after_restore = self.evaluate(v) self.assertAllClose(before_save, after_restore) def testTraceback(self, distribution, synchronization, aggregation): if context.executing_eagerly(): self.skipTest("does not apply to eager") with distribution.scope(): variable_scope.get_variable( name="testVar", initializer=1., use_resource=True, synchronization=synchronization, aggregation=aggregation) with self.assertRaisesRegex(ValueError, "Variable testVar already exists"): variable_scope.get_variable( name="testVar", initializer=1., use_resource=True, synchronization=synchronization, aggregation=aggregation) def testSelectReplica(self, distribution, synchronization, aggregation): with distribution.scope(): v = variables_lib.Variable( 1., synchronization=synchronization, aggregation=aggregation) self.assertIs(v, distribute_utils.select_replica(0, v)) def testIsTensorLike(self, distribution, synchronization, aggregation): if isinstance(distribution.extended, tpu_strategy.TPUExtended) and context.executing_eagerly(): self.skipTest("TPU doesn't support pure eager") with distribution.scope(): v = variables_lib.Variable( 0., synchronization=synchronization, aggregation=aggregation) # In cross replica context. self.assertIsInstance(v, core.Tensor) # In replica context. distribution.run( lambda v: self.assertIsInstance(v, core.Tensor), args=(v,)) def testAssignReturnValueIsTensorLike(self, distribution, synchronization, aggregation): if isinstance(distribution.extended, tpu_strategy.TPUExtended): if context.executing_eagerly(): self.skipTest("TPU doesn't support pure eager") else: self.skipTest("b/152076846") with distribution.scope(): v = variables_lib.Variable( 0., synchronization=synchronization, aggregation=aggregation) def assert_is_tensor_like(v): # We can't use Python literals because they are treated as non-distributed # values is not allowed when aggregation is SUM. See # `cross_device_ops.reduce_non_distributed_value`. delta = array_ops.identity(1.) self.assertIsInstance(v.assign(delta), core.Tensor) self.assertIsInstance(v.assign_sub(delta), core.Tensor) self.assertIsInstance(v.assign_add(delta), core.Tensor) # In cross replica context we return a PerReplica which is not Tensor like # all the time yet. if (synchronization == variables_lib.VariableSynchronization.ON_READ and aggregation != variables_lib.VariableAggregation.SUM): assert_is_tensor_like(v) # In replica context. distribution.run(assert_is_tensor_like, args=(v,)) def testDeepCopy(self, distribution, synchronization, aggregation): if not context.executing_eagerly(): self.skipTest("deepcopy only supported in eager mode") with distribution.scope(): v = variables_lib.Variable( 0., synchronization=synchronization, aggregation=aggregation) in_dist_copy = copy.deepcopy(v) out_dist_copy = copy.deepcopy(v) def assert_is_deep_copy(v1, v2): self.assertIsInstance(v2, type(v1)) self.assertEqual(v1.aggregation, v2.aggregation) self.assertEqual(v1.distribute_strategy, v2.distribute_strategy) if isinstance(v1, ps_values.AggregatingVariable): self.assertIsInstance(v2.get(), type(v1.get())) self.assertNotEqual(id(v1.get()), id(v2.get())) else: if v1._policy: self.assertNotEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access else: self.assertEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access self.assertEqual(len(v1.values), len(v2.values)) for (v1v, v2v) in zip(v1.values, v2.values): self.assertEqual(v1v.device, v2v.device) self.assertNotEqual(id(v1v), id(v2v)) self.assertAllEqual(self.evaluate(v1.values), self.evaluate(v2.values)) self.evaluate(variables_lib.global_variables_initializer()) if not isinstance(distribution.extended, tpu_strategy.TPUExtended): distribution.run(assert_is_deep_copy, args=(v, in_dist_copy)) distribution.run(assert_is_deep_copy, args=(v, out_dist_copy)) def testAssignSignature(self, distribution, synchronization, aggregation): # This test verifies assign*() can be called in the same way as normal # variables. with distribution.scope(): v = variables_lib.Variable( 0., synchronization=synchronization, aggregation=aggregation) def assign(): one = constant_op.constant(1.) v.assign(one, True, "assign", False) # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing # value as a keyword argument. v.assign(one, use_locking=True, name="assign", read_value=False) v.assign_add(one, True, "assign", False) v.assign_add(one, use_locking=True, name="assign", read_value=False) v.assign_sub(one, True, "assign", False) v.assign_sub(one, use_locking=True, name="assign", read_value=False) # Return something for graph mode to fetch. return constant_op.constant(1) self.evaluate(variables_lib.global_variables_initializer()) if not (synchronization == variables_lib.VariableSynchronization.ON_READ and aggregation == variables_lib.VariableAggregation.SUM): self.evaluate(distribution.experimental_local_results(assign())) if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and context.executing_eagerly()): self.evaluate( distribution.experimental_local_results(distribution.run(assign))) def testStrategyExtendedUpdate(self, distribution, synchronization, aggregation): if len(distribution.extended.parameter_devices) != 2: self.skipTest("n/a: needs exactly two parameter devices") with distribution.scope(): v = variables_lib.Variable( 0., synchronization=synchronization, aggregation=aggregation) # Note that this is actually real usage. We're doing this in optimizer to # workaround the current restriction in strategy.extended.update(). value = values_lib.Mirrored([1., 2.]) assign_fn = lambda var, value: var.assign(value) self.evaluate(distribution.extended.update(v, assign_fn, args=(value,))) self.assertAllEqual(self.evaluate(v.values), [1., 2.]) assign_add_fn = lambda var, value: var.assign_add(value) self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,))) self.assertAllEqual(self.evaluate(v.values), [2., 4.]) assign_sub_fn = lambda var, value: var.assign_sub(value) self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,))) self.assertAllEqual(self.evaluate(v.values), [1., 2.]) read_assign_fn = lambda var, value: var.assign_add(var.value() + var. read_value()) self.evaluate( distribution.extended.update(v, read_assign_fn, args=(value,))) self.assertAllEqual(self.evaluate(v.values), [3., 6.]) def testSaveNonDistributed(self, distribution, synchronization, aggregation): # This test verifies that the DistributedVariable behave like the primary # variable when saving a non-distributed version of the model (the default). # The test asserts that the function traced under SaveContext has no device # annotations and only reference the primary component of the variable. Note # that please avoid capturing other eager tensors in this test to make the # assertion easy. if isinstance(distribution.extended, parameter_server_strategy.ParameterServerStrategyExtended): self.skipTest("b/148689177: AggregatingVariable doesn't " "conform to Variable interface well") # tf.function requires the return value to be Tensors, which is not always # case for properties and methods of Variable, so we simply discard the # return values. def _discard_return(f): f() return def _test(f, v): # This verifies that the function under SaveContext: # - contains no device annotations. # - only references the primary component of the variable. g = def_function.function(lambda: _discard_return(f)) options = save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy.NONE) with save_context.save_context(options): # The graph should contain no device. graph = g.get_concrete_function().graph for op in graph.get_operations(): self.assertEqual(op.device, "", msg=str(op)) # The function should only capture the primary variable. Note that it # may not have captures, e.g. v.aggregation. captures = list(graph.captures) self.assertLessEqual(len(captures), 1) if graph.captures: self.assertIs(captures[0][0], v._primary.handle) def _assert(cond): return control_flow_ops.Assert(cond, [cond]) with distribution.scope(): # We use four variables for convenience reasons. They have no special # meaning. # - v is used whenever possible. # - w is used for scatter and gather, which require the variable to be # non-scalar. # - y is used when the dtype needs to be integer. Note that aggregation # cannot be MEAN for integers. v = variables_lib.Variable( 0., synchronization=synchronization, aggregation=aggregation, trainable=True) w = variables_lib.Variable([0., 0., 0.], synchronization=synchronization, aggregation=aggregation, trainable=True) if aggregation != variables_lib.VariableAggregation.MEAN: y = variables_lib.Variable( 0, synchronization=synchronization, aggregation=aggregation) # pylint: disable=g-long-lambda # tf.Variable properties. _test(lambda: self.assertEqual(v.aggregation, aggregation), v) _test(lambda: self.assertIs(v.constraint, None), v) # TODO(crccw): should we raise an error instead? _test(lambda: self.assertEqual(v.device, v._primary.device), v) _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v) if not context.executing_eagerly(): _test(lambda: self.assertIs(v.graph, v._primary.graph), v) if not context.executing_eagerly(): _test(lambda: _assert(v.initial_value == 0), v) _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v) _test(lambda: self.assertEqual(v.name, "Variable:0"), v) if not context.executing_eagerly(): _test(lambda: self.assertIs(v.op, v._primary.op), v) _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v) _test(lambda: self.assertEqual(v.synchronization, synchronization), v) _test(lambda: self.assertTrue(v.trainable, True), v) # tf.Variable methods. _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v) _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v) _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v) # TODO(b/148689177): Implement batch_scatter_update. # count_up_to() is skipped since it's deprecated. # eval() is skipped since it shouldn't called in a tf.function. # experimental_ref() is skipped since it's deprecated. # from_proto() is skipped since it shouldn't called in a tf.function. # TODO(b/148689177): Implement gather_nd. _test( lambda: check_ops.assert_equal_v2(v.get_shape(), tensor_shape.TensorShape(())), v) # initialized_value() is skipped since it shouldn't called in a tf.function. # load() is skipped since it shouldn't called in a tf.function. _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v) # ref() is skipped since it shouldn't called in a tf.function. _test( lambda: check_ops.assert_equal_v2( w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])), [1., 0., 2.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])), [0.25, 0., 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])), [0.25, 1., 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])), [0.25, 0.5, 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])), [0.5, 0.25, 1.]), w) # TODO(b/148689177): Implement scatter_nd_* _test( lambda: check_ops.assert_equal_v2( w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])), [-1.5, -0.25, 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_update( _make_index_slices(values=[2., 0.5], indices=[0, 1])), [2., 0.5, 1.]), w) # set_shape() is skipped since ResourceVariable doesn't implement it. # to_proto() is skipped since it shouldn't called in a tf.function. _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v) # DistributedVariable should be treated as ResourceVariable, so it needs to # conform to ResourceVariable interface as well. _test(lambda: self.assertIs(v.handle, v._primary.handle), v) # Convert to tensor. _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v) # Control dependency. def _with_control_dep(): with ops.control_dependencies([v.assign(1.)]): return array_ops.identity(1) _test(_with_control_dep, v) # Operator overloads. _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v) _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v) _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v) _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v) _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v) _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v) _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v) _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v) _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v) _test( lambda: check_ops.assert_equal_v2( math_ops.cast(v / 2., dtypes.float32), 3.5), v) _test( lambda: check_ops.assert_equal_v2( math_ops.cast(14. / v, dtypes.float32), 2.), v) _test(lambda: _assert(v < 12.), v) _test(lambda: _assert(v <= 12.), v) _test(lambda: _assert(not v > 12.), v) _test(lambda: _assert(not v >= 12.), v) _test(lambda: _assert(not 12. < v), v) _test(lambda: _assert(not 12. <= v), v) _test(lambda: _assert(12. > v), v) _test(lambda: _assert(12. >= v), v) _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v) _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v) _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v) # Operator overloads that only works for integers. if aggregation != variables_lib.VariableAggregation.MEAN: _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y) _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y) _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y) _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y) _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y) _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y) _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y) _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y) _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y) _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y) _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y) _test(lambda: check_ops.assert_equal_v2(-y, -7), y) _test(lambda: check_ops.assert_equal_v2(~y, ~7), y) # Index. if isinstance(distribution.extended, tpu_strategy.TPUExtended): # TODO(b/161572567): slice assignment doesn't work for TPU. _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w) else: _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]), w) _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w) # pylint: enable=g-long-lambda def testUnsaveable(self, distribution, synchronization, aggregation, mode): if isinstance(distribution.extended, parameter_server_strategy.ParameterServerStrategyExtended): self.skipTest("n/a: not appliable to AggregatingVariable") if (isinstance(distribution, collective_all_reduce_strategy.CollectiveAllReduceStrategy) and mode == "graph"): self.skipTest("MWMS combinations tests do not work well in graph mode.") with distribution.scope(): v = variables_lib.Variable([1., 1.], synchronization=synchronization, aggregation=aggregation) with self.cached_session(): self.evaluate(variables_lib.global_variables_initializer()) export_dir = self.get_temp_dir() def _assert_unsaveable(f): # Ignore if it cannot be traced. Certain combinations are not supported or # yet or not allowed. try: f = def_function.function(f).get_concrete_function() except (NotImplementedError, ValueError): return with self.assertRaisesRegex(ValueError, "f_with_input_signature"): save.save(v, export_dir, signatures=f) _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.]))) _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.]))) _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.]))) _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0]))) _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0]))) _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0]))) _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0]))) _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0]))) _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0]))) _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0]))) # Reading a ON_READ variable should be unsaveable if either: # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM. # 2) aggregation is SUM. if (synchronization == variables_lib.VariableSynchronization.ON_READ and (aggregation == variables_lib.VariableAggregation.SUM or (isinstance(distribution.extended, collective_all_reduce_strategy.CollectiveAllReduceExtended) and aggregation == variables_lib.VariableAggregation.MEAN))): _assert_unsaveable(v.read_value) _assert_unsaveable(v.value) _assert_unsaveable(lambda: ops.convert_to_tensor(v)) else: # Otherwise reading a variable should be saveable. @def_function.function def f(): v.read_value() v.value() return ops.convert_to_tensor(v) with self.cached_session(): save.save(v, export_dir, signatures=f.get_concrete_function()) @combinations.generate( combinations.combine( distribution=[ strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.tpu_strategy, ], mode=["eager"])) class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase): def testPackedVariable(self, distribution): with distribution.scope(): v0 = variables_lib.Variable(0.) self.assertIsNone(v0._packed_var) distribution._enable_packed_variable_in_eager_mode = True with distribution.scope(): v1 = variables_lib.Variable(0) self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable) devices = v1._devices for i in range(1, len(devices)): with distribute_lib.ReplicaContext(distribution, i): v1.assign(i) val = v1._get() self.assertIsInstance(val, packed.PackedVarAndDevice) self.assertEqual(val.device, devices[0]) self.assertEqual(self.evaluate(val.read_value()), 0) for i in range(0, len(devices)): with distribute_lib.ReplicaContext(distribution, i): val = v1._get() self.assertIsInstance(val, packed.PackedVarAndDevice) self.assertEqual(val.device, devices[i]) self.assertEqual(self.evaluate(val.read_value()), i) def testIgnorePackedVariableInSaveContext(self, distribution): distribution._enable_packed_variable_in_eager_mode = True with distribution.scope(): v = variables_lib.Variable(0) self.assertIsInstance( v._packed_variable, packed.PackedDistributedVariable) options = save_options.SaveOptions() with save_context.save_context(options): self.assertIsNone(v._packed_variable) class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @test_util.run_in_graph_and_eager_modes(config=config) def testProperties(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") mirrored = _make_mirrored() v = mirrored.values[0] self.assertEqual(v.name, mirrored.name) self.assertEqual(v.dtype, mirrored.dtype) self.assertEqual(v.shape, mirrored.shape) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) mirrored = values_lib.MirroredVariable( None, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, mirrored.name) self.assertEqual(v.dtype, mirrored.dtype) self.assertEqual(v.shape, mirrored.shape) class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase): def _assign_mirrored(self, v, new): for var, n in zip(v.values, new): self.evaluate(var.assign(n)) def _save_return_saver(self, sess, var): saver = saver_lib.Saver(var_list=[var]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") return saver.save(sess, prefix), saver def _save(self, sess, var): save_path, _ = self._save_return_saver(sess, var) return save_path def _save_mirrored(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: mirrored = _make_mirrored(distribution) # Overwrite the initial values. self._assign_mirrored(mirrored, [3., 4.]) # Saves the current value of v[0], 3. save_path = self._save(sess, mirrored) # Change the values between save and restore. self._assign_mirrored(mirrored, [5., 6.]) return save_path def _save_normal(self): """Save variables without mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: var = variable_scope.get_variable( name="v", initializer=1., use_resource=True) # Overwrite the initial value. self.evaluate(var.assign(3.)) # Saves the current value of var, 3. save_path = self._save(sess, var) # Change the values between save and restore. self.evaluate(var.assign(5.)) return save_path def _restore_normal(self, save_path): """Restore to variables without mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: var = variable_scope.get_variable( name="v", initializer=7., use_resource=True) # Overwrite the initial value. self.evaluate(var.assign(8.)) # Restores the saved value of 3. to `var`. saver = saver_lib.Saver(var_list=[var]) saver.restore(sess, save_path) self.assertEqual(3., self.evaluate(var)) def _restore_mirrored(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: mirrored = _make_mirrored(distribution) v = mirrored.values # Overwrite the initial values. self._assign_mirrored(mirrored, [7., 8.]) # Restores the saved value of 3. to both variables. saver = saver_lib.Saver(var_list=[mirrored]) saver.restore(sess, save_path) self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveAndRestoreMirroredOneGraph(self, distribution): with self.cached_session() as sess: mirrored = _make_mirrored(distribution) v = mirrored .values # Overwrite the initial values. self._assign_mirrored(mirrored, [3., 4.]) # Saves the current value of v[0], 3. save_path, saver = self._save_return_saver(sess, mirrored) # Change the values between save and restore. self._assign_mirrored(mirrored, [5., 6.]) # Restores the saved value of 3. to both variables. saver.restore(sess, save_path) self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveMirroredRestoreMirrored(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): # Graph mode can work without GPU because the Placer "moves" the # variable to a CPU. In other words, if there is no GPU available, but # user requested to create a variable on GPU, Placer will ignore the # user request and assign the VarHandleOp to CPU. This requires # soft_placement, which is on by default. self.skipTest("A GPU is not available for this test in eager mode.") save_path = self._save_mirrored(distribution) self._restore_mirrored(save_path, distribution) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveMirroredRestoreNormal(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): # Graph mode can work without GPU because the Placer "moves" the # variable to a CPU. In other words, if there is no GPU available, but # user requested to create a variable on GPU, Placer will ignore the # user request and assign the VarHandleOp to CPU. This requires # soft_placement, which is on by default. self.skipTest("A GPU is not available for this test in eager mode.") save_path = self._save_mirrored(distribution) self._restore_normal(save_path) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveNormalRestoreMirrored(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): # Graph mode can work without GPU because the Placer "moves" the # variable to a CPU. In other words, if there is no GPU available, but # user requested to create a variable on GPU, Placer will ignore the # user request and assign the VarHandleOp to CPU. This requires # soft_placement, which is on by default. self.skipTest("A GPU is not available for this test in eager mode.") save_path = self._save_normal() self._restore_mirrored(save_path, distribution) _TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1) def _make_replica_local(method, strategy=None): if strategy is None: devices = ("/device:GPU:0", "/device:CPU:0") else: devices = strategy.extended.worker_devices v = [] for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES): var_cls = tpu_values.TPUSyncOnReadVariable else: var_cls = values_lib.SyncOnReadVariable replica_local = var_cls(strategy, v, method) return v, replica_local class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): def _assign_replica_local(self, v, new): for var, n in zip(v, new): with ops.device(var.device): self.evaluate(var.assign(n)) def _save_return_saver(self, sess, var): saver = saver_lib.Saver(var_list=[var]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") return saver.save(sess, prefix), saver def _save(self, sess, var): save_path, _ = self._save_return_saver(sess, var) return save_path config = config_pb2.ConfigProto() config.allow_soft_placement = True @test_util.run_in_graph_and_eager_modes(config=config) def testProperties(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) self.assertEqual(v[0].constraint, replica_local.constraint) self.assertEqual(v[0].name, replica_local.name) self.assertEqual(v[0].dtype, replica_local.dtype) self.assertEqual(v[0].shape, replica_local.shape) self.assertEqual(variable_scope.VariableAggregation.SUM, replica_local.aggregation) @test_util.run_v2_only def testCanPassToDefFun(self): @def_function.function def add1(x): return x + 1 v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) replica_local = values_lib.SyncOnReadVariable( None, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(2., self.evaluate(add1(replica_local))) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testTensorConversion(self, distribution): with context.graph_mode(): _, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM, distribution) converted = ops.convert_to_tensor(replica_local, as_ref=False) self.assertIsInstance(converted, ops.Tensor) self.assertEqual(converted.dtype, replica_local.dtype) converted = ops.convert_to_tensor(replica_local, as_ref=True) # Resources variable are converted to tensors as well when as_ref is True. self.assertIsInstance(converted, ops.Tensor) self.assertEqual(converted.dtype, replica_local.dtype) @combinations.generate(combinations.combine( distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy_packed_var, ], mode=["eager"])) def testValueInCrossReplicaContext(self, distribution): value_list, replica_local = _make_replica_local( variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) self.assertIsInstance(replica_local.value(), ops.Tensor) self.assertEqual(self.evaluate(replica_local.value()), self.evaluate(value_list[0].value())) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(v, [3., 4.]) with distribution.scope(): # Saves the current value of v[0] + v[1], 7. save_path, saver = self._save_return_saver(sess, replica_local) # Change the values between save and restore. self._assign_replica_local(v, [5., 6.]) # Restores the saved value of 7. which gets divided equally # between the variables. saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(v, [3., 4.]) with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5. save_path, saver = self._save_return_saver(sess, replica_local) # Change the values between save and restore. self._assign_replica_local(v, [5., 6.]) # Restores the saved value of 3.5 to both variables. saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) def _save_replica_local_mean(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(v, [3., 4.]) with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5 save_path = self._save(sess, replica_local) # Change the values between save and restore. self._assign_replica_local(v, [5., 6.]) return save_path def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(v, [1.5, 2.]) with distribution.scope(): # Saves the current value of v[0] + v[1], 3.5 save_path = self._save(sess, replica_local) # Change the values between save and restore. self._assign_replica_local(v, [5., 6.]) return save_path def _save_normal(self): """Save variables without mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: var = variable_scope.get_variable( name="v", initializer=1., use_resource=True) # Overwrite the initial value. self.evaluate(var.assign(3.5)) # Saves the current value of var, 3.5. save_path = self._save(sess, var) # Change the values between save and restore. self.evaluate(var.assign(5.)) return save_path def _restore_normal(self, save_path): """Restore to variables without mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: var = variable_scope.get_variable( name="v", initializer=7., use_resource=True) # Overwrite the initial value. self.evaluate(var.assign(8.)) # Restores the saved value of 3.5 to `var`. saver = saver_lib.Saver(var_list=[var]) saver.restore(sess, save_path) self.assertEqual(3.5, self.evaluate(var)) def _restore_replica_local_mean(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(v, [7., 8.]) with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) def _restore_replica_local_sum(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(v, [7., 8.]) with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): save_path = self._save_replica_local_mean(distribution) self._restore_replica_local_mean(save_path, distribution) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): save_path = self._save_replica_local_sum(distribution) self._restore_replica_local_sum(save_path, distribution) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalMeanRestoreNormal(self, distribution): save_path = self._save_replica_local_mean(distribution) self._restore_normal(save_path) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalSumRestoreNormal(self, distribution): save_path = self._save_replica_local_sum(distribution) self._restore_normal(save_path) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveNormalRestoreReplicaLocalMean(self, distribution): save_path = self._save_normal() self._restore_replica_local_mean(save_path, distribution) @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveNormalRestoreReplicaLocalSum(self, distribution): save_path = self._save_normal() self._restore_replica_local_sum(save_path, distribution) class MirroredTest(test.TestCase): def testAddOp(self): if context.num_gpus() < 1: self.skipTest("A GPU is not available for this test.") mirrored_val = _make_mirrored_val(init_val=3.) self.assertEqual(self.evaluate(constant_op.constant(6.)), self.evaluate(mirrored_val + mirrored_val)) self.assertEqual(self.evaluate(constant_op.constant(4.)), self.evaluate(mirrored_val + 1)) self.assertEqual(self.evaluate(mirrored_val + 1), self.evaluate(math_ops.add(mirrored_val, 1))) self.assertEqual(type(mirrored_val + 1), type(math_ops.add(mirrored_val, 1))) class PerReplicaTest(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(mode=["eager"])) def testTypeSpec(self): vals = (constant_op.constant(1.),) per_replica = values_lib.PerReplica(vals) spec = per_replica._type_spec self.assertEqual(spec._value_specs, (tensor_spec.TensorSpec([], dtypes.float32),)) @combinations.generate(combinations.combine(mode=["eager"])) def testTypeSpecRoundTrip(self): vals = (constant_op.constant(1.),) per_replica = values_lib.PerReplica(vals) spec = per_replica._type_spec tensor_list = spec._to_components(per_replica) reconstructed = spec._from_components(tensor_list) self.assertAllEqual(per_replica.values, reconstructed.values) @combinations.generate(combinations.combine(mode=["eager"])) def testTypeSpecNest(self): vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),) per_replica = values_lib.PerReplica(vals) # Note: nest.map_structure exercises nest.flatten and # nest.pack_sequence_as. result = nest.map_structure( lambda t: t + 10, per_replica, expand_composites=True) self.assertLen(result.values, 2) self.assertAllEqual(result.values[0], 11.) self.assertAllEqual(result.values[1], [15., 16.0]) @test_util.run_in_graph_and_eager_modes def testIsGraphTensor(self): per_replica = values_lib.PerReplica((constant_op.constant(1.),)) for t in nest.flatten(per_replica, expand_composites=True): self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly()) @combinations.generate(combinations.combine(mode=["eager"])) def testDoesNotTriggerFunctionTracing(self): traces = [] @def_function.function def f(x): traces.append(None) # Only happens on trace. return x per_replica = values_lib.PerReplica((constant_op.constant(1.),)) # Trace once. f(per_replica) self.assertNotEmpty(traces) del traces[:] per_replica_spec = per_replica._type_spec for _ in range(5): vals = per_replica_spec._to_components(per_replica) vals = [v * 2 for v in vals] per_replica = per_replica_spec._from_components(vals) output = f(per_replica) self.assertIsInstance(output, values_lib.PerReplica) self.assertAllEqual(output._values, per_replica._values) self.assertEmpty(traces) # Make sure we're not re-tracing `f`. @combinations.generate(combinations.combine(mode=["eager"])) def testFunctionCanReturnPerReplica(self): f = def_function.function(lambda x: x) x = values_lib.PerReplica((constant_op.constant(1.),)) y = f(x) self.assertIsNot(x, y) nest.map_structure(self.assertAllEqual, x, y, expand_composites=True) self.assertEqual(x._type_spec, y._type_spec) @test_util.run_in_graph_and_eager_modes def testCondWithTensorValues(self): per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),)) per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),)) condition = array_ops.placeholder_with_default(True, []) result = control_flow_ops.cond( condition, lambda: per_replica_1, lambda: per_replica_2) self.assertLen(result.values, 1) self.assertAllEqual(result.values[0], "a") @test_util.run_in_graph_and_eager_modes def testCondWithValuesConvertibleToTensor(self): per_replica_1 = values_lib.PerReplica(("a",)) per_replica_2 = values_lib.PerReplica(("b",)) condition = array_ops.placeholder_with_default(True, []) result = control_flow_ops.cond( condition, lambda: per_replica_1, lambda: per_replica_2) self.assertLen(result.values, 1) self.assertAllEqual(result.values[0], "a") @test_util.build_as_function_and_v1_graph def testCondWithValuesNotConvertibleToTensor(self): per_replica_1 = values_lib.PerReplica(({"a"},)) per_replica_2 = values_lib.PerReplica(({"b", "c"},)) condition = array_ops.placeholder(dtypes.bool, []) with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"): control_flow_ops.cond( condition, lambda: per_replica_1, lambda: per_replica_2) def _make_index_slices(values, indices, dense_shape=None): if dense_shape: dense_shape = array_ops.identity(dense_shape) return indexed_slices.IndexedSlices( array_ops.identity(values), array_ops.identity(indices), dense_shape) if __name__ == "__main__": ds_test_util.main()