Allows creating tf.random.Generator under distribution-strategy scopes. Different replicas will get different random-number streams.
All strategies are supported except for CentralStorageStrategy and ParameterServerStrategy. This CL also removes the CompositeTensor superclass from Generator. Generator is a wrapper around tf.Variable, and because tf.Variable is not a CompositeTensor, Generator can't be a CompositeTensor in theory. Previously we made it a CompositeTensor by returning Variable.handle, but that breaks down when the variable is a DistributedVariable (in cross-replica context). PiperOrigin-RevId: 350851648 Change-Id: I5f4d77ddb990557fcc9c7336987203ecdaec5b9a
This commit is contained in:
parent
70502be4a5
commit
587ac71f68
@ -24,6 +24,7 @@
|
|||||||
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
||||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||||
* `tf.keras.metrics.AUC` now support logit predictions.
|
* `tf.keras.metrics.AUC` now support logit predictions.
|
||||||
|
* Creating `tf.random.Generator` under `tf.distribute.Strategy` scopes is now allowed (except for `tf.distribute.experimental.CentralStorageStrategy` and `tf.distribute.experimental.ParameterServerStrategy`). Different replicas will get different random-number streams.
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
@ -587,6 +587,22 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
distribute_py_test(
|
||||||
|
name = "random_generator_test",
|
||||||
|
srcs = ["random_generator_test.py"],
|
||||||
|
main = "random_generator_test.py",
|
||||||
|
shard_count = 10,
|
||||||
|
tags = [
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:stateful_random_ops",
|
||||||
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/distribute:combinations",
|
||||||
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tpu_py_test(
|
tpu_py_test(
|
||||||
name = "tpu_strategy_test",
|
name = "tpu_strategy_test",
|
||||||
srcs = ["tpu_strategy_test.py"],
|
srcs = ["tpu_strategy_test.py"],
|
||||||
|
255
tensorflow/python/distribute/random_generator_test.py
Normal file
255
tensorflow/python/distribute/random_generator_test.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests tf.random.Generator with distribution strategies."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.compat import v2_compat
|
||||||
|
from tensorflow.python.distribute import combinations as ds_combinations
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_combinations as combinations
|
||||||
|
from tensorflow.python.module import module
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import stateful_random_ops as rng
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.saved_model import load
|
||||||
|
from tensorflow.python.saved_model import save
|
||||||
|
from tensorflow.python.training.tracking import util as tracking_util
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_local_replicas(strat, values=None):
|
||||||
|
strat_name = type(strat).__name__
|
||||||
|
if "MultiWorker" in strat_name or "CollectiveAllReduceStrategy" in strat_name:
|
||||||
|
if values is None:
|
||||||
|
values = strat.run(lambda: constant_op.constant(0))
|
||||||
|
values = strat.experimental_local_results(values)
|
||||||
|
return len(values)
|
||||||
|
else:
|
||||||
|
return strat.num_replicas_in_sync
|
||||||
|
|
||||||
|
|
||||||
|
all_strategies = (strategy_combinations.all_strategies +
|
||||||
|
strategy_combinations.multiworker_strategies)
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(GeneratorTest, self).setUp()
|
||||||
|
v2_compat.enable_v2_behavior()
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
|
|
||||||
|
def assertAllDifferent(self, tensors):
|
||||||
|
"""Checks that there are no duplicate elements anywhere among the tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: a list of tensors. They can have different shapes.
|
||||||
|
"""
|
||||||
|
values = [array_ops.reshape(t, shape=[-1]) for t in tensors]
|
||||||
|
values = array_ops.concat(values, axis=0)
|
||||||
|
values = self.evaluate(values)
|
||||||
|
values = values.tolist()
|
||||||
|
self.assertAllEqual(len(values), len(set(values)))
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strat=all_strategies,
|
||||||
|
mode=["eager"]))
|
||||||
|
def testCrossReplica(self, strat):
|
||||||
|
"""Tests that RNG can be properly advanced in cross-replica context."""
|
||||||
|
strat_name = type(strat).__name__
|
||||||
|
if "CentralStorage" in strat_name:
|
||||||
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||||
|
def read_values(dv):
|
||||||
|
return [v.read_value() for v in strat.experimental_local_results(dv)]
|
||||||
|
with strat.scope():
|
||||||
|
g = rng.Generator.from_seed(1)
|
||||||
|
s1 = read_values(g.state)
|
||||||
|
g.normal([3])
|
||||||
|
g.skip(4)
|
||||||
|
s2 = read_values(g.state)
|
||||||
|
self.assertNotAllEqual(s1[0], s2[0])
|
||||||
|
self.assertEqual(len(s1), len(s2))
|
||||||
|
for i in range(1, len(s1)):
|
||||||
|
self.assertAllEqual(s1[0], s1[i])
|
||||||
|
self.assertAllEqual(s2[0], s2[i])
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strat=all_strategies,
|
||||||
|
mode=["eager"],
|
||||||
|
seeded=[True, False]))
|
||||||
|
def testDistStrat(self, strat, seeded):
|
||||||
|
"""Tests RNG with distribution strategies."""
|
||||||
|
strat_name = type(strat).__name__
|
||||||
|
if "CentralStorage" in strat_name:
|
||||||
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||||
|
creators = {
|
||||||
|
True: functools.partial(rng.Generator.from_seed, 1234),
|
||||||
|
False: rng.Generator.from_non_deterministic_state,
|
||||||
|
}
|
||||||
|
shape = [3, 4]
|
||||||
|
dtype = dtypes.int32
|
||||||
|
creator = creators[seeded]
|
||||||
|
with strat.scope():
|
||||||
|
gen = creator()
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
t = array_ops.stack([t1, t2])
|
||||||
|
return t
|
||||||
|
results = strat.run(f)
|
||||||
|
values = strat.experimental_local_results(results)
|
||||||
|
n = get_num_local_replicas(strat, values)
|
||||||
|
self.assertAllEqual(n, len(values))
|
||||||
|
self.assertAllDifferent(values)
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strat=all_strategies,
|
||||||
|
mode=["eager"]))
|
||||||
|
def testDistVarAsTFFunArg(self, strat):
|
||||||
|
"""Tests that RNG with dist variables can be used as tf.function's arg."""
|
||||||
|
strat_name = type(strat).__name__
|
||||||
|
if "CentralStorage" in strat_name:
|
||||||
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||||
|
shape = [3, 4]
|
||||||
|
dtype = dtypes.int32
|
||||||
|
with strat.scope():
|
||||||
|
gen = rng.Generator.from_seed(1234)
|
||||||
|
@def_function.function
|
||||||
|
def f(gen): # the main focus
|
||||||
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
t = array_ops.stack([t1, t2])
|
||||||
|
return t
|
||||||
|
@def_function.function # required by TPUStrategy.run
|
||||||
|
def g():
|
||||||
|
return f(gen)
|
||||||
|
for _ in range(2):
|
||||||
|
results = strat.run(g)
|
||||||
|
values = strat.experimental_local_results(results)
|
||||||
|
n = get_num_local_replicas(strat, values)
|
||||||
|
self.assertAllEqual(n, len(values))
|
||||||
|
self.assertAllDifferent(values)
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strat1=strategy_combinations.all_strategies,
|
||||||
|
strat2=strategy_combinations.all_strategies,
|
||||||
|
mode=["eager"]) +
|
||||||
|
combinations.combine(
|
||||||
|
strat1=strategy_combinations.multiworker_strategies,
|
||||||
|
strat2=[None],
|
||||||
|
mode=["eager"]))
|
||||||
|
def testDistStratRestore(self, strat1, strat2):
|
||||||
|
"""Tests checkpointing and restoring (to possibly different #replicas)."""
|
||||||
|
if strat2 is None:
|
||||||
|
strat2 = strat1
|
||||||
|
strat1_name = type(strat1).__name__
|
||||||
|
strat2_name = type(strat2).__name__
|
||||||
|
if "CentralStorage" in strat1_name or "CentralStorage" in strat2_name:
|
||||||
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||||
|
if "Default" in strat1_name or "Default" in strat2_name:
|
||||||
|
self.skipTest(
|
||||||
|
"We don't guarantee consistency between strategy and no-strategy.")
|
||||||
|
fname = os.path.join(self.get_temp_dir(), "checkpoint")
|
||||||
|
def uniform(strat, g):
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
return g.uniform_full_int([3], dtype=dtypes.int32)
|
||||||
|
result = strat.run(f)
|
||||||
|
return strat.experimental_local_results(result)
|
||||||
|
with strat1.scope():
|
||||||
|
g1 = rng.Generator.from_seed(1)
|
||||||
|
with strat2.scope():
|
||||||
|
g2 = rng.Generator.from_seed(10)
|
||||||
|
cp1 = tracking_util.Checkpoint(g=g1)
|
||||||
|
cp2 = tracking_util.Checkpoint(g=g2)
|
||||||
|
def write_restore_compare():
|
||||||
|
cp1.write(fname)
|
||||||
|
r1 = uniform(strat1, g1)
|
||||||
|
cp2.restore(fname)
|
||||||
|
r2 = uniform(strat2, g2)
|
||||||
|
# Tests that overlapping replicas are properly restored.
|
||||||
|
n1 = get_num_local_replicas(strat1)
|
||||||
|
n2 = get_num_local_replicas(strat2)
|
||||||
|
n = min(n1, n2)
|
||||||
|
self.assertAllEqual(r1[:n], r2[:n])
|
||||||
|
# Run multiple times so that cp1.write is called in various RNG states
|
||||||
|
for _ in range(2):
|
||||||
|
write_restore_compare()
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
strat=strategy_combinations.all_strategies,
|
||||||
|
mode=["eager"],
|
||||||
|
is_save_in_scope=[True, False]))
|
||||||
|
def testSavedModel(self, strat, is_save_in_scope):
|
||||||
|
strat_name = type(strat).__name__
|
||||||
|
if "CentralStorage" in strat_name:
|
||||||
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||||
|
|
||||||
|
class CustomModule(module.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CustomModule, self).__init__()
|
||||||
|
self.g = rng.Generator.from_seed(0)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def __call__(self):
|
||||||
|
return self.g.state
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def mutate(self):
|
||||||
|
self.g.normal([])
|
||||||
|
|
||||||
|
with strat.scope():
|
||||||
|
m = CustomModule()
|
||||||
|
m.mutate()
|
||||||
|
state_before = m()
|
||||||
|
path = os.path.join(self.get_temp_dir(), "saved_model")
|
||||||
|
if is_save_in_scope:
|
||||||
|
with strat.scope():
|
||||||
|
save.save(m, path)
|
||||||
|
else:
|
||||||
|
save.save(m, path)
|
||||||
|
with strat.scope():
|
||||||
|
m.mutate()
|
||||||
|
state_before_2 = m()
|
||||||
|
|
||||||
|
imported = load.load(path)
|
||||||
|
state_after = imported()
|
||||||
|
self.assertAllEqual(state_before, state_after)
|
||||||
|
imported.mutate()
|
||||||
|
state_after_2 = imported()
|
||||||
|
self.assertAllEqual(state_before_2, state_after_2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
multi_process_runner.test_main()
|
@ -39,7 +39,6 @@ from tensorflow.python.ops import image_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import stateful_random_ops
|
from tensorflow.python.ops import stateful_random_ops
|
||||||
from tensorflow.python.ops import stateless_random_ops
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
from tensorflow.python.ops import variables
|
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
ResizeMethod = image_ops.ResizeMethod
|
ResizeMethod = image_ops.ResizeMethod
|
||||||
@ -1356,44 +1355,20 @@ class RandomWidth(PreprocessingLayer):
|
|||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/147877541, b/158339556): This class is added to temporarily enable
|
|
||||||
# creating generators within distribution strategies. Remove it when the proper
|
|
||||||
# API is in place.
|
|
||||||
class _RandomGenerator(stateful_random_ops.Generator):
|
|
||||||
"""A subclass that allows creation inside distribution strategies.
|
|
||||||
|
|
||||||
This is a temporary solution to allow creating tf.random.Generator inside
|
|
||||||
distribution strategies. It will be removed when proper API is in place.
|
|
||||||
|
|
||||||
All replicas will have the same RNG state and generate the same random
|
|
||||||
numbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO(b/157995497): Temporarily use primary variable handle inside cross
|
|
||||||
# replica context.
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
"""The internal state of the RNG."""
|
|
||||||
state_var = self._state_var
|
|
||||||
try:
|
|
||||||
_ = getattr(state_var, 'handle')
|
|
||||||
return state_var
|
|
||||||
except ValueError:
|
|
||||||
return state_var.values[0]
|
|
||||||
|
|
||||||
def _create_variable(self, *args, **kwargs):
|
|
||||||
# This function does the same thing as the base class's namesake, except
|
|
||||||
# that it skips the distribution-strategy check. When we are inside a
|
|
||||||
# distribution-strategy scope, variables.Variable will pick a proper
|
|
||||||
# variable class (e.g. MirroredVariable).
|
|
||||||
return variables.Variable(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def make_generator(seed=None):
|
def make_generator(seed=None):
|
||||||
|
"""Creates a random generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: the seed to initialize the generator. If None, the generator will be
|
||||||
|
initialized non-deterministically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator object.
|
||||||
|
"""
|
||||||
if seed:
|
if seed:
|
||||||
return _RandomGenerator.from_seed(seed)
|
return stateful_random_ops.Generator.from_seed(seed)
|
||||||
else:
|
else:
|
||||||
return _RandomGenerator.from_non_deterministic_state()
|
return stateful_random_ops.Generator.from_non_deterministic_state()
|
||||||
|
|
||||||
|
|
||||||
def get_interpolation(interpolation):
|
def get_interpolation(interpolation):
|
||||||
|
@ -41,6 +41,8 @@ class ImagePreprocessingDistributionTest(
|
|||||||
preprocessing_test_utils.PreprocessingLayerTest):
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
|
||||||
def test_distribution(self, distribution):
|
def test_distribution(self, distribution):
|
||||||
|
if "CentralStorage" in type(distribution).__name__:
|
||||||
|
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||||
# TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
|
# TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
|
||||||
np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
|
np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
|
||||||
image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(
|
image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(
|
||||||
|
@ -1082,8 +1082,6 @@ class RandomRotationTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def test_distribution_strategy(self):
|
def test_distribution_strategy(self):
|
||||||
"""Tests that RandomRotation can be created within distribution strategies.
|
"""Tests that RandomRotation can be created within distribution strategies.
|
||||||
|
|
||||||
And that replicas got the same random result.
|
|
||||||
"""
|
"""
|
||||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||||
with testing_utils.use_gpu():
|
with testing_utils.use_gpu():
|
||||||
@ -1093,7 +1091,6 @@ class RandomRotationTest(keras_parameterized.TestCase):
|
|||||||
output = strat.run(lambda: layer(input_images, training=True))
|
output = strat.run(lambda: layer(input_images, training=True))
|
||||||
values = output.values
|
values = output.values
|
||||||
self.assertAllEqual(2, len(values))
|
self.assertAllEqual(2, len(values))
|
||||||
self.assertAllClose(values[0], values[1], rtol=1e-5)
|
|
||||||
|
|
||||||
@testing_utils.run_v2_only
|
@testing_utils.run_v2_only
|
||||||
def test_config_with_custom_name(self):
|
def test_config_with_custom_name(self):
|
||||||
|
@ -25,17 +25,14 @@ import six
|
|||||||
|
|
||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
|
from tensorflow.python.distribute import values_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
|
||||||
from tensorflow.python.framework import type_spec
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -235,40 +232,15 @@ def _convert_to_state_tensor(t):
|
|||||||
return ops.convert_to_tensor(t, dtype=STATE_TYPE)
|
return ops.convert_to_tensor(t, dtype=STATE_TYPE)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorSpec(type_spec.TypeSpec):
|
def get_replica_id():
|
||||||
"""TypeSpec for Generator."""
|
rctx = ds_context.get_replica_context()
|
||||||
|
if rctx is None:
|
||||||
def __init__(self, shape=None, dtype=None, alg=None):
|
return None
|
||||||
self.shape = shape
|
return rctx.replica_id_in_sync_group
|
||||||
self.dtype = dtype
|
|
||||||
self.alg = alg
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _component_specs(self):
|
|
||||||
return (tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource),)
|
|
||||||
|
|
||||||
def _to_components(self, value):
|
|
||||||
return (value.state.handle,)
|
|
||||||
|
|
||||||
def _from_components(self, components):
|
|
||||||
assert isinstance(components, (list, tuple))
|
|
||||||
assert len(components) == 1
|
|
||||||
handle = components[0]
|
|
||||||
state_var = resource_variable_ops.BaseResourceVariable(
|
|
||||||
handle=handle, shape=self.shape, dtype=self.dtype,
|
|
||||||
trainable=False, handle_deleter=object(), handle_name="RNGVar")
|
|
||||||
return Generator(state=state_var, alg=self.alg)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def value_type(self):
|
|
||||||
return Generator
|
|
||||||
|
|
||||||
def _serialize(self):
|
|
||||||
return (self.shape, self.dtype, self.alg)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("random.Generator", "random.experimental.Generator")
|
@tf_export("random.Generator", "random.experimental.Generator")
|
||||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
class Generator(tracking.AutoTrackable):
|
||||||
"""Random-number generator.
|
"""Random-number generator.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -320,8 +292,137 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
>>> g = tf.random.get_global_generator()
|
>>> g = tf.random.get_global_generator()
|
||||||
>>> g.normal(shape=(2, 3))
|
>>> g.normal(shape=(2, 3))
|
||||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||||
|
|
||||||
|
When creating a generator inside a `tf.distribute.Strategy` scope, each
|
||||||
|
replica will get a different stream of random numbers.
|
||||||
|
|
||||||
|
Note: `tf.distribute.experimental.CentralStorageStrategy` and
|
||||||
|
`tf.distribute.experimental.ParameterServerStrategy` are not supported yet.
|
||||||
|
|
||||||
|
For example, in this code:
|
||||||
|
|
||||||
|
```
|
||||||
|
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
||||||
|
with strat.scope():
|
||||||
|
g = tf.random.Generator.from_seed(1)
|
||||||
|
def f():
|
||||||
|
return g.normal([])
|
||||||
|
results = strat.run(f).values
|
||||||
|
```
|
||||||
|
|
||||||
|
`results[0]` and `results[1]` will have different values.
|
||||||
|
|
||||||
|
If the generator is seeded (e.g. created via `Generator.from_seed`), the
|
||||||
|
random numbers will be determined by the seed, even though different replicas
|
||||||
|
get different numbers. One can think of a random number generated on a
|
||||||
|
replica as a hash of the replica ID and a "master" random number that may be
|
||||||
|
common to all replicas. Hence, the whole system is still deterministic.
|
||||||
|
|
||||||
|
(Note that the random numbers on different replicas are not correlated, even
|
||||||
|
if they are deterministically determined by the same seed. They are not
|
||||||
|
correlated in the sense that no matter what statistics one calculates on them,
|
||||||
|
there won't be any discernable correlation.)
|
||||||
|
|
||||||
|
Generators can be freely saved and restored using `tf.train.Checkpoint`. The
|
||||||
|
checkpoint can be restored in a distribution strategy with a different number
|
||||||
|
of replicas than the original strategy. If a replica ID is present in both the
|
||||||
|
original and the new distribution strategy, its state will be properly
|
||||||
|
restored (i.e. the random-number stream from the restored point will be the
|
||||||
|
same as that from the saving point) unless the replicas have already diverged
|
||||||
|
in their RNG call traces before saving (e.g. one replica has made one RNG call
|
||||||
|
while another has made two RNG calls). We don't have such guarantee if the
|
||||||
|
generator is saved in a strategy scope and restored outside of any strategy
|
||||||
|
scope, or vice versa.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state(cls, state, alg):
|
||||||
|
"""Creates a generator from a state.
|
||||||
|
|
||||||
|
See `__init__` for description of `state` and `alg`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: the new state.
|
||||||
|
alg: the RNG algorithm.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new generator.
|
||||||
|
"""
|
||||||
|
return cls(alg=alg, state=state)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_seed(cls, seed, alg=None):
|
||||||
|
"""Creates a generator from a seed.
|
||||||
|
|
||||||
|
A seed is a 1024-bit unsigned integer represented either as a Python
|
||||||
|
integer or a vector of integers. Seeds shorter than 1024-bit will be
|
||||||
|
padded. The padding, the internal structure of a seed and the way a seed
|
||||||
|
is converted to a state are all opaque (unspecified). The only semantics
|
||||||
|
specification of seeds is that two different seeds are likely to produce
|
||||||
|
two independent generators (but no guarantee).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: the seed for the RNG.
|
||||||
|
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||||
|
`__init__` for its possible values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new generator.
|
||||||
|
"""
|
||||||
|
if alg is None:
|
||||||
|
# TODO(b/170668986): more sophisticated algorithm selection
|
||||||
|
alg = DEFAULT_ALGORITHM
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
|
state = create_rng_state(seed, alg)
|
||||||
|
return cls(state=state, alg=alg)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_non_deterministic_state(cls, alg=None):
|
||||||
|
"""Creates a generator by non-deterministically initializing its state.
|
||||||
|
|
||||||
|
The source of the non-determinism will be platform- and time-dependent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||||
|
`__init__` for its possible values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new generator.
|
||||||
|
"""
|
||||||
|
if alg is None:
|
||||||
|
# TODO(b/170668986): more sophisticated algorithm selection
|
||||||
|
alg = DEFAULT_ALGORITHM
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
|
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
||||||
|
dtype=SEED_TYPE)
|
||||||
|
return cls(state=state, alg=alg)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_key_counter(cls, key, counter, alg):
|
||||||
|
"""Creates a generator from a key and a counter.
|
||||||
|
|
||||||
|
This constructor only applies if the algorithm is a counter-based algorithm.
|
||||||
|
See method `key` for the meaning of "key" and "counter".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: the key for the RNG, a scalar of type STATE_TYPE.
|
||||||
|
counter: a vector of dtype STATE_TYPE representing the initial counter for
|
||||||
|
the RNG, whose length is algorithm-specific.,
|
||||||
|
alg: the RNG algorithm. If None, it will be auto-selected. See
|
||||||
|
`__init__` for its possible values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new generator.
|
||||||
|
"""
|
||||||
|
counter = _convert_to_state_tensor(counter)
|
||||||
|
key = _convert_to_state_tensor(key)
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
|
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
||||||
|
key.shape.assert_is_compatible_with([])
|
||||||
|
key = array_ops.reshape(key, [1])
|
||||||
|
state = array_ops.concat([counter, key], 0)
|
||||||
|
return cls(state=state, alg=alg)
|
||||||
|
|
||||||
def __init__(self, copy_from=None, state=None, alg=None):
|
def __init__(self, copy_from=None, state=None, alg=None):
|
||||||
"""Creates a generator.
|
"""Creates a generator.
|
||||||
|
|
||||||
@ -346,24 +447,26 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
The string names `"philox"` and `"threefry"` can also be used.
|
The string names `"philox"` and `"threefry"` can also be used.
|
||||||
Note `PHILOX` guarantees the same numbers are produced (given
|
Note `PHILOX` guarantees the same numbers are produced (given
|
||||||
the same random state) across all architectures (CPU, GPU, XLA etc).
|
the same random state) across all architectures (CPU, GPU, XLA etc).
|
||||||
|
|
||||||
Throws:
|
|
||||||
ValueError: if the generator is created inside a synchronous
|
|
||||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
|
||||||
because there is ambiguity on how to replicate a generator (e.g. should
|
|
||||||
it be copied so such each replica will get the same random numbers, or
|
|
||||||
should it be "split" into different generators that generate
|
|
||||||
different random numbers).
|
|
||||||
"""
|
"""
|
||||||
|
# TODO(b/175072242): Remove distribution-strategy dependencies in this file.
|
||||||
|
if ds_context.has_strategy():
|
||||||
|
self._distribution_strategy = ds_context.get_strategy()
|
||||||
|
else:
|
||||||
|
self._distribution_strategy = None
|
||||||
if copy_from is not None:
|
if copy_from is not None:
|
||||||
# All other arguments should be None
|
# All other arguments should be None
|
||||||
assert (alg or state) is None
|
assert (alg or state) is None
|
||||||
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
|
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
|
||||||
trainable=False)
|
trainable=False)
|
||||||
self._alg = copy_from.algorithm
|
self._alg = copy_from.algorithm
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert alg is not None and state is not None
|
assert alg is not None and state is not None
|
||||||
|
if ds_context.has_strategy():
|
||||||
|
strat_name = type(ds_context.get_strategy()).__name__
|
||||||
|
# TODO(b/174610856): Support CentralStorageStrategy and
|
||||||
|
# ParameterServerStrategy.
|
||||||
|
if "CentralStorage" in strat_name or "ParameterServer" in strat_name:
|
||||||
|
raise ValueError("%s is not supported yet" % strat_name)
|
||||||
alg = _convert_alg_to_int(alg)
|
alg = _convert_alg_to_int(alg)
|
||||||
if isinstance(state, variables.Variable):
|
if isinstance(state, variables.Variable):
|
||||||
_check_state_shape(state.shape, alg)
|
_check_state_shape(state.shape, alg)
|
||||||
@ -376,7 +479,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
self._alg = alg
|
self._alg = alg
|
||||||
|
|
||||||
def _create_variable(self, *args, **kwargs):
|
def _create_variable(self, *args, **kwargs):
|
||||||
"""Creates a variable, and check that it's not MirroredVariable.
|
"""Creates a variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: positional arguments passed along to `variables.Variable.
|
*args: positional arguments passed along to `variables.Variable.
|
||||||
@ -385,135 +488,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
Returns:
|
Returns:
|
||||||
The created variable.
|
The created variable.
|
||||||
"""
|
"""
|
||||||
if ds_context.has_strategy():
|
return variables.Variable(*args, **kwargs)
|
||||||
raise ValueError(
|
|
||||||
"Creating a generator within a strategy scope is disallowed, because "
|
|
||||||
"there is ambiguity on how to replicate a generator (e.g. should it "
|
|
||||||
"be copied so that each replica gets the same random numbers, or "
|
|
||||||
"'split' so that each replica gets different random numbers).")
|
|
||||||
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
|
|
||||||
var = variables.Variable(*args, **kwargs)
|
|
||||||
return var
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_state(cls, state, alg):
|
|
||||||
"""Creates a generator from a state.
|
|
||||||
|
|
||||||
See `__init__` for description of `state` and `alg`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: the new state.
|
|
||||||
alg: the RNG algorithm.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The new generator.
|
|
||||||
|
|
||||||
Throws:
|
|
||||||
ValueError: if the generator is created inside a synchronous
|
|
||||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
|
||||||
because there is ambiguity on how to replicate a generator (e.g. should
|
|
||||||
it be copied so such each replica will get the same random numbers, or
|
|
||||||
should it be "split" into different generators that generate
|
|
||||||
different random numbers).
|
|
||||||
"""
|
|
||||||
return cls(alg=alg, state=state)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_seed(cls, seed, alg=None):
|
|
||||||
"""Creates a generator from a seed.
|
|
||||||
|
|
||||||
A seed is a 1024-bit unsigned integer represented either as a Python
|
|
||||||
integer or a vector of integers. Seeds shorter than 1024-bit will be
|
|
||||||
padded. The padding, the internal structure of a seed and the way a seed
|
|
||||||
is converted to a state are all opaque (unspecified). The only semantics
|
|
||||||
specification of seeds is that two different seeds are likely to produce
|
|
||||||
two independent generators (but no guarantee).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: the seed for the RNG.
|
|
||||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
|
||||||
`__init__` for its possible values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The new generator.
|
|
||||||
|
|
||||||
Throws:
|
|
||||||
ValueError: if the generator is created inside a synchronous
|
|
||||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
|
||||||
because there is ambiguity on how to replicate a generator (e.g. should
|
|
||||||
it be copied so such each replica will get the same random numbers, or
|
|
||||||
should it be "split" into different generators that generate
|
|
||||||
different random numbers).
|
|
||||||
"""
|
|
||||||
if alg is None:
|
|
||||||
# TODO(wangpeng): more sophisticated algorithm selection
|
|
||||||
alg = DEFAULT_ALGORITHM
|
|
||||||
alg = _convert_alg_to_int(alg)
|
|
||||||
state = create_rng_state(seed, alg)
|
|
||||||
return cls(state=state, alg=alg)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_non_deterministic_state(cls, alg=None):
|
|
||||||
"""Creates a generator by non-deterministically initializing its state.
|
|
||||||
|
|
||||||
The source of the non-determinism will be platform- and time-dependent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
|
||||||
`__init__` for its possible values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The new generator.
|
|
||||||
|
|
||||||
Throws:
|
|
||||||
ValueError: if the generator is created inside a synchronous
|
|
||||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
|
||||||
because there is ambiguity on how to replicate a generator (e.g. should
|
|
||||||
it be copied so such each replica will get the same random numbers, or
|
|
||||||
should it be "split" into different generators that generate
|
|
||||||
different random numbers).
|
|
||||||
"""
|
|
||||||
if alg is None:
|
|
||||||
# TODO(wangpeng): more sophisticated algorithm selection
|
|
||||||
alg = DEFAULT_ALGORITHM
|
|
||||||
alg = _convert_alg_to_int(alg)
|
|
||||||
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
|
||||||
dtype=SEED_TYPE)
|
|
||||||
return cls(state=state, alg=alg)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_key_counter(cls, key, counter, alg):
|
|
||||||
"""Creates a generator from a key and a counter.
|
|
||||||
|
|
||||||
This constructor only applies if the algorithm is a counter-based algorithm.
|
|
||||||
See method `key` for the meaning of "key" and "counter".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: the key for the RNG, a scalar of type STATE_TYPE.
|
|
||||||
counter: a vector of dtype STATE_TYPE representing the initial counter for
|
|
||||||
the RNG, whose length is algorithm-specific.,
|
|
||||||
alg: the RNG algorithm. If None, it will be auto-selected. See
|
|
||||||
`__init__` for its possible values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The new generator.
|
|
||||||
|
|
||||||
Throws:
|
|
||||||
ValueError: if the generator is created inside a synchronous
|
|
||||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
|
||||||
because there is ambiguity on how to replicate a generator (e.g. should
|
|
||||||
it be copied so such each replica will get the same random numbers, or
|
|
||||||
should it be "split" into different generators that generate
|
|
||||||
different random numbers).
|
|
||||||
"""
|
|
||||||
counter = _convert_to_state_tensor(counter)
|
|
||||||
key = _convert_to_state_tensor(key)
|
|
||||||
alg = _convert_alg_to_int(alg)
|
|
||||||
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
|
||||||
key.shape.assert_is_compatible_with([])
|
|
||||||
key = array_ops.reshape(key, [1])
|
|
||||||
state = array_ops.concat([counter, key], 0)
|
|
||||||
return cls(state=state, alg=alg)
|
|
||||||
|
|
||||||
def reset(self, state):
|
def reset(self, state):
|
||||||
"""Resets the generator by a new state.
|
"""Resets the generator by a new state.
|
||||||
@ -556,11 +531,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
state = array_ops.concat([counter, key], 0)
|
state = array_ops.concat([counter, key], 0)
|
||||||
self._state_var.assign(state)
|
self._state_var.assign(state)
|
||||||
|
|
||||||
@property
|
|
||||||
def _type_spec(self):
|
|
||||||
return GeneratorSpec(shape=self.state.shape, dtype=self.state.dtype,
|
|
||||||
alg=self.algorithm)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
"""The internal state of the RNG."""
|
"""The internal state of the RNG."""
|
||||||
@ -614,15 +584,52 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
counter is an unspecified implementation detail.
|
counter is an unspecified implementation detail.
|
||||||
"""
|
"""
|
||||||
if compat.forward_compatible(2020, 10, 25):
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
return gen_stateful_random_ops.rng_read_and_skip(
|
return self._skip(delta)
|
||||||
self.state.handle,
|
|
||||||
alg=math_ops.cast(self.algorithm, dtypes.int32),
|
|
||||||
delta=math_ops.cast(delta, dtypes.uint64))
|
|
||||||
gen_stateful_random_ops.rng_skip(
|
gen_stateful_random_ops.rng_skip(
|
||||||
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
|
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
|
||||||
math_ops.cast(delta, dtypes.int64))
|
math_ops.cast(delta, dtypes.int64))
|
||||||
# pylint: enable=g-doc-return-or-yield
|
# pylint: enable=g-doc-return-or-yield
|
||||||
|
|
||||||
|
def _skip_single_var(self, var, delta):
|
||||||
|
# TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
|
||||||
|
return gen_stateful_random_ops.rng_read_and_skip(
|
||||||
|
var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32),
|
||||||
|
delta=math_ops.cast(delta, dtypes.uint64))
|
||||||
|
|
||||||
|
def _skip(self, delta):
|
||||||
|
def update_fn(v):
|
||||||
|
return self._skip_single_var(v, delta)
|
||||||
|
# TODO(b/170515001): Always call strategy.extended.update after calling it
|
||||||
|
# from both replica context and cross-replica context is supported.
|
||||||
|
if values_util.is_saving_non_distributed():
|
||||||
|
# Assumes replica context with replica_id=0, since we only save the first
|
||||||
|
# replica.
|
||||||
|
return update_fn(self.state)
|
||||||
|
if self._distribution_strategy is not None:
|
||||||
|
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
|
||||||
|
if ds_context.in_cross_replica_context():
|
||||||
|
# Code that operates on all replicas of a variable cannot be saved
|
||||||
|
# without retracing.
|
||||||
|
values_util.mark_as_unsaveable()
|
||||||
|
# In cross-replica context we need to use strategy.extended.update.
|
||||||
|
return ds_context.get_strategy().extended.update(
|
||||||
|
self.state, update_fn)
|
||||||
|
return update_fn(self.state)
|
||||||
|
|
||||||
|
def _preprocess_key(self, key):
|
||||||
|
if self._distribution_strategy is None:
|
||||||
|
return key
|
||||||
|
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
|
||||||
|
replica_id = get_replica_id()
|
||||||
|
if replica_id is not None:
|
||||||
|
replica_id = array_ops.stack([replica_id, 0], axis=0)
|
||||||
|
replica_id = math_ops.cast(replica_id, dtypes.uint64)
|
||||||
|
# Conceptually: key = hash(key, replica_id)
|
||||||
|
key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
|
||||||
|
shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
|
||||||
|
alg=self.algorithm)
|
||||||
|
return key
|
||||||
|
|
||||||
def _prepare_key_counter(self, shape):
|
def _prepare_key_counter(self, shape):
|
||||||
delta = math_ops.reduce_prod(shape)
|
delta = math_ops.reduce_prod(shape)
|
||||||
counter_key = self.skip(delta)
|
counter_key = self.skip(delta)
|
||||||
@ -630,6 +637,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
|
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
|
||||||
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
|
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
|
||||||
dtypes.uint64)
|
dtypes.uint64)
|
||||||
|
key = self._preprocess_key(key)
|
||||||
return key, counter
|
return key, counter
|
||||||
|
|
||||||
# The following functions return a tensor and as a side effect update
|
# The following functions return a tensor and as a side effect update
|
||||||
|
@ -18,11 +18,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.python.distribute import values as dist_values
|
from tensorflow.python.distribute import values as dist_values
|
||||||
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
||||||
@ -33,7 +33,6 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.kernel_tests.random import util as \
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
random_test_util
|
random_test_util
|
||||||
@ -45,6 +44,7 @@ from tensorflow.python.ops import stateful_random_ops as \
|
|||||||
random
|
random
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training.tracking import util as tracking_util
|
||||||
|
|
||||||
|
|
||||||
g_seeded = None
|
g_seeded = None
|
||||||
@ -636,49 +636,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(res1, res2)
|
self.assertAllEqual(res1, res2)
|
||||||
self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
|
self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testFunArgAlgIsInt(self):
|
|
||||||
"""Tests that `algorithm` is `int` when reconstructed from composite tensor.
|
|
||||||
"""
|
|
||||||
@def_function.function
|
|
||||||
def f(g):
|
|
||||||
self.assertIsInstance(g.algorithm, six.integer_types)
|
|
||||||
return g.make_seeds(), g
|
|
||||||
gen = random.Generator.from_seed(123, alg="philox")
|
|
||||||
f(gen)
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testLimitedRetracingWithCompositeTensors(self):
|
|
||||||
"""Tests that RNGs with the same shape/dtype won't cause retracing.
|
|
||||||
"""
|
|
||||||
trace_count = [0]
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def f(x):
|
|
||||||
trace_count[0] += 1
|
|
||||||
return x.normal([])
|
|
||||||
|
|
||||||
f(random.Generator.from_seed(1))
|
|
||||||
f(random.Generator.from_seed(2))
|
|
||||||
self.assertEqual(trace_count[0], 1)
|
|
||||||
|
|
||||||
def testMostSpecificCompatibleType(self):
|
|
||||||
"""Tests GeneratorSpec.most_specific_compatible_type.
|
|
||||||
"""
|
|
||||||
spec = random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32)
|
|
||||||
res = spec.most_specific_compatible_type(
|
|
||||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32))
|
|
||||||
self.assertEqual(spec, res)
|
|
||||||
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
|
||||||
spec.most_specific_compatible_type(
|
|
||||||
tensor_spec.TensorSpec(shape=(2, 3), dtype=dtypes.int32))
|
|
||||||
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
|
||||||
spec.most_specific_compatible_type(
|
|
||||||
random.GeneratorSpec(shape=(2, 4), dtype=dtypes.int32))
|
|
||||||
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
|
||||||
spec.most_specific_compatible_type(
|
|
||||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testCreateOutsideMirroredStrat(self):
|
def testCreateOutsideMirroredStrat(self):
|
||||||
"""Tests RNG/MirrorStrategy interaction #1.
|
"""Tests RNG/MirrorStrategy interaction #1.
|
||||||
@ -702,29 +659,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(2, len(values))
|
self.assertAllEqual(2, len(values))
|
||||||
self.assertAllDifferent(values)
|
self.assertAllDifferent(values)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testMirroredStratParaSyncDisallowed(self):
|
|
||||||
"""Tests that generator creation in MirroredStrategy is disallowed.
|
|
||||||
"""
|
|
||||||
creators = [
|
|
||||||
lambda: random.Generator.from_seed(1234),
|
|
||||||
random.Generator.from_non_deterministic_state,
|
|
||||||
]
|
|
||||||
shape = [3, 4]
|
|
||||||
dtype = dtypes.int32
|
|
||||||
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
|
||||||
for creator in creators:
|
|
||||||
with strat.scope():
|
|
||||||
with self.assertRaisesWithPredicateMatch(
|
|
||||||
ValueError, "disallowed"):
|
|
||||||
creator() # pylint: disable=cell-var-from-loop
|
|
||||||
def f():
|
|
||||||
gen = creator() # pylint: disable=cell-var-from-loop
|
|
||||||
return gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
||||||
with self.assertRaisesWithPredicateMatch(
|
|
||||||
ValueError, "disallowed"):
|
|
||||||
strat.extended.call_for_each_replica(fn=f)
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testMirroredStratParaAsync(self):
|
def testMirroredStratParaAsync(self):
|
||||||
"""Tests RNG/MirrorStrategy interaction #2.
|
"""Tests RNG/MirrorStrategy interaction #2.
|
||||||
@ -764,6 +698,23 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
|
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
self.assertAllEqual(r1, r2)
|
self.assertAllEqual(r1, r2)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testRestore(self):
|
||||||
|
"""Tests save and restore.
|
||||||
|
"""
|
||||||
|
fname = os.path.join(self.get_temp_dir(), "checkpoint")
|
||||||
|
g = random.Generator.from_seed(1)
|
||||||
|
cp = tracking_util.Checkpoint(g=g)
|
||||||
|
def write_restore_compare():
|
||||||
|
cp.write(fname)
|
||||||
|
r1 = g.uniform([], dtype=dtypes.uint32, minval=None)
|
||||||
|
cp.restore(fname)
|
||||||
|
r2 = g.uniform([], dtype=dtypes.uint32, minval=None)
|
||||||
|
self.assertAllEqual(r1, r2)
|
||||||
|
# Run multiple times so that cp.write is called in various RNG states
|
||||||
|
for _ in range(2):
|
||||||
|
write_restore_compare()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config.set_soft_device_placement(False)
|
config.set_soft_device_placement(False)
|
||||||
|
@ -3,7 +3,6 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "algorithm"
|
name: "algorithm"
|
||||||
|
@ -3,7 +3,6 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "algorithm"
|
name: "algorithm"
|
||||||
|
@ -3,7 +3,6 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "algorithm"
|
name: "algorithm"
|
||||||
|
@ -3,7 +3,6 @@ tf_class {
|
|||||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "algorithm"
|
name: "algorithm"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user