Moves tf.random.experimental.Generator out of experimental status
PiperOrigin-RevId: 286257170 Change-Id: I656383638ba145405ed00ebb5279ebd900787fca
This commit is contained in:
parent
3378750f4c
commit
9986a75807
@ -63,7 +63,7 @@ PHILOX_STATE_SIZE = 3
|
||||
THREEFRY_STATE_SIZE = 2
|
||||
|
||||
|
||||
@tf_export("random.experimental.Algorithm")
|
||||
@tf_export("random.Algorithm", "random.experimental.Algorithm")
|
||||
class Algorithm(enum.Enum):
|
||||
PHILOX = 1
|
||||
THREEFRY = 2
|
||||
@ -183,16 +183,16 @@ def _convert_alg_to_int(alg):
|
||||
(alg, type(alg)))
|
||||
|
||||
|
||||
@tf_export("random.experimental.create_rng_state")
|
||||
@tf_export("random.create_rng_state", "random.experimental.create_rng_state")
|
||||
def create_rng_state(seed, alg):
|
||||
"""Creates a RNG state from an integer or a vector.
|
||||
|
||||
Example:
|
||||
|
||||
>>> tf.random.experimental.create_rng_state(
|
||||
>>> tf.random.create_rng_state(
|
||||
... 1234, "philox")
|
||||
array([1234, 0, 0])
|
||||
>>> tf.random.experimental.create_rng_state(
|
||||
>>> tf.random.create_rng_state(
|
||||
... [12, 34], "threefry")
|
||||
array([12, 34])
|
||||
|
||||
@ -279,7 +279,7 @@ def _create_variable(*args, **kwargs):
|
||||
return var
|
||||
|
||||
|
||||
@tf_export("random.experimental.Generator")
|
||||
@tf_export("random.Generator", "random.experimental.Generator")
|
||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
"""Random-number generator.
|
||||
|
||||
@ -287,7 +287,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
|
||||
Creating a generator from a seed:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_seed(1234)
|
||||
>>> g = tf.random.Generator.from_seed(1234)
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||
array([[ 0.9356609 , 1.0854305 , -0.93788373],
|
||||
@ -295,7 +295,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
|
||||
Creating a generator from a non-deterministic state:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_non_deterministic_state()
|
||||
>>> g = tf.random.Generator.from_non_deterministic_state()
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||
|
||||
@ -303,7 +303,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
(RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
|
||||
example:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_seed(123, alg="philox")
|
||||
>>> g = tf.random.Generator.from_seed(123, alg="philox")
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||
array([[ 0.8673864 , -0.29899067, -0.9310337 ],
|
||||
@ -317,7 +317,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
random numbers are generated, the state of the generator will change. For
|
||||
example:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_seed(1234)
|
||||
>>> g = tf.random.Generator.from_seed(1234)
|
||||
>>> g.state
|
||||
<tf.Variable ... numpy=array([1234, 0, 0])>
|
||||
>>> g.normal(shape=(2, 3))
|
||||
@ -329,7 +329,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
|
||||
There is also a global generator:
|
||||
|
||||
>>> g = tf.random.experimental.get_global_generator()
|
||||
>>> g = tf.random.get_global_generator()
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||
"""
|
||||
@ -351,8 +351,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
variable, the generator will reuse it instead of creating a new
|
||||
variable.
|
||||
alg: the RNG algorithm. Possible values are
|
||||
`tf.random.experimental.Algorithm.PHILOX` for the Philox algorithm and
|
||||
`tf.random.experimental.Algorithm.THREEFRY` for the ThreeFry algorithm
|
||||
`tf.random.Algorithm.PHILOX` for the Philox algorithm and
|
||||
`tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
|
||||
(see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
||||
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
|
||||
The string names `"philox"` and `"threefry"` can also be used.
|
||||
@ -756,14 +756,14 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
# Probability of success.
|
||||
probs = [0.8]
|
||||
|
||||
rng = tf.random.experimental.Generator.from_seed(seed=234)
|
||||
rng = tf.random.Generator.from_seed(seed=234)
|
||||
binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
|
||||
|
||||
|
||||
counts = ... # Shape [3, 1, 2]
|
||||
probs = ... # Shape [1, 4, 2]
|
||||
shape = [3, 4, 3, 4, 2]
|
||||
rng = tf.random.experimental.Generator.from_seed(seed=1717)
|
||||
rng = tf.random.Generator.from_seed(seed=1717)
|
||||
# Sample shape will be [3, 4, 3, 4, 2]
|
||||
binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs)
|
||||
```
|
||||
@ -897,7 +897,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
global_generator = None
|
||||
|
||||
|
||||
@tf_export("random.experimental.get_global_generator")
|
||||
@tf_export("random.get_global_generator",
|
||||
"random.experimental.get_global_generator")
|
||||
def get_global_generator():
|
||||
"""Retrieves the global generator.
|
||||
|
||||
@ -907,7 +908,7 @@ def get_global_generator():
|
||||
placed on a less-ideal device will incur performance regression.
|
||||
|
||||
Returns:
|
||||
The global `tf.random.experimental.Generator` object.
|
||||
The global `tf.random.Generator` object.
|
||||
"""
|
||||
global global_generator
|
||||
if global_generator is None:
|
||||
@ -916,7 +917,8 @@ def get_global_generator():
|
||||
return global_generator
|
||||
|
||||
|
||||
@tf_export("random.experimental.set_global_generator")
|
||||
@tf_export("random.set_global_generator",
|
||||
"random.experimental.set_global_generator")
|
||||
def set_global_generator(generator):
|
||||
"""Replaces the global generator with another `Generator` object.
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.random.Algorithm"
|
||||
tf_class {
|
||||
is_instance: "<enum \'Algorithm\'>"
|
||||
member {
|
||||
name: "PHILOX"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
member {
|
||||
name: "THREEFRY"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
path: "tensorflow.random.Generator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "algorithm"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "key"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'copy_from\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "binomial"
|
||||
argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_key_counter"
|
||||
argspec: "args=[\'cls\', \'key\', \'counter\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_non_deterministic_state"
|
||||
argspec: "args=[\'cls\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_seed"
|
||||
argspec: "args=[\'cls\', \'seed\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_state"
|
||||
argspec: "args=[\'cls\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_seeds"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "normal"
|
||||
argspec: "args=[\'self\', \'shape\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reset"
|
||||
argspec: "args=[\'self\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_key_counter"
|
||||
argspec: "args=[\'self\', \'key\', \'counter\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_seed"
|
||||
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "skip"
|
||||
argspec: "args=[\'self\', \'delta\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "split"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "truncated_normal"
|
||||
argspec: "args=[\'self\', \'shape\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "uniform"
|
||||
argspec: "args=[\'self\', \'shape\', \'minval\', \'maxval\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "uniform_full_int"
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
}
|
@ -1,5 +1,13 @@
|
||||
path: "tensorflow.random"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Algorithm"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "Generator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
@ -12,6 +20,10 @@ tf_module {
|
||||
name: "categorical"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "create_rng_state"
|
||||
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "fixed_unigram_candidate_sampler"
|
||||
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'vocab_file\', \'distortion\', \'num_reserved_ids\', \'num_shards\', \'shard\', \'unigrams\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'1.0\', \'0\', \'1\', \'0\', \'()\', \'None\', \'None\'], "
|
||||
@ -20,6 +32,10 @@ tf_module {
|
||||
name: "gamma"
|
||||
argspec: "args=[\'shape\', \'alpha\', \'beta\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_global_generator"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_seed"
|
||||
argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -44,6 +60,10 @@ tf_module {
|
||||
name: "poisson"
|
||||
argspec: "args=[\'lam\', \'shape\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_global_generator"
|
||||
argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_random_seed"
|
||||
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.random.Algorithm"
|
||||
tf_class {
|
||||
is_instance: "<enum \'Algorithm\'>"
|
||||
member {
|
||||
name: "PHILOX"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
member {
|
||||
name: "THREEFRY"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
path: "tensorflow.random.Generator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "algorithm"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "key"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "state"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'copy_from\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "binomial"
|
||||
argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_key_counter"
|
||||
argspec: "args=[\'cls\', \'key\', \'counter\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_non_deterministic_state"
|
||||
argspec: "args=[\'cls\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_seed"
|
||||
argspec: "args=[\'cls\', \'seed\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_state"
|
||||
argspec: "args=[\'cls\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_seeds"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "normal"
|
||||
argspec: "args=[\'self\', \'shape\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reset"
|
||||
argspec: "args=[\'self\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_key_counter"
|
||||
argspec: "args=[\'self\', \'key\', \'counter\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_seed"
|
||||
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "skip"
|
||||
argspec: "args=[\'self\', \'delta\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "split"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "truncated_normal"
|
||||
argspec: "args=[\'self\', \'shape\', \'mean\', \'stddev\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "uniform"
|
||||
argspec: "args=[\'self\', \'shape\', \'minval\', \'maxval\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "uniform_full_int"
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
}
|
@ -1,5 +1,13 @@
|
||||
path: "tensorflow.random"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Algorithm"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "Generator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
@ -12,6 +20,10 @@ tf_module {
|
||||
name: "categorical"
|
||||
argspec: "args=[\'logits\', \'num_samples\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "create_rng_state"
|
||||
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "fixed_unigram_candidate_sampler"
|
||||
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'vocab_file\', \'distortion\', \'num_reserved_ids\', \'num_shards\', \'shard\', \'unigrams\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'1.0\', \'0\', \'1\', \'0\', \'()\', \'None\', \'None\'], "
|
||||
@ -20,6 +32,10 @@ tf_module {
|
||||
name: "gamma"
|
||||
argspec: "args=[\'shape\', \'alpha\', \'beta\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_global_generator"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "learned_unigram_candidate_sampler"
|
||||
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
@ -36,6 +52,10 @@ tf_module {
|
||||
name: "poisson"
|
||||
argspec: "args=[\'shape\', \'lam\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_global_generator"
|
||||
argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_seed"
|
||||
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user