Bug fix: Make get_local_variable
accessible in the opensource Tensorflow namespace and update doscstring.
Change: 143149885
This commit is contained in:
parent
d322c0533d
commit
a081f4b06f
@ -0,0 +1,84 @@
|
|||||||
|
### `tf.get_local_variable(*args, **kwargs)` {#get_local_variable}
|
||||||
|
|
||||||
|
Gets an existing *local* variable or creates a new one.
|
||||||
|
|
||||||
|
Behavior is the same as in `get_variable`, except that variables are
|
||||||
|
added to the `LOCAL_VARIABLES` collection and `trainable` is set to
|
||||||
|
`False`.
|
||||||
|
This function prefixes the name with the current variable scope
|
||||||
|
and performs reuse checks. See the
|
||||||
|
[Variable Scope How To](../../how_tos/variable_scope/index.md)
|
||||||
|
for an extensive description of how reusing works. Here is a basic example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with tf.variable_scope("foo"):
|
||||||
|
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
|
||||||
|
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
|
||||||
|
with tf.variable_scope("foo", reuse=True)
|
||||||
|
v1 = tf.get_variable("v") # The same as v above.
|
||||||
|
```
|
||||||
|
|
||||||
|
If initializer is `None` (the default), the default initializer passed in
|
||||||
|
the variable scope will be used. If that one is `None` too, a
|
||||||
|
`uniform_unit_scaling_initializer` will be used. The initializer can also be
|
||||||
|
a Tensor, in which case the variable is initialized to this value and shape.
|
||||||
|
|
||||||
|
Similarly, if the regularizer is `None` (the default), the default regularizer
|
||||||
|
passed in the variable scope will be used (if that is `None` too,
|
||||||
|
then by default no regularization is performed).
|
||||||
|
|
||||||
|
If a partitioner is provided, a `PartitionedVariable` is returned.
|
||||||
|
Accessing this object as a `Tensor` returns the shards concatenated along
|
||||||
|
the partition axis.
|
||||||
|
|
||||||
|
Some useful partitioners are available. See, e.g.,
|
||||||
|
`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`name`</b>: The name of the new or existing variable.
|
||||||
|
* <b>`shape`</b>: Shape of the new or existing variable.
|
||||||
|
* <b>`dtype`</b>: Type of the new or existing variable (defaults to `DT_FLOAT`).
|
||||||
|
* <b>`initializer`</b>: Initializer for the variable if one is created.
|
||||||
|
* <b>`regularizer`</b>: A (Tensor -> Tensor or None) function; the result of
|
||||||
|
applying it on a newly created variable will be added to the collection
|
||||||
|
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||||
|
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
|
||||||
|
Defaults to `[GraphKeys.LOCAL_VARIABLES]` (see `tf.Variable`).
|
||||||
|
* <b>`caching_device`</b>: Optional device string or function describing where the
|
||||||
|
Variable should be cached for reading. Defaults to the Variable's
|
||||||
|
device. If not `None`, caches on another device. Typical use is to
|
||||||
|
cache on the device where the Ops using the Variable reside, to
|
||||||
|
deduplicate copying through `Switch` and other conditional statements.
|
||||||
|
* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
|
||||||
|
and `dtype` of the Variable to be created, and returns a list of
|
||||||
|
partitions for each axis (currently only one axis can be partitioned).
|
||||||
|
* <b>`validate_shape`</b>: If False, allows the variable to be initialized with a
|
||||||
|
value of unknown shape. If True, the default, the shape of initial_value
|
||||||
|
must be known.
|
||||||
|
* <b>`custom_getter`</b>: Callable that takes as a first argument the true getter, and
|
||||||
|
allows overwriting the internal get_variable method.
|
||||||
|
The signature of `custom_getter` should match that of this method,
|
||||||
|
but the most future-proof version will allow for changes:
|
||||||
|
`def custom_getter(getter, *args, **kwargs)`. Direct access to
|
||||||
|
all `get_variable` parameters is also allowed:
|
||||||
|
`def custom_getter(getter, name, *args, **kwargs)`. A simple identity
|
||||||
|
custom getter that simply creates variables with modified names is:
|
||||||
|
```python
|
||||||
|
def custom_getter(getter, name, *args, **kwargs):
|
||||||
|
return getter(name + '_suffix', *args, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
The created or existing `Variable` (or `PartitionedVariable`, if a
|
||||||
|
partitioner was used).
|
||||||
|
|
||||||
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
|
||||||
|
when violating reuse during variable creation, or when `initializer` dtype
|
||||||
|
and `dtype` don't match. Reuse is set inside `variable_scope`.
|
||||||
|
|
@ -84,6 +84,7 @@
|
|||||||
* [`export_meta_graph`](../../api_docs/python/state_ops.md#export_meta_graph)
|
* [`export_meta_graph`](../../api_docs/python/state_ops.md#export_meta_graph)
|
||||||
* [`fixed_size_partitioner`](../../api_docs/python/state_ops.md#fixed_size_partitioner)
|
* [`fixed_size_partitioner`](../../api_docs/python/state_ops.md#fixed_size_partitioner)
|
||||||
* [`get_checkpoint_state`](../../api_docs/python/state_ops.md#get_checkpoint_state)
|
* [`get_checkpoint_state`](../../api_docs/python/state_ops.md#get_checkpoint_state)
|
||||||
|
* [`get_local_variable`](../../api_docs/python/state_ops.md#get_local_variable)
|
||||||
* [`get_variable`](../../api_docs/python/state_ops.md#get_variable)
|
* [`get_variable`](../../api_docs/python/state_ops.md#get_variable)
|
||||||
* [`get_variable_scope`](../../api_docs/python/state_ops.md#get_variable_scope)
|
* [`get_variable_scope`](../../api_docs/python/state_ops.md#get_variable_scope)
|
||||||
* [`global_variables`](../../api_docs/python/state_ops.md#global_variables)
|
* [`global_variables`](../../api_docs/python/state_ops.md#global_variables)
|
||||||
|
@ -2009,6 +2009,93 @@ Some useful partitioners are available. See, e.g.,
|
|||||||
##### Raises:
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
|
||||||
|
when violating reuse during variable creation, or when `initializer` dtype
|
||||||
|
and `dtype` don't match. Reuse is set inside `variable_scope`.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.get_local_variable(*args, **kwargs)` {#get_local_variable}
|
||||||
|
|
||||||
|
Gets an existing *local* variable or creates a new one.
|
||||||
|
|
||||||
|
Behavior is the same as in `get_variable`, except that variables are
|
||||||
|
added to the `LOCAL_VARIABLES` collection and `trainable` is set to
|
||||||
|
`False`.
|
||||||
|
This function prefixes the name with the current variable scope
|
||||||
|
and performs reuse checks. See the
|
||||||
|
[Variable Scope How To](../../how_tos/variable_scope/index.md)
|
||||||
|
for an extensive description of how reusing works. Here is a basic example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with tf.variable_scope("foo"):
|
||||||
|
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
|
||||||
|
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
|
||||||
|
with tf.variable_scope("foo", reuse=True)
|
||||||
|
v1 = tf.get_variable("v") # The same as v above.
|
||||||
|
```
|
||||||
|
|
||||||
|
If initializer is `None` (the default), the default initializer passed in
|
||||||
|
the variable scope will be used. If that one is `None` too, a
|
||||||
|
`uniform_unit_scaling_initializer` will be used. The initializer can also be
|
||||||
|
a Tensor, in which case the variable is initialized to this value and shape.
|
||||||
|
|
||||||
|
Similarly, if the regularizer is `None` (the default), the default regularizer
|
||||||
|
passed in the variable scope will be used (if that is `None` too,
|
||||||
|
then by default no regularization is performed).
|
||||||
|
|
||||||
|
If a partitioner is provided, a `PartitionedVariable` is returned.
|
||||||
|
Accessing this object as a `Tensor` returns the shards concatenated along
|
||||||
|
the partition axis.
|
||||||
|
|
||||||
|
Some useful partitioners are available. See, e.g.,
|
||||||
|
`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`name`</b>: The name of the new or existing variable.
|
||||||
|
* <b>`shape`</b>: Shape of the new or existing variable.
|
||||||
|
* <b>`dtype`</b>: Type of the new or existing variable (defaults to `DT_FLOAT`).
|
||||||
|
* <b>`initializer`</b>: Initializer for the variable if one is created.
|
||||||
|
* <b>`regularizer`</b>: A (Tensor -> Tensor or None) function; the result of
|
||||||
|
applying it on a newly created variable will be added to the collection
|
||||||
|
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||||
|
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
|
||||||
|
Defaults to `[GraphKeys.LOCAL_VARIABLES]` (see `tf.Variable`).
|
||||||
|
* <b>`caching_device`</b>: Optional device string or function describing where the
|
||||||
|
Variable should be cached for reading. Defaults to the Variable's
|
||||||
|
device. If not `None`, caches on another device. Typical use is to
|
||||||
|
cache on the device where the Ops using the Variable reside, to
|
||||||
|
deduplicate copying through `Switch` and other conditional statements.
|
||||||
|
* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
|
||||||
|
and `dtype` of the Variable to be created, and returns a list of
|
||||||
|
partitions for each axis (currently only one axis can be partitioned).
|
||||||
|
* <b>`validate_shape`</b>: If False, allows the variable to be initialized with a
|
||||||
|
value of unknown shape. If True, the default, the shape of initial_value
|
||||||
|
must be known.
|
||||||
|
* <b>`custom_getter`</b>: Callable that takes as a first argument the true getter, and
|
||||||
|
allows overwriting the internal get_variable method.
|
||||||
|
The signature of `custom_getter` should match that of this method,
|
||||||
|
but the most future-proof version will allow for changes:
|
||||||
|
`def custom_getter(getter, *args, **kwargs)`. Direct access to
|
||||||
|
all `get_variable` parameters is also allowed:
|
||||||
|
`def custom_getter(getter, name, *args, **kwargs)`. A simple identity
|
||||||
|
custom getter that simply creates variables with modified names is:
|
||||||
|
```python
|
||||||
|
def custom_getter(getter, name, *args, **kwargs):
|
||||||
|
return getter(name + '_suffix', *args, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
The created or existing `Variable` (or `PartitionedVariable`, if a
|
||||||
|
partitioner was used).
|
||||||
|
|
||||||
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
|
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
|
||||||
when violating reuse during variable creation, or when `initializer` dtype
|
when violating reuse during variable creation, or when `initializer` dtype
|
||||||
and `dtype` don't match. Reuse is set inside `variable_scope`.
|
and `dtype` don't match. Reuse is set inside `variable_scope`.
|
||||||
|
@ -54,6 +54,7 @@ TensorFlow provides several classes and operations that you can use to
|
|||||||
create variables contingent on certain conditions.
|
create variables contingent on certain conditions.
|
||||||
|
|
||||||
@@get_variable
|
@@get_variable
|
||||||
|
@@get_local_variable
|
||||||
@@VariableScope
|
@@VariableScope
|
||||||
@@variable_scope
|
@@variable_scope
|
||||||
@@variable_op_scope
|
@@variable_op_scope
|
||||||
|
@ -979,9 +979,16 @@ def get_variable(name,
|
|||||||
partitioner=None,
|
partitioner=None,
|
||||||
validate_shape=True,
|
validate_shape=True,
|
||||||
custom_getter=None):
|
custom_getter=None):
|
||||||
"""Gets an existing variable with these parameters or create a new one.
|
return get_variable_scope().get_variable(
|
||||||
|
_get_default_variable_store(), name, shape=shape, dtype=dtype,
|
||||||
|
initializer=initializer, regularizer=regularizer, trainable=trainable,
|
||||||
|
collections=collections, caching_device=caching_device,
|
||||||
|
partitioner=partitioner, validate_shape=validate_shape,
|
||||||
|
custom_getter=custom_getter)
|
||||||
|
get_variable_or_local_docstring = (
|
||||||
|
"""%s
|
||||||
|
|
||||||
This function prefixes the name with the current variable scope
|
%sThis function prefixes the name with the current variable scope
|
||||||
and performs reuse checks. See the
|
and performs reuse checks. See the
|
||||||
[Variable Scope How To](../../how_tos/variable_scope/index.md)
|
[Variable Scope How To](../../how_tos/variable_scope/index.md)
|
||||||
for an extensive description of how reusing works. Here is a basic example:
|
for an extensive description of how reusing works. Here is a basic example:
|
||||||
@ -1018,10 +1025,8 @@ def get_variable(name,
|
|||||||
regularizer: A (Tensor -> Tensor or None) function; the result of
|
regularizer: A (Tensor -> Tensor or None) function; the result of
|
||||||
applying it on a newly created variable will be added to the collection
|
applying it on a newly created variable will be added to the collection
|
||||||
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||||
trainable: If `True` also add the variable to the graph collection
|
%scollections: List of graph collections keys to add the Variable to.
|
||||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
Defaults to `[%s]` (see `tf.Variable`).
|
||||||
collections: List of graph collections keys to add the Variable to.
|
|
||||||
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
|
|
||||||
caching_device: Optional device string or function describing where the
|
caching_device: Optional device string or function describing where the
|
||||||
Variable should be cached for reading. Defaults to the Variable's
|
Variable should be cached for reading. Defaults to the Variable's
|
||||||
device. If not `None`, caches on another device. Typical use is to
|
device. If not `None`, caches on another device. Typical use is to
|
||||||
@ -1054,13 +1059,13 @@ def get_variable(name,
|
|||||||
ValueError: when creating a new variable and shape is not declared,
|
ValueError: when creating a new variable and shape is not declared,
|
||||||
when violating reuse during variable creation, or when `initializer` dtype
|
when violating reuse during variable creation, or when `initializer` dtype
|
||||||
and `dtype` don't match. Reuse is set inside `variable_scope`.
|
and `dtype` don't match. Reuse is set inside `variable_scope`.
|
||||||
"""
|
""")
|
||||||
return get_variable_scope().get_variable(
|
get_variable.__doc__ = get_variable_or_local_docstring % (
|
||||||
_get_default_variable_store(), name, shape=shape, dtype=dtype,
|
"Gets an existing variable with these parameters or create a new one.",
|
||||||
initializer=initializer, regularizer=regularizer, trainable=trainable,
|
"",
|
||||||
collections=collections, caching_device=caching_device,
|
"trainable: If `True` also add the variable to the graph collection\n"
|
||||||
partitioner=partitioner, validate_shape=validate_shape,
|
" `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n",
|
||||||
custom_getter=custom_getter)
|
"GraphKeys.GLOBAL_VARIABLES")
|
||||||
|
|
||||||
|
|
||||||
@functools.wraps(get_variable)
|
@functools.wraps(get_variable)
|
||||||
@ -1070,10 +1075,14 @@ def get_local_variable(*args, **kwargs):
|
|||||||
kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES]
|
kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES]
|
||||||
else:
|
else:
|
||||||
kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES]
|
kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES]
|
||||||
get_local_variable.__doc__ = (
|
|
||||||
"Gets an existing local variable or creates a new one.\n\n" +
|
|
||||||
get_local_variable.__doc__)
|
|
||||||
return get_variable(*args, **kwargs)
|
return get_variable(*args, **kwargs)
|
||||||
|
get_local_variable.__doc__ = get_variable_or_local_docstring % (
|
||||||
|
"Gets an existing *local* variable or creates a new one.",
|
||||||
|
"Behavior is the same as in `get_variable`, except that variables are\n"
|
||||||
|
"added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
|
||||||
|
"`False`.\n",
|
||||||
|
"",
|
||||||
|
"GraphKeys.LOCAL_VARIABLES")
|
||||||
|
|
||||||
|
|
||||||
def _get_partitioned_variable(name,
|
def _get_partitioned_variable(name,
|
||||||
|
Loading…
Reference in New Issue
Block a user