All strategies are supported except for CentralStorageStrategy and ParameterServerStrategy. This CL also removes the CompositeTensor superclass from Generator. Generator is a wrapper around tf.Variable, and because tf.Variable is not a CompositeTensor, Generator can't be a CompositeTensor in theory. Previously we made it a CompositeTensor by returning Variable.handle, but that breaks down when the variable is a DistributedVariable (in cross-replica context). PiperOrigin-RevId: 350851648 Change-Id: I5f4d77ddb990557fcc9c7336987203ecdaec5b9a
256 lines
8.9 KiB
Python
256 lines
8.9 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests tf.random.Generator with distribution strategies."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
import os
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from tensorflow.python.compat import v2_compat
|
|
from tensorflow.python.distribute import combinations as ds_combinations
|
|
from tensorflow.python.distribute import multi_process_runner
|
|
from tensorflow.python.distribute import strategy_combinations
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import config
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import test_combinations as combinations
|
|
from tensorflow.python.module import module
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import stateful_random_ops as rng
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.saved_model import load
|
|
from tensorflow.python.saved_model import save
|
|
from tensorflow.python.training.tracking import util as tracking_util
|
|
|
|
|
|
def get_num_local_replicas(strat, values=None):
|
|
strat_name = type(strat).__name__
|
|
if "MultiWorker" in strat_name or "CollectiveAllReduceStrategy" in strat_name:
|
|
if values is None:
|
|
values = strat.run(lambda: constant_op.constant(0))
|
|
values = strat.experimental_local_results(values)
|
|
return len(values)
|
|
else:
|
|
return strat.num_replicas_in_sync
|
|
|
|
|
|
all_strategies = (strategy_combinations.all_strategies +
|
|
strategy_combinations.multiworker_strategies)
|
|
|
|
|
|
class GeneratorTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
super(GeneratorTest, self).setUp()
|
|
v2_compat.enable_v2_behavior()
|
|
config.set_soft_device_placement(False)
|
|
|
|
def assertAllDifferent(self, tensors):
|
|
"""Checks that there are no duplicate elements anywhere among the tensors.
|
|
|
|
Args:
|
|
tensors: a list of tensors. They can have different shapes.
|
|
"""
|
|
values = [array_ops.reshape(t, shape=[-1]) for t in tensors]
|
|
values = array_ops.concat(values, axis=0)
|
|
values = self.evaluate(values)
|
|
values = values.tolist()
|
|
self.assertAllEqual(len(values), len(set(values)))
|
|
|
|
@ds_combinations.generate(
|
|
combinations.combine(
|
|
strat=all_strategies,
|
|
mode=["eager"]))
|
|
def testCrossReplica(self, strat):
|
|
"""Tests that RNG can be properly advanced in cross-replica context."""
|
|
strat_name = type(strat).__name__
|
|
if "CentralStorage" in strat_name:
|
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
|
def read_values(dv):
|
|
return [v.read_value() for v in strat.experimental_local_results(dv)]
|
|
with strat.scope():
|
|
g = rng.Generator.from_seed(1)
|
|
s1 = read_values(g.state)
|
|
g.normal([3])
|
|
g.skip(4)
|
|
s2 = read_values(g.state)
|
|
self.assertNotAllEqual(s1[0], s2[0])
|
|
self.assertEqual(len(s1), len(s2))
|
|
for i in range(1, len(s1)):
|
|
self.assertAllEqual(s1[0], s1[i])
|
|
self.assertAllEqual(s2[0], s2[i])
|
|
|
|
@ds_combinations.generate(
|
|
combinations.combine(
|
|
strat=all_strategies,
|
|
mode=["eager"],
|
|
seeded=[True, False]))
|
|
def testDistStrat(self, strat, seeded):
|
|
"""Tests RNG with distribution strategies."""
|
|
strat_name = type(strat).__name__
|
|
if "CentralStorage" in strat_name:
|
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
|
creators = {
|
|
True: functools.partial(rng.Generator.from_seed, 1234),
|
|
False: rng.Generator.from_non_deterministic_state,
|
|
}
|
|
shape = [3, 4]
|
|
dtype = dtypes.int32
|
|
creator = creators[seeded]
|
|
with strat.scope():
|
|
gen = creator()
|
|
@def_function.function
|
|
def f():
|
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t = array_ops.stack([t1, t2])
|
|
return t
|
|
results = strat.run(f)
|
|
values = strat.experimental_local_results(results)
|
|
n = get_num_local_replicas(strat, values)
|
|
self.assertAllEqual(n, len(values))
|
|
self.assertAllDifferent(values)
|
|
|
|
@ds_combinations.generate(
|
|
combinations.combine(
|
|
strat=all_strategies,
|
|
mode=["eager"]))
|
|
def testDistVarAsTFFunArg(self, strat):
|
|
"""Tests that RNG with dist variables can be used as tf.function's arg."""
|
|
strat_name = type(strat).__name__
|
|
if "CentralStorage" in strat_name:
|
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
|
shape = [3, 4]
|
|
dtype = dtypes.int32
|
|
with strat.scope():
|
|
gen = rng.Generator.from_seed(1234)
|
|
@def_function.function
|
|
def f(gen): # the main focus
|
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t = array_ops.stack([t1, t2])
|
|
return t
|
|
@def_function.function # required by TPUStrategy.run
|
|
def g():
|
|
return f(gen)
|
|
for _ in range(2):
|
|
results = strat.run(g)
|
|
values = strat.experimental_local_results(results)
|
|
n = get_num_local_replicas(strat, values)
|
|
self.assertAllEqual(n, len(values))
|
|
self.assertAllDifferent(values)
|
|
|
|
@ds_combinations.generate(
|
|
combinations.combine(
|
|
strat1=strategy_combinations.all_strategies,
|
|
strat2=strategy_combinations.all_strategies,
|
|
mode=["eager"]) +
|
|
combinations.combine(
|
|
strat1=strategy_combinations.multiworker_strategies,
|
|
strat2=[None],
|
|
mode=["eager"]))
|
|
def testDistStratRestore(self, strat1, strat2):
|
|
"""Tests checkpointing and restoring (to possibly different #replicas)."""
|
|
if strat2 is None:
|
|
strat2 = strat1
|
|
strat1_name = type(strat1).__name__
|
|
strat2_name = type(strat2).__name__
|
|
if "CentralStorage" in strat1_name or "CentralStorage" in strat2_name:
|
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
|
if "Default" in strat1_name or "Default" in strat2_name:
|
|
self.skipTest(
|
|
"We don't guarantee consistency between strategy and no-strategy.")
|
|
fname = os.path.join(self.get_temp_dir(), "checkpoint")
|
|
def uniform(strat, g):
|
|
@def_function.function
|
|
def f():
|
|
return g.uniform_full_int([3], dtype=dtypes.int32)
|
|
result = strat.run(f)
|
|
return strat.experimental_local_results(result)
|
|
with strat1.scope():
|
|
g1 = rng.Generator.from_seed(1)
|
|
with strat2.scope():
|
|
g2 = rng.Generator.from_seed(10)
|
|
cp1 = tracking_util.Checkpoint(g=g1)
|
|
cp2 = tracking_util.Checkpoint(g=g2)
|
|
def write_restore_compare():
|
|
cp1.write(fname)
|
|
r1 = uniform(strat1, g1)
|
|
cp2.restore(fname)
|
|
r2 = uniform(strat2, g2)
|
|
# Tests that overlapping replicas are properly restored.
|
|
n1 = get_num_local_replicas(strat1)
|
|
n2 = get_num_local_replicas(strat2)
|
|
n = min(n1, n2)
|
|
self.assertAllEqual(r1[:n], r2[:n])
|
|
# Run multiple times so that cp1.write is called in various RNG states
|
|
for _ in range(2):
|
|
write_restore_compare()
|
|
|
|
@ds_combinations.generate(
|
|
combinations.combine(
|
|
strat=strategy_combinations.all_strategies,
|
|
mode=["eager"],
|
|
is_save_in_scope=[True, False]))
|
|
def testSavedModel(self, strat, is_save_in_scope):
|
|
strat_name = type(strat).__name__
|
|
if "CentralStorage" in strat_name:
|
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
|
|
|
class CustomModule(module.Module):
|
|
|
|
def __init__(self):
|
|
super(CustomModule, self).__init__()
|
|
self.g = rng.Generator.from_seed(0)
|
|
|
|
@def_function.function
|
|
def __call__(self):
|
|
return self.g.state
|
|
|
|
@def_function.function
|
|
def mutate(self):
|
|
self.g.normal([])
|
|
|
|
with strat.scope():
|
|
m = CustomModule()
|
|
m.mutate()
|
|
state_before = m()
|
|
path = os.path.join(self.get_temp_dir(), "saved_model")
|
|
if is_save_in_scope:
|
|
with strat.scope():
|
|
save.save(m, path)
|
|
else:
|
|
save.save(m, path)
|
|
with strat.scope():
|
|
m.mutate()
|
|
state_before_2 = m()
|
|
|
|
imported = load.load(path)
|
|
state_after = imported()
|
|
self.assertAllEqual(state_before, state_after)
|
|
imported.mutate()
|
|
state_after_2 = imported()
|
|
self.assertAllEqual(state_before_2, state_after_2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
multi_process_runner.test_main()
|