Bug fix: Make get_local_variable accessible in the opensource Tensorflow namespace and update doscstring.

Change: 143149885
This commit is contained in:
A. Unique TensorFlower 2016-12-28 19:36:53 -08:00 committed by TensorFlower Gardener
parent d322c0533d
commit a081f4b06f
5 changed files with 261 additions and 79 deletions

View File

@ -0,0 +1,84 @@
### `tf.get_local_variable(*args, **kwargs)` {#get_local_variable}
Gets an existing *local* variable or creates a new one.
Behavior is the same as in `get_variable`, except that variables are
added to the `LOCAL_VARIABLES` collection and `trainable` is set to
`False`.
This function prefixes the name with the current variable scope
and performs reuse checks. See the
[Variable Scope How To](../../how_tos/variable_scope/index.md)
for an extensive description of how reusing works. Here is a basic example:
```python
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
with tf.variable_scope("foo", reuse=True)
v1 = tf.get_variable("v") # The same as v above.
```
If initializer is `None` (the default), the default initializer passed in
the variable scope will be used. If that one is `None` too, a
`uniform_unit_scaling_initializer` will be used. The initializer can also be
a Tensor, in which case the variable is initialized to this value and shape.
Similarly, if the regularizer is `None` (the default), the default regularizer
passed in the variable scope will be used (if that is `None` too,
then by default no regularization is performed).
If a partitioner is provided, a `PartitionedVariable` is returned.
Accessing this object as a `Tensor` returns the shards concatenated along
the partition axis.
Some useful partitioners are available. See, e.g.,
`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
##### Args:
* <b>`name`</b>: The name of the new or existing variable.
* <b>`shape`</b>: Shape of the new or existing variable.
* <b>`dtype`</b>: Type of the new or existing variable (defaults to `DT_FLOAT`).
* <b>`initializer`</b>: Initializer for the variable if one is created.
* <b>`regularizer`</b>: A (Tensor -> Tensor or None) function; the result of
applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.LOCAL_VARIABLES]` (see `tf.Variable`).
* <b>`caching_device`</b>: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
and `dtype` of the Variable to be created, and returns a list of
partitions for each axis (currently only one axis can be partitioned).
* <b>`validate_shape`</b>: If False, allows the variable to be initialized with a
value of unknown shape. If True, the default, the shape of initial_value
must be known.
* <b>`custom_getter`</b>: Callable that takes as a first argument the true getter, and
allows overwriting the internal get_variable method.
The signature of `custom_getter` should match that of this method,
but the most future-proof version will allow for changes:
`def custom_getter(getter, *args, **kwargs)`. Direct access to
all `get_variable` parameters is also allowed:
`def custom_getter(getter, name, *args, **kwargs)`. A simple identity
custom getter that simply creates variables with modified names is:
```python
def custom_getter(getter, name, *args, **kwargs):
return getter(name + '_suffix', *args, **kwargs)
```
##### Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
partitioner was used).
##### Raises:
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
when violating reuse during variable creation, or when `initializer` dtype
and `dtype` don't match. Reuse is set inside `variable_scope`.

View File

