Added model_variable and helpers to manage variables.
Change: 122727793
This commit is contained in:
parent
63c29c82b3
commit
144855b385
@ -18,13 +18,40 @@
|
|||||||
@@assert_same_float_dtype
|
@@assert_same_float_dtype
|
||||||
@@assert_scalar_int
|
@@assert_scalar_int
|
||||||
@@convert_to_tensor_or_sparse_tensor
|
@@convert_to_tensor_or_sparse_tensor
|
||||||
@@local_variable
|
@@get_graph_from_inputs
|
||||||
|
@@is_numeric_tensor
|
||||||
|
@@is_non_decreasing
|
||||||
|
@@is_strictly_increasing
|
||||||
@@reduce_sum_n
|
@@reduce_sum_n
|
||||||
@@safe_embedding_lookup_sparse
|
@@safe_embedding_lookup_sparse
|
||||||
@@with_shape
|
@@with_shape
|
||||||
@@with_same_shape
|
@@with_same_shape
|
||||||
|
|
||||||
@@get_graph_from_inputs
|
|
||||||
|
## Arg_Scope
|
||||||
|
@@arg_scope
|
||||||
|
@@add_arg_scope
|
||||||
|
@@has_arg_scope
|
||||||
|
@@arg_scoped_arguments
|
||||||
|
|
||||||
|
## Variables
|
||||||
|
@@add_model_variable
|
||||||
|
@@assert_global_step
|
||||||
|
@@assert_or_get_global_step
|
||||||
|
@@create_global_step
|
||||||
|
@@get_global_step
|
||||||
|
@@get_or_create_global_step
|
||||||
|
@@get_local_variables
|
||||||
|
@@get_model_variables
|
||||||
|
@@get_unique_variable
|
||||||
|
@@get_variables_by_name
|
||||||
|
@@get_variables_by_suffix
|
||||||
|
@@get_variables_to_restore
|
||||||
|
@@get_variables
|
||||||
|
@@local_variable
|
||||||
|
@@model_variable
|
||||||
|
@@variable
|
||||||
|
@@VariableDeviceChooser
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Tensor utility functions."""
|
"""Tensor utility functions."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -27,14 +26,16 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'assert_same_float_dtype', 'assert_scalar_int',
|
'assert_same_float_dtype',
|
||||||
'convert_to_tensor_or_sparse_tensor', 'local_variable', 'reduce_sum_n',
|
'assert_scalar_int',
|
||||||
'with_shape', 'with_same_shape',
|
'convert_to_tensor_or_sparse_tensor',
|
||||||
]
|
'reduce_sum_n',
|
||||||
|
'with_shape',
|
||||||
|
'with_same_shape']
|
||||||
|
|
||||||
|
|
||||||
def _assert_same_base_type(items, expected_type=None):
|
def _assert_same_base_type(items, expected_type=None):
|
||||||
"""Asserts all items are of the same base type.
|
r"""Asserts all items are of the same base type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
|
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
|
||||||
@ -110,23 +111,6 @@ def assert_scalar_int(tensor):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
# TODO(ptucker): Move to tf.variables?
|
|
||||||
def local_variable(initial_value, validate_shape=True, name=None):
|
|
||||||
"""Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
initial_value: See variables.Variable.__init__.
|
|
||||||
validate_shape: See variables.Variable.__init__.
|
|
||||||
name: See variables.Variable.__init__.
|
|
||||||
Returns:
|
|
||||||
New variable.
|
|
||||||
"""
|
|
||||||
return variables.Variable(
|
|
||||||
initial_value, trainable=False,
|
|
||||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
|
||||||
validate_shape=validate_shape, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_sum_n(tensors, name=None):
|
def reduce_sum_n(tensors, name=None):
|
||||||
"""Reduce tensors to a scalar sum.
|
"""Reduce tensors to a scalar sum.
|
||||||
|
|
||||||
|
@ -47,11 +47,6 @@
|
|||||||
|
|
||||||
@tf.contrib.add_arg_scope
|
@tf.contrib.add_arg_scope
|
||||||
def conv2d(*args, **kwargs)
|
def conv2d(*args, **kwargs)
|
||||||
|
|
||||||
@@arg_scope
|
|
||||||
@@add_arg_scope
|
|
||||||
@@has_arg_scope
|
|
||||||
@@arg_scoped_arguments
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -59,8 +54,10 @@ from __future__ import print_function
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
__all__ = ['arg_scope', 'add_arg_scope',
|
__all__ = ['arg_scope',
|
||||||
'has_arg_scope', 'arg_scoped_arguments']
|
'add_arg_scope',
|
||||||
|
'has_arg_scope',
|
||||||
|
'arg_scoped_arguments']
|
||||||
|
|
||||||
_ARGSTACK = [{}]
|
_ARGSTACK = [{}]
|
||||||
|
|
||||||
|
@ -14,24 +14,37 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Variable functions.
|
"""Variable functions.
|
||||||
|
|
||||||
@@assert_global_step
|
|
||||||
@@create_global_step
|
|
||||||
@@get_global_step
|
|
||||||
@@assert_or_get_global_step
|
|
||||||
@@local_variable
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
|
||||||
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'assert_global_step', 'create_global_step', 'get_global_step',
|
__all__ = ['add_model_variable',
|
||||||
'assert_or_get_global_step', 'local_variable']
|
'assert_global_step',
|
||||||
|
'assert_or_get_global_step',
|
||||||
|
'create_global_step',
|
||||||
|
'get_global_step',
|
||||||
|
'get_or_create_global_step',
|
||||||
|
'get_local_variables',
|
||||||
|
'get_model_variables',
|
||||||
|
'get_unique_variable',
|
||||||
|
'get_variables_by_name',
|
||||||
|
'get_variables_by_suffix',
|
||||||
|
'get_variables_to_restore',
|
||||||
|
'get_variables',
|
||||||
|
'local_variable',
|
||||||
|
'model_variable',
|
||||||
|
'variable',
|
||||||
|
'VariableDeviceChooser']
|
||||||
|
|
||||||
|
|
||||||
def assert_global_step(global_step_tensor):
|
def assert_global_step(global_step_tensor):
|
||||||
@ -125,7 +138,6 @@ def create_global_step(graph=None):
|
|||||||
Global step tensor.
|
Global step tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `dtype` is invalid.
|
|
||||||
ValueError: if global step key is already defined.
|
ValueError: if global step key is already defined.
|
||||||
"""
|
"""
|
||||||
graph = ops.get_default_graph() if graph is None else graph
|
graph = ops.get_default_graph() if graph is None else graph
|
||||||
@ -133,10 +145,27 @@ def create_global_step(graph=None):
|
|||||||
raise ValueError('"global_step" already exists.')
|
raise ValueError('"global_step" already exists.')
|
||||||
# Create in proper graph and base name_scope.
|
# Create in proper graph and base name_scope.
|
||||||
with graph.as_default() as g, g.name_scope(None):
|
with graph.as_default() as g, g.name_scope(None):
|
||||||
result = variables.Variable(
|
collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP]
|
||||||
0, trainable=False, dtype=dtypes.int64, name=ops.GraphKeys.GLOBAL_STEP)
|
return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64,
|
||||||
graph.add_to_collection(ops.GraphKeys.GLOBAL_STEP, result)
|
initializer=init_ops.zeros_initializer, trainable=False,
|
||||||
return result
|
collections=collections)
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_global_step(graph=None):
|
||||||
|
"""Returns and create (if necessary) the global step variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The graph in which to create the global step. If missing, use default
|
||||||
|
graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the tensor representing the global step variable.
|
||||||
|
"""
|
||||||
|
graph = ops.get_default_graph() if graph is None else graph
|
||||||
|
globalstep = get_global_step(graph)
|
||||||
|
if globalstep is None:
|
||||||
|
globalstep = create_global_step(graph)
|
||||||
|
return globalstep
|
||||||
|
|
||||||
|
|
||||||
def local_variable(initial_value, validate_shape=True, name=None):
|
def local_variable(initial_value, validate_shape=True, name=None):
|
||||||
@ -154,3 +183,268 @@ def local_variable(initial_value, validate_shape=True, name=None):
|
|||||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||||
validate_shape=validate_shape, name=name)
|
validate_shape=validate_shape, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@contrib_add_arg_scope
|
||||||
|
def variable(name, shape=None, dtype=dtypes.float32, initializer=None,
|
||||||
|
regularizer=None, trainable=True, collections=None,
|
||||||
|
caching_device=None, device=None):
|
||||||
|
"""Gets an existing variable with these parameters or creates a new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name of the new or existing variable.
|
||||||
|
shape: shape of the new or existing variable.
|
||||||
|
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
|
||||||
|
initializer: initializer for the variable if one is created.
|
||||||
|
regularizer: a (Tensor -> Tensor or None) function; the result of
|
||||||
|
applying it on a newly created variable will be added to the collection
|
||||||
|
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||||
|
trainable: If `True` also add the variable to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||||
|
collections: A list of collection names to which the Variable will be added.
|
||||||
|
If None it would default to tf.GraphKeys.VARIABLES.
|
||||||
|
caching_device: Optional device string or function describing where the
|
||||||
|
Variable should be cached for reading. Defaults to the Variable's
|
||||||
|
device.
|
||||||
|
device: Optional device to place the variable. It can be an string or a
|
||||||
|
function that is called to get the device for the variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created or existing variable.
|
||||||
|
"""
|
||||||
|
collections = list(collections or [ops.GraphKeys.VARIABLES])
|
||||||
|
|
||||||
|
# Remove duplicates
|
||||||
|
collections = set(collections)
|
||||||
|
with ops.device(device or ''):
|
||||||
|
return variable_scope.get_variable(name, shape=shape, dtype=dtype,
|
||||||
|
initializer=initializer,
|
||||||
|
regularizer=regularizer,
|
||||||
|
trainable=trainable,
|
||||||
|
collections=collections,
|
||||||
|
caching_device=caching_device)
|
||||||
|
|
||||||
|
# TODO(sguada) move it to ops.GraphKeys or to contrib.framework.GraphKeys
|
||||||
|
# Collection containing all the variables created using model_variables.
|
||||||
|
MODEL_VARIABLES = '_model_variables_'
|
||||||
|
|
||||||
|
|
||||||
|
@contrib_add_arg_scope
|
||||||
|
def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
|
||||||
|
regularizer=None, trainable=True, collections=None,
|
||||||
|
caching_device=None, device=None):
|
||||||
|
"""Gets an existing model variable with these parameters or creates a new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name of the new or existing variable.
|
||||||
|
shape: shape of the new or existing variable.
|
||||||
|
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
|
||||||
|
initializer: initializer for the variable if one is created.
|
||||||
|
regularizer: a (Tensor -> Tensor or None) function; the result of
|
||||||
|
applying it on a newly created variable will be added to the collection
|
||||||
|
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||||
|
trainable: If `True` also add the variable to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||||
|
collections: A list of collection names to which the Variable will be added.
|
||||||
|
Note that the variable is always also added to the tf.GraphKeys.VARIABLES
|
||||||
|
and MODEL_VARIABLES collections.
|
||||||
|
caching_device: Optional device string or function describing where the
|
||||||
|
Variable should be cached for reading. Defaults to the Variable's
|
||||||
|
device.
|
||||||
|
device: Optional device to place the variable. It can be an string or a
|
||||||
|
function that is called to get the device for the variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created or existing variable.
|
||||||
|
"""
|
||||||
|
collections = list(collections or [])
|
||||||
|
|
||||||
|
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
|
||||||
|
collections += [ops.GraphKeys.VARIABLES, MODEL_VARIABLES]
|
||||||
|
return variable(name, shape=shape, dtype=dtype,
|
||||||
|
initializer=initializer, regularizer=regularizer,
|
||||||
|
trainable=trainable, collections=collections,
|
||||||
|
caching_device=caching_device, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_variable(var):
|
||||||
|
"""Adds a variable to the MODEL_VARIABLES collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var: a variable.
|
||||||
|
"""
|
||||||
|
if var not in ops.get_collection(MODEL_VARIABLES):
|
||||||
|
ops.add_to_collection(MODEL_VARIABLES, var)
|
||||||
|
|
||||||
|
|
||||||
|
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES):
|
||||||
|
"""Gets the list of variables, filtered by scope and/or suffix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: an optional scope for filtering the variables to return.
|
||||||
|
suffix: an optional suffix for filtering the variables to return.
|
||||||
|
collection: in which collection search for. Defaults to GraphKeys.VARIABLES.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of variables in colelction with scope and suffix.
|
||||||
|
"""
|
||||||
|
if suffix is not None:
|
||||||
|
if ':' not in suffix:
|
||||||
|
suffix += ':'
|
||||||
|
scope = (scope or '') + '.*' + suffix
|
||||||
|
return ops.get_collection(collection, scope)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_variables(scope=None, suffix=None):
|
||||||
|
"""Gets the list of model variables, filtered by scope and/or suffix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: an optional scope for filtering the variables to return.
|
||||||
|
suffix: an optional suffix for filtering the variables to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of variables in colelction with scope and suffix.
|
||||||
|
"""
|
||||||
|
return get_variables(scope, suffix, MODEL_VARIABLES)
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_variables(scope=None, suffix=None):
|
||||||
|
"""Gets the list of model variables, filtered by scope and/or suffix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: an optional scope for filtering the variables to return.
|
||||||
|
suffix: an optional suffix for filtering the variables to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of variables in colelction with scope and suffix.
|
||||||
|
"""
|
||||||
|
return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES)
|
||||||
|
|
||||||
|
|
||||||
|
def get_variables_to_restore(include=None, exclude=None):
|
||||||
|
"""Gets the list of the variables to restore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include: an optional list/tuple of scope strings for filtering which
|
||||||
|
variables from the VARIABLES collection to include. None would include all
|
||||||
|
the variables.
|
||||||
|
exclude: an optional list/tuple of scope strings for filtering which
|
||||||
|
variables from the VARIABLES collection to exclude. None it would not
|
||||||
|
exclude any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of variables to restore.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: include or exclude is provided but is not a list or a tuple.
|
||||||
|
"""
|
||||||
|
if include is None:
|
||||||
|
# Include all variables.
|
||||||
|
vars_to_include = get_variables()
|
||||||
|
else:
|
||||||
|
if not isinstance(include, (list, tuple)):
|
||||||
|
raise TypeError('include is provided but is not a list or a tuple.')
|
||||||
|
vars_to_include = []
|
||||||
|
for scope in include:
|
||||||
|
vars_to_include += get_variables(scope)
|
||||||
|
vars_to_exclude = set()
|
||||||
|
if exclude is not None:
|
||||||
|
if not isinstance(exclude, (list, tuple)):
|
||||||
|
raise TypeError('exclude is provided but is not a list or a tuple.')
|
||||||
|
for scope in exclude:
|
||||||
|
vars_to_exclude |= set(get_variables(scope))
|
||||||
|
# Exclude the variables in vars_to_exclude
|
||||||
|
return [v for v in vars_to_include if v not in vars_to_exclude]
|
||||||
|
|
||||||
|
|
||||||
|
def get_variables_by_suffix(suffix, scope=None):
|
||||||
|
"""Gets the list of variables that end with the given suffix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
suffix: suffix for filtering the variables to return.
|
||||||
|
scope: an optional scope for filtering the variables to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a copied list of variables with the given name and prefix.
|
||||||
|
"""
|
||||||
|
return get_variables(scope=scope, suffix=suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def get_variables_by_name(given_name, scope=None):
|
||||||
|
"""Gets the list of variables that were given that name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
given_name: name given to the variable without any scope.
|
||||||
|
scope: an optional scope for filtering the variables to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a copied list of variables with the given name and scope.
|
||||||
|
"""
|
||||||
|
suffix = '/' + given_name + ':|^' + given_name + ':'
|
||||||
|
return get_variables(scope=scope, suffix=suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_variable(var_op_name):
|
||||||
|
"""Gets the variable uniquely identified by that var_op_name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var_op_name: the full name of the variable op, including the scope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tensorflow variable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if no variable uniquely identified by the name exists.
|
||||||
|
"""
|
||||||
|
candidates = get_variables(scope=var_op_name)
|
||||||
|
if not candidates:
|
||||||
|
raise ValueError('Couldnt find variable %s' % var_op_name)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate.op.name == var_op_name:
|
||||||
|
return candidate
|
||||||
|
raise ValueError('Variable %s does not uniquely identify a variable',
|
||||||
|
var_op_name)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableDeviceChooser(object):
|
||||||
|
"""Device chooser for variables.
|
||||||
|
|
||||||
|
When using a parameter server it will assign them in a round-robin fashion.
|
||||||
|
When not using a parameter server it allows GPU or CPU placement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_tasks=0,
|
||||||
|
device_type='CPU',
|
||||||
|
device_index=0):
|
||||||
|
"""Initialize VariableDeviceChooser.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
To use with 2 parameter servers:
|
||||||
|
VariableDeviceChooser(2)
|
||||||
|
|
||||||
|
To use without parameter servers:
|
||||||
|
VariableDeviceChooser()
|
||||||
|
VariableDeviceChooser(device_type='GPU') # For GPU placement
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tasks: number of tasks.
|
||||||
|
device_type: Optional device type string (e.g. "CPU" or "GPU")
|
||||||
|
device_index: int. Optional device index. If left
|
||||||
|
unspecified, device represents 'any' device_index.
|
||||||
|
"""
|
||||||
|
self._job_name = 'ps' if num_tasks > 0 else None
|
||||||
|
self._device_type = device_type
|
||||||
|
self._device_index = device_index
|
||||||
|
self._num_tasks = num_tasks
|
||||||
|
self._next_task_id = 0
|
||||||
|
|
||||||
|
def __call__(self, op):
|
||||||
|
device_spec = tf_device.DeviceSpec(job=self._job_name,
|
||||||
|
device_type=self._device_type,
|
||||||
|
device_index=self._device_index)
|
||||||
|
if self._num_tasks > 0:
|
||||||
|
task_id = self._next_task_id
|
||||||
|
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
|
||||||
|
device_spec.task = task_id
|
||||||
|
return device_spec.to_string()
|
||||||
|
@ -37,11 +37,57 @@ class LocalVariableTest(tf.test.TestCase):
|
|||||||
tf.initialize_variables(variables).run()
|
tf.initialize_variables(variables).run()
|
||||||
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
||||||
|
|
||||||
|
def testLocalVariableNameAndShape(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.local_variable([1, 1, 1, 1, 1], name='a')
|
||||||
|
self.assertEquals(a.op.name, 'A/a')
|
||||||
|
self.assertListEqual(a.get_shape().as_list(), [5])
|
||||||
|
self.assertListEqual([a], tf.contrib.framework.get_local_variables())
|
||||||
|
|
||||||
|
def testLocalVariableNotInAllVariables(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.local_variable(0)
|
||||||
|
self.assertFalse(a in tf.all_variables())
|
||||||
|
self.assertTrue(a in tf.local_variables())
|
||||||
|
|
||||||
|
def testLocalVariableNotInVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.local_variable(0)
|
||||||
|
self.assertFalse(a in tf.contrib.framework.get_variables_to_restore())
|
||||||
|
self.assertTrue(a in tf.local_variables())
|
||||||
|
|
||||||
|
def testGetVariablesDontReturnsTransients(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
tf.contrib.framework.local_variable(0)
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
tf.contrib.framework.local_variable(0)
|
||||||
|
self.assertEquals([], tf.contrib.framework.get_variables('A'))
|
||||||
|
self.assertEquals([], tf.contrib.framework.get_variables('B'))
|
||||||
|
|
||||||
|
def testGetLocalVariablesReturnsTransients(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.local_variable(0)
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.local_variable(0)
|
||||||
|
self.assertEquals([a], tf.contrib.framework.get_local_variables('A'))
|
||||||
|
self.assertEquals([b], tf.contrib.framework.get_local_variables('B'))
|
||||||
|
|
||||||
|
def testInitializedVariableValue(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
||||||
|
sess.run(tf.initialize_local_variables())
|
||||||
|
self.assertAllEqual(a.eval(), [0]*5)
|
||||||
|
|
||||||
|
|
||||||
class GlobalStepTest(tf.test.TestCase):
|
class GlobalStepTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _assert_global_step(self, global_step, expected_dtype=tf.int64):
|
def _assert_global_step(self, global_step, expected_dtype=tf.int64):
|
||||||
self.assertEquals("%s:0" % tf.GraphKeys.GLOBAL_STEP, global_step.name)
|
self.assertEquals('%s:0' % tf.GraphKeys.GLOBAL_STEP, global_step.name)
|
||||||
self.assertEquals(expected_dtype, global_step.dtype.base_dtype)
|
self.assertEquals(expected_dtype, global_step.dtype.base_dtype)
|
||||||
self.assertEquals([], global_step.get_shape().as_list())
|
self.assertEquals([], global_step.get_shape().as_list())
|
||||||
|
|
||||||
@ -51,10 +97,10 @@ class GlobalStepTest(tf.test.TestCase):
|
|||||||
tf.Variable(
|
tf.Variable(
|
||||||
0.0, trainable=False, dtype=tf.float32, name=tf.GraphKeys.GLOBAL_STEP)
|
0.0, trainable=False, dtype=tf.float32, name=tf.GraphKeys.GLOBAL_STEP)
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
TypeError, "does not have integer type",
|
TypeError, 'does not have integer type',
|
||||||
tf.contrib.framework.get_global_step)
|
tf.contrib.framework.get_global_step)
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
TypeError, "does not have integer type",
|
TypeError, 'does not have integer type',
|
||||||
tf.contrib.framework.get_global_step, g)
|
tf.contrib.framework.get_global_step, g)
|
||||||
|
|
||||||
def test_invalid_shape(self):
|
def test_invalid_shape(self):
|
||||||
@ -63,10 +109,10 @@ class GlobalStepTest(tf.test.TestCase):
|
|||||||
tf.Variable(
|
tf.Variable(
|
||||||
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
|
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
TypeError, "not scalar",
|
TypeError, 'not scalar',
|
||||||
tf.contrib.framework.get_global_step)
|
tf.contrib.framework.get_global_step)
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
TypeError, "not scalar",
|
TypeError, 'not scalar',
|
||||||
tf.contrib.framework.get_global_step, g)
|
tf.contrib.framework.get_global_step, g)
|
||||||
|
|
||||||
def test_create_global_step(self):
|
def test_create_global_step(self):
|
||||||
@ -75,9 +121,9 @@ class GlobalStepTest(tf.test.TestCase):
|
|||||||
global_step = tf.contrib.framework.create_global_step()
|
global_step = tf.contrib.framework.create_global_step()
|
||||||
self._assert_global_step(global_step)
|
self._assert_global_step(global_step)
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
ValueError, "already exists", tf.contrib.framework.create_global_step)
|
ValueError, 'already exists', tf.contrib.framework.create_global_step)
|
||||||
self.assertRaisesRegexp(
|
self.assertRaisesRegexp(
|
||||||
ValueError, "already exists", tf.contrib.framework.create_global_step,
|
ValueError, 'already exists', tf.contrib.framework.create_global_step,
|
||||||
g)
|
g)
|
||||||
self._assert_global_step(
|
self._assert_global_step(
|
||||||
tf.contrib.framework.create_global_step(tf.Graph()))
|
tf.contrib.framework.create_global_step(tf.Graph()))
|
||||||
@ -92,6 +138,510 @@ class GlobalStepTest(tf.test.TestCase):
|
|||||||
self._assert_global_step(
|
self._assert_global_step(
|
||||||
tf.contrib.framework.get_global_step(g), expected_dtype=tf.int32)
|
tf.contrib.framework.get_global_step(g), expected_dtype=tf.int32)
|
||||||
|
|
||||||
|
def test_get_or_create_global_step(self):
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
||||||
|
self._assert_global_step(
|
||||||
|
tf.contrib.framework.get_or_create_global_step())
|
||||||
|
self._assert_global_step(
|
||||||
|
tf.contrib.framework.get_or_create_global_step(g))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
class VariablesTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testCreateVariable(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
self.assertEquals(a.op.name, 'A/a')
|
||||||
|
self.assertListEqual(a.get_shape().as_list(), [5])
|
||||||
|
|
||||||
|
def testGetVariables(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.variable('a', [5])
|
||||||
|
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||||
|
self.assertEquals([a], tf.contrib.framework.get_variables('A'))
|
||||||
|
self.assertEquals([b], tf.contrib.framework.get_variables('B'))
|
||||||
|
|
||||||
|
def testGetVariablesSuffix(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
b = tf.contrib.framework.variable('b', [5])
|
||||||
|
self.assertEquals([a], tf.contrib.framework.get_variables(suffix='a'))
|
||||||
|
self.assertEquals([b], tf.contrib.framework.get_variables(suffix='b'))
|
||||||
|
|
||||||
|
def testGetVariableWithSingleVar(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('parent'):
|
||||||
|
a = tf.contrib.framework.variable('child', [5])
|
||||||
|
self.assertEquals(
|
||||||
|
a, tf.contrib.framework.get_unique_variable('parent/child'))
|
||||||
|
|
||||||
|
def testGetVariableWithDistractors(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('parent'):
|
||||||
|
a = tf.contrib.framework.variable('child', [5])
|
||||||
|
with tf.variable_scope('child'):
|
||||||
|
tf.contrib.framework.variable('grandchild1', [7])
|
||||||
|
tf.contrib.framework.variable('grandchild2', [9])
|
||||||
|
self.assertEquals(
|
||||||
|
a, tf.contrib.framework.get_unique_variable('parent/child'))
|
||||||
|
|
||||||
|
def testGetVariableThrowsExceptionWithNoMatch(self):
|
||||||
|
var_name = 'cant_find_me'
|
||||||
|
with self.test_session():
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.framework.get_unique_variable(var_name)
|
||||||
|
|
||||||
|
def testGetThrowsExceptionWithChildrenButNoMatch(self):
|
||||||
|
var_name = 'parent/child'
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope(var_name):
|
||||||
|
tf.contrib.framework.variable('grandchild1', [7])
|
||||||
|
tf.contrib.framework.variable('grandchild2', [9])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.contrib.framework.get_unique_variable(var_name)
|
||||||
|
|
||||||
|
def testGetVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.variable('a', [5])
|
||||||
|
self.assertEquals([a, b],
|
||||||
|
tf.contrib.framework.get_variables_to_restore())
|
||||||
|
|
||||||
|
def testIncludeGetVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.variable('a', [5])
|
||||||
|
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||||
|
self.assertEquals([a],
|
||||||
|
tf.contrib.framework.get_variables_to_restore(['A']))
|
||||||
|
|
||||||
|
def testExcludeGetVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.variable('a', [5])
|
||||||
|
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||||
|
self.assertEquals([a],
|
||||||
|
tf.contrib.framework.get_variables_to_restore(
|
||||||
|
exclude=['B']))
|
||||||
|
|
||||||
|
def testWrongIncludeGetVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.variable('a', [5])
|
||||||
|
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||||
|
self.assertEquals([],
|
||||||
|
tf.contrib.framework.get_variables_to_restore(['a']))
|
||||||
|
|
||||||
|
def testGetMixedVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
b = tf.contrib.framework.variable('b', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
c = tf.contrib.framework.variable('c', [5])
|
||||||
|
d = tf.contrib.framework.variable('d', [5])
|
||||||
|
self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables())
|
||||||
|
self.assertEquals([a, c],
|
||||||
|
tf.contrib.framework.get_variables_to_restore(
|
||||||
|
include=['A/a', 'B/c']))
|
||||||
|
|
||||||
|
def testExcludeGetMixedVariablesToRestore(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
b = tf.contrib.framework.variable('b', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
c = tf.contrib.framework.variable('c', [5])
|
||||||
|
d = tf.contrib.framework.variable('d', [5])
|
||||||
|
self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables())
|
||||||
|
self.assertEquals([b, d],
|
||||||
|
tf.contrib.framework.get_variables_to_restore(
|
||||||
|
exclude=['A/a', 'B/c']))
|
||||||
|
|
||||||
|
def testReuseVariable(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
with tf.variable_scope('A', reuse=True):
|
||||||
|
b = tf.contrib.framework.variable('a', [])
|
||||||
|
self.assertEquals(a, b)
|
||||||
|
self.assertListEqual([a], tf.contrib.framework.get_variables())
|
||||||
|
|
||||||
|
def testVariableWithRegularizer(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [], regularizer=tf.nn.l2_loss)
|
||||||
|
loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
|
||||||
|
self.assertDeviceEqual(loss.device, a.device)
|
||||||
|
|
||||||
|
def testVariableWithRegularizerColocate(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [], device='gpu:0',
|
||||||
|
regularizer=tf.nn.l2_loss)
|
||||||
|
loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
|
||||||
|
self.assertDeviceEqual(loss.device, a.device)
|
||||||
|
|
||||||
|
def testVariableWithDevice(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [], device='cpu:0')
|
||||||
|
b = tf.contrib.framework.variable('b', [], device='cpu:1')
|
||||||
|
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||||
|
self.assertDeviceEqual(b.device, 'cpu:1')
|
||||||
|
|
||||||
|
def testVariableWithDeviceFromScope(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.device('/cpu:0'):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
b = tf.contrib.framework.variable('b', [], device='cpu:1')
|
||||||
|
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||||
|
self.assertDeviceEqual(b.device, 'cpu:1')
|
||||||
|
|
||||||
|
def testVariableWithDeviceFunction(self):
|
||||||
|
class DevFn(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.counter = -1
|
||||||
|
|
||||||
|
def __call__(self, op):
|
||||||
|
self.counter += 1
|
||||||
|
return 'cpu:%d' % self.counter
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
device=DevFn()):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||||
|
d = tf.contrib.framework.variable('d', [])
|
||||||
|
with tf.device('cpu:99'):
|
||||||
|
e_init = tf.constant(12)
|
||||||
|
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||||
|
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||||
|
self.assertDeviceEqual(a.initial_value.device, 'cpu:0')
|
||||||
|
self.assertDeviceEqual(b.device, 'cpu:1')
|
||||||
|
self.assertDeviceEqual(b.initial_value.device, 'cpu:1')
|
||||||
|
self.assertDeviceEqual(c.device, 'cpu:12')
|
||||||
|
self.assertDeviceEqual(c.initial_value.device, 'cpu:12')
|
||||||
|
self.assertDeviceEqual(d.device, 'cpu:2')
|
||||||
|
self.assertDeviceEqual(d.initial_value.device, 'cpu:2')
|
||||||
|
self.assertDeviceEqual(e.device, 'cpu:3')
|
||||||
|
self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
|
||||||
|
|
||||||
|
def testVariableWithReplicaDeviceSetter(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.device(tf.train.replica_device_setter(ps_tasks=2)):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||||
|
d = tf.contrib.framework.variable('d', [])
|
||||||
|
with tf.device('cpu:99'):
|
||||||
|
e_init = tf.constant(12)
|
||||||
|
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||||
|
# The values below highlight how the replica_device_setter puts initial
|
||||||
|
# values on the worker job, and how it merges explicit devices.
|
||||||
|
self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
|
||||||
|
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||||
|
self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
|
||||||
|
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||||
|
self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12')
|
||||||
|
self.assertDeviceEqual(c.initial_value.device, c.device)
|
||||||
|
self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0')
|
||||||
|
self.assertDeviceEqual(d.initial_value.device, d.device)
|
||||||
|
self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0')
|
||||||
|
self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99')
|
||||||
|
|
||||||
|
def testVariableWithVariableDeviceChooser(self):
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
device_fn = tf.contrib.framework.VariableDeviceChooser(num_tasks=2)
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
device=device_fn):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||||
|
d = tf.contrib.framework.variable('d', [])
|
||||||
|
with tf.device('cpu:99'):
|
||||||
|
e_init = tf.constant(12)
|
||||||
|
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||||
|
# The values below highlight how the VariableDeviceChooser puts initial
|
||||||
|
# values on the same device as the variable job.
|
||||||
|
self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
|
||||||
|
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||||
|
self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
|
||||||
|
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||||
|
self.assertDeviceEqual(c.device, '/cpu:12')
|
||||||
|
self.assertDeviceEqual(c.initial_value.device, c.device)
|
||||||
|
self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
|
||||||
|
self.assertDeviceEqual(d.initial_value.device, d.device)
|
||||||
|
self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
|
||||||
|
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
|
||||||
|
|
||||||
|
def testVariableGPUPlacement(self):
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
device_fn = tf.contrib.framework.VariableDeviceChooser(device_type='GPU')
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
device=device_fn):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||||
|
d = tf.contrib.framework.variable('d', [])
|
||||||
|
with tf.device('cpu:99'):
|
||||||
|
e_init = tf.constant(12)
|
||||||
|
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||||
|
# The values below highlight how the VariableDeviceChooser puts initial
|
||||||
|
# values on the same device as the variable job.
|
||||||
|
self.assertDeviceEqual(a.device, '/gpu:0')
|
||||||
|
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||||
|
self.assertDeviceEqual(b.device, '/gpu:0')
|
||||||
|
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||||
|
self.assertDeviceEqual(c.device, '/cpu:12')
|
||||||
|
self.assertDeviceEqual(c.initial_value.device, c.device)
|
||||||
|
self.assertDeviceEqual(d.device, '/gpu:0')
|
||||||
|
self.assertDeviceEqual(d.initial_value.device, d.device)
|
||||||
|
self.assertDeviceEqual(e.device, '/gpu:0')
|
||||||
|
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVariablesTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testNameAndShape(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
self.assertEquals(a.op.name, 'A/a')
|
||||||
|
self.assertListEqual(a.get_shape().as_list(), [5])
|
||||||
|
self.assertListEqual([a], tf.contrib.framework.get_model_variables('A'))
|
||||||
|
|
||||||
|
def testNotInLocalVariables(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
self.assertTrue(a in tf.all_variables())
|
||||||
|
self.assertFalse(a in tf.local_variables())
|
||||||
|
|
||||||
|
def testGetVariablesReturns(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
self.assertEquals([a], tf.contrib.framework.get_variables('A'))
|
||||||
|
self.assertEquals([b], tf.contrib.framework.get_variables('B'))
|
||||||
|
|
||||||
|
def testGetModelVariables(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
b = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
self.assertEquals([a], tf.contrib.framework.get_model_variables('A'))
|
||||||
|
self.assertEquals([b], tf.contrib.framework.get_model_variables('B'))
|
||||||
|
|
||||||
|
def testGetLocalVariables(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
_ = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
_ = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
self.assertEquals([], tf.contrib.framework.get_local_variables('A'))
|
||||||
|
self.assertEquals([], tf.contrib.framework.get_local_variables('B'))
|
||||||
|
|
||||||
|
def testInitializedVariableValue(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5], initializer=tf.ones)
|
||||||
|
sess.run(tf.initialize_all_variables())
|
||||||
|
self.assertAllEqual(a.eval(), [1]*5)
|
||||||
|
|
||||||
|
def testDeviceFn(self):
|
||||||
|
class DevFn(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.counter = -1
|
||||||
|
|
||||||
|
def __call__(self, op):
|
||||||
|
self.counter += 1
|
||||||
|
return '/cpu:%d' % self.counter
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable],
|
||||||
|
device=DevFn()):
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
b = tf.contrib.framework.model_variable('b', [20])
|
||||||
|
self.assertDeviceEqual(a.device, '/cpu:0')
|
||||||
|
self.assertDeviceEqual(a.initial_value.device, '/cpu:0')
|
||||||
|
self.assertDeviceEqual(b.device, '/cpu:1')
|
||||||
|
self.assertDeviceEqual(b.initial_value.device, '/cpu:1')
|
||||||
|
|
||||||
|
def testVariableWithVariableDeviceChooser(self):
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
device_fn = tf.contrib.framework.VariableDeviceChooser()
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable],
|
||||||
|
device=device_fn):
|
||||||
|
a = tf.contrib.framework.model_variable('a', [5])
|
||||||
|
b = tf.contrib.framework.model_variable('b', [20])
|
||||||
|
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||||
|
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||||
|
self.assertDeviceEqual(b.device, 'cpu:0')
|
||||||
|
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||||
|
|
||||||
|
|
||||||
|
class GetVariablesCollections(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testVariableCollection(self):
|
||||||
|
with self.test_session():
|
||||||
|
a = tf.contrib.framework.variable('a', [], collections='A')
|
||||||
|
b = tf.contrib.framework.variable('b', [], collections='B')
|
||||||
|
self.assertEquals(a, tf.get_collection('A')[0])
|
||||||
|
self.assertEquals(b, tf.get_collection('B')[0])
|
||||||
|
|
||||||
|
def testVariableCollections(self):
|
||||||
|
with self.test_session():
|
||||||
|
a = tf.contrib.framework.variable('a', [], collections=['A', 'C'])
|
||||||
|
b = tf.contrib.framework.variable('b', [], collections=['B', 'C'])
|
||||||
|
self.assertEquals(a, tf.get_collection('A')[0])
|
||||||
|
self.assertEquals(b, tf.get_collection('B')[0])
|
||||||
|
self.assertListEqual([a, b], tf.get_collection('C'))
|
||||||
|
|
||||||
|
def testVariableCollectionsWithArgScope(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
collections='A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
self.assertListEqual([a, b], tf.get_collection('A'))
|
||||||
|
|
||||||
|
def testVariableCollectionsWithArgScopeNested(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
collections='A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
collections='B'):
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
self.assertEquals(a, tf.get_collection('A')[0])
|
||||||
|
self.assertEquals(b, tf.get_collection('B')[0])
|
||||||
|
|
||||||
|
def testVariableCollectionsWithArgScopeNonNested(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
collections='A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
collections='B'):
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
tf.contrib.framework.variable('c', [])
|
||||||
|
self.assertListEqual([a], tf.get_collection('A'))
|
||||||
|
self.assertListEqual([b], tf.get_collection('B'))
|
||||||
|
|
||||||
|
def testVariableRestoreWithArgScopeNested(self):
|
||||||
|
with self.test_session():
|
||||||
|
a = tf.contrib.framework.variable('a', [])
|
||||||
|
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||||
|
trainable=False,
|
||||||
|
collections=['A', 'B']):
|
||||||
|
b = tf.contrib.framework.variable('b', [])
|
||||||
|
c = tf.contrib.framework.variable('c', [], trainable=False)
|
||||||
|
self.assertEquals([a, c], tf.contrib.framework.get_variables_to_restore())
|
||||||
|
self.assertEquals([a], tf.trainable_variables())
|
||||||
|
self.assertEquals([b], tf.get_collection('A'))
|
||||||
|
self.assertEquals([b], tf.get_collection('B'))
|
||||||
|
|
||||||
|
|
||||||
|
class GetVariablesBySuffixTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testGetVariableGivenNameScoped(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
b = tf.contrib.framework.variable('b', [5])
|
||||||
|
self.assertEquals([a],
|
||||||
|
tf.contrib.framework.get_variables_by_suffix('a'))
|
||||||
|
self.assertEquals([b],
|
||||||
|
tf.contrib.framework.get_variables_by_suffix('b'))
|
||||||
|
|
||||||
|
def testGetVariableWithScope(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
a2 = tf.contrib.framework.variable('a', [5])
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_suffix('a')
|
||||||
|
self.assertEquals([a, fooa, a2], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_suffix('/a')
|
||||||
|
self.assertEquals([a, a2], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_suffix(
|
||||||
|
'a', scope='A')
|
||||||
|
self.assertEquals([a, fooa], matched_variables)
|
||||||
|
|
||||||
|
def testGetVariableWithoutScope(self):
|
||||||
|
with self.test_session():
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||||
|
b_a = tf.contrib.framework.variable('B/a', [5])
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_suffix('a')
|
||||||
|
self.assertEquals([a, fooa, b_a], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_suffix('fooa')
|
||||||
|
self.assertEquals([fooa], matched_variables)
|
||||||
|
|
||||||
|
|
||||||
|
class GetVariablesByNameTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testGetVariableGivenNameScoped(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
b = tf.contrib.framework.variable('b', [5])
|
||||||
|
self.assertEquals([a], tf.contrib.framework.get_variables_by_name('a'))
|
||||||
|
self.assertEquals([b], tf.contrib.framework.get_variables_by_name('b'))
|
||||||
|
|
||||||
|
def testGetVariableWithScope(self):
|
||||||
|
with self.test_session():
|
||||||
|
with tf.variable_scope('A'):
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||||
|
with tf.variable_scope('B'):
|
||||||
|
a2 = tf.contrib.framework.variable('a', [5])
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_name('a')
|
||||||
|
self.assertEquals([a, a2], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_name('fooa')
|
||||||
|
self.assertEquals([fooa], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_name('/a')
|
||||||
|
self.assertEquals([], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_name('a',
|
||||||
|
scope='A')
|
||||||
|
self.assertEquals([a], matched_variables)
|
||||||
|
|
||||||
|
def testGetVariableWithoutScope(self):
|
||||||
|
with self.test_session():
|
||||||
|
a = tf.contrib.framework.variable('a', [5])
|
||||||
|
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||||
|
b_a = tf.contrib.framework.variable('B/a', [5])
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_name('a')
|
||||||
|
for v in matched_variables:
|
||||||
|
print(v.name)
|
||||||
|
self.assertEquals([a, b_a], matched_variables)
|
||||||
|
matched_variables = tf.contrib.framework.get_variables_by_name('fooa')
|
||||||
|
self.assertEquals([fooa], matched_variables)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user