@ -84,6 +84,7 @@
* [`export_meta_graph`](../../api_docs/python/state_ops.md#export_meta_graph) * [`export_meta_graph`](../../api_docs/python/state_ops.md#export_meta_graph)
* [`fixed_size_partitioner`](../../api_docs/python/state_ops.md#fixed_size_partitioner) * [`fixed_size_partitioner`](../../api_docs/python/state_ops.md#fixed_size_partitioner)
* [`get_checkpoint_state`](../../api_docs/python/state_ops.md#get_checkpoint_state) * [`get_checkpoint_state`](../../api_docs/python/state_ops.md#get_checkpoint_state)
* [`get_local_variable`](../../api_docs/python/state_ops.md#get_local_variable)
* [`get_variable`](../../api_docs/python/state_ops.md#get_variable) * [`get_variable`](../../api_docs/python/state_ops.md#get_variable)
* [`get_variable_scope`](../../api_docs/python/state_ops.md#get_variable_scope) * [`get_variable_scope`](../../api_docs/python/state_ops.md#get_variable_scope)
* [`global_variables`](../../api_docs/python/state_ops.md#global_variables) * [`global_variables`](../../api_docs/python/state_ops.md#global_variables)

View File

@ -2009,6 +2009,93 @@ Some useful partitioners are available. See, e.g.,
##### Raises: ##### Raises:
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
when violating reuse during variable creation, or when `initializer` dtype
and `dtype` don't match. Reuse is set inside `variable_scope`.
- - -
### `tf.get_local_variable(*args, **kwargs)` {#get_local_variable}
Gets an existing *local* variable or creates a new one.
Behavior is the same as in `get_variable`, except that variables are
added to the `LOCAL_VARIABLES` collection and `trainable` is set to
`False`.
This function prefixes the name with the current variable scope
and performs reuse checks. See the
[Variable Scope How To](../../how_tos/variable_scope/index.md)
for an extensive description of how reusing works. Here is a basic example:
```python
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
with tf.variable_scope("foo", reuse=True)
v1 = tf.get_variable("v") # The same as v above.
```
If initializer is `None` (the default), the default initializer passed in
the variable scope will be used. If that one is `None` too, a
`uniform_unit_scaling_initializer` will be used. The initializer can also be
a Tensor, in which case the variable is initialized to this value and shape.
Similarly, if the regularizer is `None` (the default), the default regularizer
passed in the variable scope will be used (if that is `None` too,
then by default no regularization is performed).
If a partitioner is provided, a `PartitionedVariable` is returned.
Accessing this object as a `Tensor` returns the shards concatenated along
the partition axis.
Some useful partitioners are available. See, e.g.,
`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
##### Args:
* <b>`name`</b>: The name of the new or existing variable.
* <b>`shape`</b>: Shape of the new or existing variable.
* <b>`dtype`</b>: Type of the new or existing variable (defaults to `DT_FLOAT`).
* <b>`initializer`</b>: Initializer for the variable if one is created.
* <b>`regularizer`</b>: A (Tensor -> Tensor or None) function; the result of
applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.LOCAL_VARIABLES]` (see `tf.Variable`).
* <b>`caching_device`</b>: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
and `dtype` of the Variable to be created, and returns a list of
partitions for each axis (currently only one axis can be partitioned).
* <b>`validate_shape`</b>: If False, allows the variable to be initialized with a
value of unknown shape. If True, the default, the shape of initial_value
must be known.
* <b>`custom_getter`</b>: Callable that takes as a first argument the true getter, and
allows overwriting the internal get_variable method.
The signature of `custom_getter` should match that of this method,
but the most future-proof version will allow for changes:
`def custom_getter(getter, *args, **kwargs)`. Direct access to
all `get_variable` parameters is also allowed:
`def custom_getter(getter, name, *args, **kwargs)`. A simple identity
custom getter that simply creates variables with modified names is:
```python
def custom_getter(getter, name, *args, **kwargs):
return getter(name + '_suffix', *args, **kwargs)
```
##### Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
partitioner was used).
##### Raises:
* <b>`ValueError`</b>: when creating a new variable and shape is not declared, * <b>`ValueError`</b>: when creating a new variable and shape is not declared,
when violating reuse during variable creation, or when `initializer` dtype when violating reuse during variable creation, or when `initializer` dtype
and `dtype` don't match. Reuse is set inside `variable_scope`. and `dtype` don't match. Reuse is set inside `variable_scope`.

View File

@ -54,6 +54,7 @@ TensorFlow provides several classes and operations that you can use to
create variables contingent on certain conditions. create variables contingent on certain conditions.
@@get_variable @@get_variable
@@get_local_variable
@@VariableScope @@VariableScope
@@variable_scope @@variable_scope
@@variable_op_scope @@variable_op_scope

View File

@ -979,38 +979,45 @@ def get_variable(name,
partitioner=None, partitioner=None,
validate_shape=True, validate_shape=True,
custom_getter=None): custom_getter=None):
"""Gets an existing variable with these parameters or create a new one. return get_variable_scope().get_variable(
_get_default_variable_store(), name, shape=shape, dtype=dtype,
initializer=initializer, regularizer=regularizer, trainable=trainable,
collections=collections, caching_device=caching_device,
partitioner=partitioner, validate_shape=validate_shape,
custom_getter=custom_getter)
get_variable_or_local_docstring = (
"""%s
This function prefixes the name with the current variable scope %sThis function prefixes the name with the current variable scope
and performs reuse checks. See the and performs reuse checks. See the
[Variable Scope How To](../../how_tos/variable_scope/index.md) [Variable Scope How To](../../how_tos/variable_scope/index.md)
for an extensive description of how reusing works. Here is a basic example: for an extensive description of how reusing works. Here is a basic example:
```python ```python
with tf.variable_scope("foo"): with tf.variable_scope("foo"):
v = tf.get_variable("v", [1]) # v.name == "foo/v:0" v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
w = tf.get_variable("w", [1]) # w.name == "foo/w:0" w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
with tf.variable_scope("foo", reuse=True) with tf.variable_scope("foo", reuse=True)
v1 = tf.get_variable("v") # The same as v above. v1 = tf.get_variable("v") # The same as v above.
``` ```
If initializer is `None` (the default), the default initializer passed in If initializer is `None` (the default), the default initializer passed in
the variable scope will be used. If that one is `None` too, a the variable scope will be used. If that one is `None` too, a
`uniform_unit_scaling_initializer` will be used. The initializer can also be `uniform_unit_scaling_initializer` will be used. The initializer can also be
a Tensor, in which case the variable is initialized to this value and shape. a Tensor, in which case the variable is initialized to this value and shape.
Similarly, if the regularizer is `None` (the default), the default regularizer Similarly, if the regularizer is `None` (the default), the default regularizer
passed in the variable scope will be used (if that is `None` too, passed in the variable scope will be used (if that is `None` too,
then by default no regularization is performed). then by default no regularization is performed).
If a partitioner is provided, a `PartitionedVariable` is returned. If a partitioner is provided, a `PartitionedVariable` is returned.
Accessing this object as a `Tensor` returns the shards concatenated along Accessing this object as a `Tensor` returns the shards concatenated along
the partition axis. the partition axis.
Some useful partitioners are available. See, e.g., Some useful partitioners are available. See, e.g.,
`variable_axis_size_partitioner` and `min_max_variable_partitioner`. `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
Args: Args:
name: The name of the new or existing variable. name: The name of the new or existing variable.
shape: Shape of the new or existing variable. shape: Shape of the new or existing variable.
dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
@ -1018,10 +1025,8 @@ def get_variable(name,
regularizer: A (Tensor -> Tensor or None) function; the result of regularizer: A (Tensor -> Tensor or None) function; the result of
applying it on a newly created variable will be added to the collection applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
trainable: If `True` also add the variable to the graph collection %scollections: List of graph collections keys to add the Variable to.
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). Defaults to `[%s]` (see `tf.Variable`).
collections: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
caching_device: Optional device string or function describing where the caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to device. If not `None`, caches on another device. Typical use is to
@ -1046,21 +1051,21 @@ def get_variable(name,
return getter(name + '_suffix', *args, **kwargs) return getter(name + '_suffix', *args, **kwargs)
``` ```
Returns: Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a The created or existing `Variable` (or `PartitionedVariable`, if a
partitioner was used). partitioner was used).
Raises: Raises:
ValueError: when creating a new variable and shape is not declared, ValueError: when creating a new variable and shape is not declared,
when violating reuse during variable creation, or when `initializer` dtype when violating reuse during variable creation, or when `initializer` dtype
and `dtype` don't match. Reuse is set inside `variable_scope`. and `dtype` don't match. Reuse is set inside `variable_scope`.
""" """)
return get_variable_scope().get_variable( get_variable.__doc__ = get_variable_or_local_docstring % (
_get_default_variable_store(), name, shape=shape, dtype=dtype, "Gets an existing variable with these parameters or create a new one.",
initializer=initializer, regularizer=regularizer, trainable=trainable, "",
collections=collections, caching_device=caching_device, "trainable: If `True` also add the variable to the graph collection\n"
partitioner=partitioner, validate_shape=validate_shape, " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n",
custom_getter=custom_getter) "GraphKeys.GLOBAL_VARIABLES")
@functools.wraps(get_variable) @functools.wraps(get_variable)
@ -1070,10 +1075,14 @@ def get_local_variable(*args, **kwargs):
kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES] kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES]
else: else:
kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES] kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES]
get_local_variable.__doc__ = (
"Gets an existing local variable or creates a new one.\n\n" +
get_local_variable.__doc__)
return get_variable(*args, **kwargs) return get_variable(*args, **kwargs)
get_local_variable.__doc__ = get_variable_or_local_docstring % (
"Gets an existing *local* variable or creates a new one.",
"Behavior is the same as in `get_variable`, except that variables are\n"
"added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
"`False`.\n",
"",
"GraphKeys.LOCAL_VARIABLES")
def _get_partitioned_variable(name, def _get_partitioned_variable(name,