Added model_variable and helpers to manage variables.
Change: 122727793
This commit is contained in:
parent
63c29c82b3
commit
144855b385
@ -18,13 +18,40 @@
|
||||
@@assert_same_float_dtype
|
||||
@@assert_scalar_int
|
||||
@@convert_to_tensor_or_sparse_tensor
|
||||
@@local_variable
|
||||
@@get_graph_from_inputs
|
||||
@@is_numeric_tensor
|
||||
@@is_non_decreasing
|
||||
@@is_strictly_increasing
|
||||
@@reduce_sum_n
|
||||
@@safe_embedding_lookup_sparse
|
||||
@@with_shape
|
||||
@@with_same_shape
|
||||
|
||||
@@get_graph_from_inputs
|
||||
|
||||
## Arg_Scope
|
||||
@@arg_scope
|
||||
@@add_arg_scope
|
||||
@@has_arg_scope
|
||||
@@arg_scoped_arguments
|
||||
|
||||
## Variables
|
||||
@@add_model_variable
|
||||
@@assert_global_step
|
||||
@@assert_or_get_global_step
|
||||
@@create_global_step
|
||||
@@get_global_step
|
||||
@@get_or_create_global_step
|
||||
@@get_local_variables
|
||||
@@get_model_variables
|
||||
@@get_unique_variable
|
||||
@@get_variables_by_name
|
||||
@@get_variables_by_suffix
|
||||
@@get_variables_to_restore
|
||||
@@get_variables
|
||||
@@local_variable
|
||||
@@model_variable
|
||||
@@variable
|
||||
@@VariableDeviceChooser
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -14,7 +14,6 @@
|
||||
# ==============================================================================
|
||||
|
||||
"""Tensor utility functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -27,14 +26,16 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
__all__ = [
|
||||
'assert_same_float_dtype', 'assert_scalar_int',
|
||||
'convert_to_tensor_or_sparse_tensor', 'local_variable', 'reduce_sum_n',
|
||||
'with_shape', 'with_same_shape',
|
||||
]
|
||||
'assert_same_float_dtype',
|
||||
'assert_scalar_int',
|
||||
'convert_to_tensor_or_sparse_tensor',
|
||||
'reduce_sum_n',
|
||||
'with_shape',
|
||||
'with_same_shape']
|
||||
|
||||
|
||||
def _assert_same_base_type(items, expected_type=None):
|
||||
"""Asserts all items are of the same base type.
|
||||
r"""Asserts all items are of the same base type.
|
||||
|
||||
Args:
|
||||
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
|
||||
@ -110,23 +111,6 @@ def assert_scalar_int(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
# TODO(ptucker): Move to tf.variables?
|
||||
def local_variable(initial_value, validate_shape=True, name=None):
|
||||
"""Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.
|
||||
|
||||
Args:
|
||||
initial_value: See variables.Variable.__init__.
|
||||
validate_shape: See variables.Variable.__init__.
|
||||
name: See variables.Variable.__init__.
|
||||
Returns:
|
||||
New variable.
|
||||
"""
|
||||
return variables.Variable(
|
||||
initial_value, trainable=False,
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||
validate_shape=validate_shape, name=name)
|
||||
|
||||
|
||||
def reduce_sum_n(tensors, name=None):
|
||||
"""Reduce tensors to a scalar sum.
|
||||
|
||||
|
@ -47,11 +47,6 @@
|
||||
|
||||
@tf.contrib.add_arg_scope
|
||||
def conv2d(*args, **kwargs)
|
||||
|
||||
@@arg_scope
|
||||
@@add_arg_scope
|
||||
@@has_arg_scope
|
||||
@@arg_scoped_arguments
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -59,8 +54,10 @@ from __future__ import print_function
|
||||
import contextlib
|
||||
import functools
|
||||
|
||||
__all__ = ['arg_scope', 'add_arg_scope',
|
||||
'has_arg_scope', 'arg_scoped_arguments']
|
||||
__all__ = ['arg_scope',
|
||||
'add_arg_scope',
|
||||
'has_arg_scope',
|
||||
'arg_scoped_arguments']
|
||||
|
||||
_ARGSTACK = [{}]
|
||||
|
||||
|
@ -14,24 +14,37 @@
|
||||
# ==============================================================================
|
||||
|
||||
"""Variable functions.
|
||||
|
||||
@@assert_global_step
|
||||
@@create_global_step
|
||||
@@get_global_step
|
||||
@@assert_or_get_global_step
|
||||
@@local_variable
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
__all__ = [
|
||||
'assert_global_step', 'create_global_step', 'get_global_step',
|
||||
'assert_or_get_global_step', 'local_variable']
|
||||
|
||||
__all__ = ['add_model_variable',
|
||||
'assert_global_step',
|
||||
'assert_or_get_global_step',
|
||||
'create_global_step',
|
||||
'get_global_step',
|
||||
'get_or_create_global_step',
|
||||
'get_local_variables',
|
||||
'get_model_variables',
|
||||
'get_unique_variable',
|
||||
'get_variables_by_name',
|
||||
'get_variables_by_suffix',
|
||||
'get_variables_to_restore',
|
||||
'get_variables',
|
||||
'local_variable',
|
||||
'model_variable',
|
||||
'variable',
|
||||
'VariableDeviceChooser']
|
||||
|
||||
|
||||
def assert_global_step(global_step_tensor):
|
||||
@ -125,7 +138,6 @@ def create_global_step(graph=None):
|
||||
Global step tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: if `dtype` is invalid.
|
||||
ValueError: if global step key is already defined.
|
||||
"""
|
||||
graph = ops.get_default_graph() if graph is None else graph
|
||||
@ -133,10 +145,27 @@ def create_global_step(graph=None):
|
||||
raise ValueError('"global_step" already exists.')
|
||||
# Create in proper graph and base name_scope.
|
||||
with graph.as_default() as g, g.name_scope(None):
|
||||
result = variables.Variable(
|
||||
0, trainable=False, dtype=dtypes.int64, name=ops.GraphKeys.GLOBAL_STEP)
|
||||
graph.add_to_collection(ops.GraphKeys.GLOBAL_STEP, result)
|
||||
return result
|
||||
collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP]
|
||||
return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer, trainable=False,
|
||||
collections=collections)
|
||||
|
||||
|
||||
def get_or_create_global_step(graph=None):
|
||||
"""Returns and create (if necessary) the global step variable.
|
||||
|
||||
Args:
|
||||
graph: The graph in which to create the global step. If missing, use default
|
||||
graph.
|
||||
|
||||
Returns:
|
||||
the tensor representing the global step variable.
|
||||
"""
|
||||
graph = ops.get_default_graph() if graph is None else graph
|
||||
globalstep = get_global_step(graph)
|
||||
if globalstep is None:
|
||||
globalstep = create_global_step(graph)
|
||||
return globalstep
|
||||
|
||||
|
||||
def local_variable(initial_value, validate_shape=True, name=None):
|
||||
@ -154,3 +183,268 @@ def local_variable(initial_value, validate_shape=True, name=None):
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||
validate_shape=validate_shape, name=name)
|
||||
|
||||
|
||||
@contrib_add_arg_scope
|
||||
def variable(name, shape=None, dtype=dtypes.float32, initializer=None,
|
||||
regularizer=None, trainable=True, collections=None,
|
||||
caching_device=None, device=None):
|
||||
"""Gets an existing variable with these parameters or creates a new one.
|
||||
|
||||
Args:
|
||||
name: the name of the new or existing variable.
|
||||
shape: shape of the new or existing variable.
|
||||
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
|
||||
initializer: initializer for the variable if one is created.
|
||||
regularizer: a (Tensor -> Tensor or None) function; the result of
|
||||
applying it on a newly created variable will be added to the collection
|
||||
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||
trainable: If `True` also add the variable to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||
collections: A list of collection names to which the Variable will be added.
|
||||
If None it would default to tf.GraphKeys.VARIABLES.
|
||||
caching_device: Optional device string or function describing where the
|
||||
Variable should be cached for reading. Defaults to the Variable's
|
||||
device.
|
||||
device: Optional device to place the variable. It can be an string or a
|
||||
function that is called to get the device for the variable.
|
||||
|
||||
Returns:
|
||||
The created or existing variable.
|
||||
"""
|
||||
collections = list(collections or [ops.GraphKeys.VARIABLES])
|
||||
|
||||
# Remove duplicates
|
||||
collections = set(collections)
|
||||
with ops.device(device or ''):
|
||||
return variable_scope.get_variable(name, shape=shape, dtype=dtype,
|
||||
initializer=initializer,
|
||||
regularizer=regularizer,
|
||||
trainable=trainable,
|
||||
collections=collections,
|
||||
caching_device=caching_device)
|
||||
|
||||
# TODO(sguada) move it to ops.GraphKeys or to contrib.framework.GraphKeys
|
||||
# Collection containing all the variables created using model_variables.
|
||||
MODEL_VARIABLES = '_model_variables_'
|
||||
|
||||
|
||||
@contrib_add_arg_scope
|
||||
def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
|
||||
regularizer=None, trainable=True, collections=None,
|
||||
caching_device=None, device=None):
|
||||
"""Gets an existing model variable with these parameters or creates a new one.
|
||||
|
||||
Args:
|
||||
name: the name of the new or existing variable.
|
||||
shape: shape of the new or existing variable.
|
||||
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
|
||||
initializer: initializer for the variable if one is created.
|
||||
regularizer: a (Tensor -> Tensor or None) function; the result of
|
||||
applying it on a newly created variable will be added to the collection
|
||||
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
||||
trainable: If `True` also add the variable to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||
collections: A list of collection names to which the Variable will be added.
|
||||
Note that the variable is always also added to the tf.GraphKeys.VARIABLES
|
||||
and MODEL_VARIABLES collections.
|
||||
caching_device: Optional device string or function describing where the
|
||||
Variable should be cached for reading. Defaults to the Variable's
|
||||
device.
|
||||
device: Optional device to place the variable. It can be an string or a
|
||||
function that is called to get the device for the variable.
|
||||
|
||||
Returns:
|
||||
The created or existing variable.
|
||||
"""
|
||||
collections = list(collections or [])
|
||||
|
||||
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
|
||||
collections += [ops.GraphKeys.VARIABLES, MODEL_VARIABLES]
|
||||
return variable(name, shape=shape, dtype=dtype,
|
||||
initializer=initializer, regularizer=regularizer,
|
||||
trainable=trainable, collections=collections,
|
||||
caching_device=caching_device, device=device)
|
||||
|
||||
|
||||
def add_model_variable(var):
|
||||
"""Adds a variable to the MODEL_VARIABLES collection.
|
||||
|
||||
Args:
|
||||
var: a variable.
|
||||
"""
|
||||
if var not in ops.get_collection(MODEL_VARIABLES):
|
||||
ops.add_to_collection(MODEL_VARIABLES, var)
|
||||
|
||||
|
||||
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES):
|
||||
"""Gets the list of variables, filtered by scope and/or suffix.
|
||||
|
||||
Args:
|
||||
scope: an optional scope for filtering the variables to return.
|
||||
suffix: an optional suffix for filtering the variables to return.
|
||||
collection: in which collection search for. Defaults to GraphKeys.VARIABLES.
|
||||
|
||||
Returns:
|
||||
a list of variables in colelction with scope and suffix.
|
||||
"""
|
||||
if suffix is not None:
|
||||
if ':' not in suffix:
|
||||
suffix += ':'
|
||||
scope = (scope or '') + '.*' + suffix
|
||||
return ops.get_collection(collection, scope)
|
||||
|
||||
|
||||
def get_model_variables(scope=None, suffix=None):
|
||||
"""Gets the list of model variables, filtered by scope and/or suffix.
|
||||
|
||||
Args:
|
||||
scope: an optional scope for filtering the variables to return.
|
||||
suffix: an optional suffix for filtering the variables to return.
|
||||
|
||||
Returns:
|
||||
a list of variables in colelction with scope and suffix.
|
||||
"""
|
||||
return get_variables(scope, suffix, MODEL_VARIABLES)
|
||||
|
||||
|
||||
def get_local_variables(scope=None, suffix=None):
|
||||
"""Gets the list of model variables, filtered by scope and/or suffix.
|
||||
|
||||
Args:
|
||||
scope: an optional scope for filtering the variables to return.
|
||||
suffix: an optional suffix for filtering the variables to return.
|
||||
|
||||
Returns:
|
||||
a list of variables in colelction with scope and suffix.
|
||||
"""
|
||||
return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES)
|
||||
|
||||
|
||||
def get_variables_to_restore(include=None, exclude=None):
|
||||
"""Gets the list of the variables to restore.
|
||||
|
||||
Args:
|
||||
include: an optional list/tuple of scope strings for filtering which
|
||||
variables from the VARIABLES collection to include. None would include all
|
||||
the variables.
|
||||
exclude: an optional list/tuple of scope strings for filtering which
|
||||
variables from the VARIABLES collection to exclude. None it would not
|
||||
exclude any.
|
||||
|
||||
Returns:
|
||||
a list of variables to restore.
|
||||
|
||||
Raises:
|
||||
TypeError: include or exclude is provided but is not a list or a tuple.
|
||||
"""
|
||||
if include is None:
|
||||
# Include all variables.
|
||||
vars_to_include = get_variables()
|
||||
else:
|
||||
if not isinstance(include, (list, tuple)):
|
||||
raise TypeError('include is provided but is not a list or a tuple.')
|
||||
vars_to_include = []
|
||||
for scope in include:
|
||||
vars_to_include += get_variables(scope)
|
||||
vars_to_exclude = set()
|
||||
if exclude is not None:
|
||||
if not isinstance(exclude, (list, tuple)):
|
||||
raise TypeError('exclude is provided but is not a list or a tuple.')
|
||||
for scope in exclude:
|
||||
vars_to_exclude |= set(get_variables(scope))
|
||||
# Exclude the variables in vars_to_exclude
|
||||
return [v for v in vars_to_include if v not in vars_to_exclude]
|
||||
|
||||
|
||||
def get_variables_by_suffix(suffix, scope=None):
|
||||
"""Gets the list of variables that end with the given suffix.
|
||||
|
||||
Args:
|
||||
suffix: suffix for filtering the variables to return.
|
||||
scope: an optional scope for filtering the variables to return.
|
||||
|
||||
Returns:
|
||||
a copied list of variables with the given name and prefix.
|
||||
"""
|
||||
return get_variables(scope=scope, suffix=suffix)
|
||||
|
||||
|
||||
def get_variables_by_name(given_name, scope=None):
|
||||
"""Gets the list of variables that were given that name.
|
||||
|
||||
Args:
|
||||
given_name: name given to the variable without any scope.
|
||||
scope: an optional scope for filtering the variables to return.
|
||||
|
||||
Returns:
|
||||
a copied list of variables with the given name and scope.
|
||||
"""
|
||||
suffix = '/' + given_name + ':|^' + given_name + ':'
|
||||
return get_variables(scope=scope, suffix=suffix)
|
||||
|
||||
|
||||
def get_unique_variable(var_op_name):
|
||||
"""Gets the variable uniquely identified by that var_op_name.
|
||||
|
||||
Args:
|
||||
var_op_name: the full name of the variable op, including the scope.
|
||||
|
||||
Returns:
|
||||
a tensorflow variable.
|
||||
|
||||
Raises:
|
||||
ValueError: if no variable uniquely identified by the name exists.
|
||||
"""
|
||||
candidates = get_variables(scope=var_op_name)
|
||||
if not candidates:
|
||||
raise ValueError('Couldnt find variable %s' % var_op_name)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate.op.name == var_op_name:
|
||||
return candidate
|
||||
raise ValueError('Variable %s does not uniquely identify a variable',
|
||||
var_op_name)
|
||||
|
||||
|
||||
class VariableDeviceChooser(object):
|
||||
"""Device chooser for variables.
|
||||
|
||||
When using a parameter server it will assign them in a round-robin fashion.
|
||||
When not using a parameter server it allows GPU or CPU placement.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_tasks=0,
|
||||
device_type='CPU',
|
||||
device_index=0):
|
||||
"""Initialize VariableDeviceChooser.
|
||||
|
||||
Usage:
|
||||
To use with 2 parameter servers:
|
||||
VariableDeviceChooser(2)
|
||||
|
||||
To use without parameter servers:
|
||||
VariableDeviceChooser()
|
||||
VariableDeviceChooser(device_type='GPU') # For GPU placement
|
||||
|
||||
Args:
|
||||
num_tasks: number of tasks.
|
||||
device_type: Optional device type string (e.g. "CPU" or "GPU")
|
||||
device_index: int. Optional device index. If left
|
||||
unspecified, device represents 'any' device_index.
|
||||
"""
|
||||
self._job_name = 'ps' if num_tasks > 0 else None
|
||||
self._device_type = device_type
|
||||
self._device_index = device_index
|
||||
self._num_tasks = num_tasks
|
||||
self._next_task_id = 0
|
||||
|
||||
def __call__(self, op):
|
||||
device_spec = tf_device.DeviceSpec(job=self._job_name,
|
||||
device_type=self._device_type,
|
||||
device_index=self._device_index)
|
||||
if self._num_tasks > 0:
|
||||
task_id = self._next_task_id
|
||||
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
|
||||
device_spec.task = task_id
|
||||
return device_spec.to_string()
|
||||
|
@ -37,11 +37,57 @@ class LocalVariableTest(tf.test.TestCase):
|
||||
tf.initialize_variables(variables).run()
|
||||
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
|
||||
|
||||
def testLocalVariableNameAndShape(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.local_variable([1, 1, 1, 1, 1], name='a')
|
||||
self.assertEquals(a.op.name, 'A/a')
|
||||
self.assertListEqual(a.get_shape().as_list(), [5])
|
||||
self.assertListEqual([a], tf.contrib.framework.get_local_variables())
|
||||
|
||||
def testLocalVariableNotInAllVariables(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.local_variable(0)
|
||||
self.assertFalse(a in tf.all_variables())
|
||||
self.assertTrue(a in tf.local_variables())
|
||||
|
||||
def testLocalVariableNotInVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.local_variable(0)
|
||||
self.assertFalse(a in tf.contrib.framework.get_variables_to_restore())
|
||||
self.assertTrue(a in tf.local_variables())
|
||||
|
||||
def testGetVariablesDontReturnsTransients(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
tf.contrib.framework.local_variable(0)
|
||||
with tf.variable_scope('B'):
|
||||
tf.contrib.framework.local_variable(0)
|
||||
self.assertEquals([], tf.contrib.framework.get_variables('A'))
|
||||
self.assertEquals([], tf.contrib.framework.get_variables('B'))
|
||||
|
||||
def testGetLocalVariablesReturnsTransients(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.local_variable(0)
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.local_variable(0)
|
||||
self.assertEquals([a], tf.contrib.framework.get_local_variables('A'))
|
||||
self.assertEquals([b], tf.contrib.framework.get_local_variables('B'))
|
||||
|
||||
def testInitializedVariableValue(self):
|
||||
with self.test_session() as sess:
|
||||
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertAllEqual(a.eval(), [0]*5)
|
||||
|
||||
|
||||
class GlobalStepTest(tf.test.TestCase):
|
||||
|
||||
def _assert_global_step(self, global_step, expected_dtype=tf.int64):
|
||||
self.assertEquals("%s:0" % tf.GraphKeys.GLOBAL_STEP, global_step.name)
|
||||
self.assertEquals('%s:0' % tf.GraphKeys.GLOBAL_STEP, global_step.name)
|
||||
self.assertEquals(expected_dtype, global_step.dtype.base_dtype)
|
||||
self.assertEquals([], global_step.get_shape().as_list())
|
||||
|
||||
@ -51,10 +97,10 @@ class GlobalStepTest(tf.test.TestCase):
|
||||
tf.Variable(
|
||||
0.0, trainable=False, dtype=tf.float32, name=tf.GraphKeys.GLOBAL_STEP)
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, "does not have integer type",
|
||||
TypeError, 'does not have integer type',
|
||||
tf.contrib.framework.get_global_step)
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, "does not have integer type",
|
||||
TypeError, 'does not have integer type',
|
||||
tf.contrib.framework.get_global_step, g)
|
||||
|
||||
def test_invalid_shape(self):
|
||||
@ -63,10 +109,10 @@ class GlobalStepTest(tf.test.TestCase):
|
||||
tf.Variable(
|
||||
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, "not scalar",
|
||||
TypeError, 'not scalar',
|
||||
tf.contrib.framework.get_global_step)
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, "not scalar",
|
||||
TypeError, 'not scalar',
|
||||
tf.contrib.framework.get_global_step, g)
|
||||
|
||||
def test_create_global_step(self):
|
||||
@ -75,9 +121,9 @@ class GlobalStepTest(tf.test.TestCase):
|
||||
global_step = tf.contrib.framework.create_global_step()
|
||||
self._assert_global_step(global_step)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "already exists", tf.contrib.framework.create_global_step)
|
||||
ValueError, 'already exists', tf.contrib.framework.create_global_step)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "already exists", tf.contrib.framework.create_global_step,
|
||||
ValueError, 'already exists', tf.contrib.framework.create_global_step,
|
||||
g)
|
||||
self._assert_global_step(
|
||||
tf.contrib.framework.create_global_step(tf.Graph()))
|
||||
@ -92,6 +138,510 @@ class GlobalStepTest(tf.test.TestCase):
|
||||
self._assert_global_step(
|
||||
tf.contrib.framework.get_global_step(g), expected_dtype=tf.int32)
|
||||
|
||||
def test_get_or_create_global_step(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
||||
self._assert_global_step(
|
||||
tf.contrib.framework.get_or_create_global_step())
|
||||
self._assert_global_step(
|
||||
tf.contrib.framework.get_or_create_global_step(g))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class VariablesTest(tf.test.TestCase):
|
||||
|
||||
def testCreateVariable(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
self.assertEquals(a.op.name, 'A/a')
|
||||
self.assertListEqual(a.get_shape().as_list(), [5])
|
||||
|
||||
def testGetVariables(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.variable('a', [5])
|
||||
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||
self.assertEquals([a], tf.contrib.framework.get_variables('A'))
|
||||
self.assertEquals([b], tf.contrib.framework.get_variables('B'))
|
||||
|
||||
def testGetVariablesSuffix(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
with tf.variable_scope('A'):
|
||||
b = tf.contrib.framework.variable('b', [5])
|
||||
self.assertEquals([a], tf.contrib.framework.get_variables(suffix='a'))
|
||||
self.assertEquals([b], tf.contrib.framework.get_variables(suffix='b'))
|
||||
|
||||
def testGetVariableWithSingleVar(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('parent'):
|
||||
a = tf.contrib.framework.variable('child', [5])
|
||||
self.assertEquals(
|
||||
a, tf.contrib.framework.get_unique_variable('parent/child'))
|
||||
|
||||
def testGetVariableWithDistractors(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('parent'):
|
||||
a = tf.contrib.framework.variable('child', [5])
|
||||
with tf.variable_scope('child'):
|
||||
tf.contrib.framework.variable('grandchild1', [7])
|
||||
tf.contrib.framework.variable('grandchild2', [9])
|
||||
self.assertEquals(
|
||||
a, tf.contrib.framework.get_unique_variable('parent/child'))
|
||||
|
||||
def testGetVariableThrowsExceptionWithNoMatch(self):
|
||||
var_name = 'cant_find_me'
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.framework.get_unique_variable(var_name)
|
||||
|
||||
def testGetThrowsExceptionWithChildrenButNoMatch(self):
|
||||
var_name = 'parent/child'
|
||||
with self.test_session():
|
||||
with tf.variable_scope(var_name):
|
||||
tf.contrib.framework.variable('grandchild1', [7])
|
||||
tf.contrib.framework.variable('grandchild2', [9])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.framework.get_unique_variable(var_name)
|
||||
|
||||
def testGetVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.variable('a', [5])
|
||||
self.assertEquals([a, b],
|
||||
tf.contrib.framework.get_variables_to_restore())
|
||||
|
||||
def testIncludeGetVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.variable('a', [5])
|
||||
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||
self.assertEquals([a],
|
||||
tf.contrib.framework.get_variables_to_restore(['A']))
|
||||
|
||||
def testExcludeGetVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.variable('a', [5])
|
||||
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||
self.assertEquals([a],
|
||||
tf.contrib.framework.get_variables_to_restore(
|
||||
exclude=['B']))
|
||||
|
||||
def testWrongIncludeGetVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.variable('a', [5])
|
||||
self.assertEquals([a, b], tf.contrib.framework.get_variables())
|
||||
self.assertEquals([],
|
||||
tf.contrib.framework.get_variables_to_restore(['a']))
|
||||
|
||||
def testGetMixedVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
b = tf.contrib.framework.variable('b', [5])
|
||||
with tf.variable_scope('B'):
|
||||
c = tf.contrib.framework.variable('c', [5])
|
||||
d = tf.contrib.framework.variable('d', [5])
|
||||
self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables())
|
||||
self.assertEquals([a, c],
|
||||
tf.contrib.framework.get_variables_to_restore(
|
||||
include=['A/a', 'B/c']))
|
||||
|
||||
def testExcludeGetMixedVariablesToRestore(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
b = tf.contrib.framework.variable('b', [5])
|
||||
with tf.variable_scope('B'):
|
||||
c = tf.contrib.framework.variable('c', [5])
|
||||
d = tf.contrib.framework.variable('d', [5])
|
||||
self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables())
|
||||
self.assertEquals([b, d],
|
||||
tf.contrib.framework.get_variables_to_restore(
|
||||
exclude=['A/a', 'B/c']))
|
||||
|
||||
def testReuseVariable(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
with tf.variable_scope('A', reuse=True):
|
||||
b = tf.contrib.framework.variable('a', [])
|
||||
self.assertEquals(a, b)
|
||||
self.assertListEqual([a], tf.contrib.framework.get_variables())
|
||||
|
||||
def testVariableWithRegularizer(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [], regularizer=tf.nn.l2_loss)
|
||||
loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
|
||||
self.assertDeviceEqual(loss.device, a.device)
|
||||
|
||||
def testVariableWithRegularizerColocate(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [], device='gpu:0',
|
||||
regularizer=tf.nn.l2_loss)
|
||||
loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
|
||||
self.assertDeviceEqual(loss.device, a.device)
|
||||
|
||||
def testVariableWithDevice(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [], device='cpu:0')
|
||||
b = tf.contrib.framework.variable('b', [], device='cpu:1')
|
||||
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||
self.assertDeviceEqual(b.device, 'cpu:1')
|
||||
|
||||
def testVariableWithDeviceFromScope(self):
|
||||
with self.test_session():
|
||||
with tf.device('/cpu:0'):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
b = tf.contrib.framework.variable('b', [], device='cpu:1')
|
||||
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||
self.assertDeviceEqual(b.device, 'cpu:1')
|
||||
|
||||
def testVariableWithDeviceFunction(self):
|
||||
class DevFn(object):
|
||||
|
||||
def __init__(self):
|
||||
self.counter = -1
|
||||
|
||||
def __call__(self, op):
|
||||
self.counter += 1
|
||||
return 'cpu:%d' % self.counter
|
||||
|
||||
with self.test_session():
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
device=DevFn()):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||
d = tf.contrib.framework.variable('d', [])
|
||||
with tf.device('cpu:99'):
|
||||
e_init = tf.constant(12)
|
||||
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||
self.assertDeviceEqual(a.initial_value.device, 'cpu:0')
|
||||
self.assertDeviceEqual(b.device, 'cpu:1')
|
||||
self.assertDeviceEqual(b.initial_value.device, 'cpu:1')
|
||||
self.assertDeviceEqual(c.device, 'cpu:12')
|
||||
self.assertDeviceEqual(c.initial_value.device, 'cpu:12')
|
||||
self.assertDeviceEqual(d.device, 'cpu:2')
|
||||
self.assertDeviceEqual(d.initial_value.device, 'cpu:2')
|
||||
self.assertDeviceEqual(e.device, 'cpu:3')
|
||||
self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
|
||||
|
||||
def testVariableWithReplicaDeviceSetter(self):
|
||||
with self.test_session():
|
||||
with tf.device(tf.train.replica_device_setter(ps_tasks=2)):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||
d = tf.contrib.framework.variable('d', [])
|
||||
with tf.device('cpu:99'):
|
||||
e_init = tf.constant(12)
|
||||
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||
# The values below highlight how the replica_device_setter puts initial
|
||||
# values on the worker job, and how it merges explicit devices.
|
||||
self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
|
||||
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||
self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
|
||||
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||
self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12')
|
||||
self.assertDeviceEqual(c.initial_value.device, c.device)
|
||||
self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0')
|
||||
self.assertDeviceEqual(d.initial_value.device, d.device)
|
||||
self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0')
|
||||
self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99')
|
||||
|
||||
def testVariableWithVariableDeviceChooser(self):
|
||||
|
||||
with tf.Graph().as_default():
|
||||
device_fn = tf.contrib.framework.VariableDeviceChooser(num_tasks=2)
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
device=device_fn):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||
d = tf.contrib.framework.variable('d', [])
|
||||
with tf.device('cpu:99'):
|
||||
e_init = tf.constant(12)
|
||||
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||
# The values below highlight how the VariableDeviceChooser puts initial
|
||||
# values on the same device as the variable job.
|
||||
self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
|
||||
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||
self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
|
||||
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||
self.assertDeviceEqual(c.device, '/cpu:12')
|
||||
self.assertDeviceEqual(c.initial_value.device, c.device)
|
||||
self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
|
||||
self.assertDeviceEqual(d.initial_value.device, d.device)
|
||||
self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
|
||||
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
|
||||
|
||||
def testVariableGPUPlacement(self):
|
||||
|
||||
with tf.Graph().as_default():
|
||||
device_fn = tf.contrib.framework.VariableDeviceChooser(device_type='GPU')
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
device=device_fn):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
c = tf.contrib.framework.variable('c', [], device='cpu:12')
|
||||
d = tf.contrib.framework.variable('d', [])
|
||||
with tf.device('cpu:99'):
|
||||
e_init = tf.constant(12)
|
||||
e = tf.contrib.framework.variable('e', initializer=e_init)
|
||||
# The values below highlight how the VariableDeviceChooser puts initial
|
||||
# values on the same device as the variable job.
|
||||
self.assertDeviceEqual(a.device, '/gpu:0')
|
||||
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||
self.assertDeviceEqual(b.device, '/gpu:0')
|
||||
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||
self.assertDeviceEqual(c.device, '/cpu:12')
|
||||
self.assertDeviceEqual(c.initial_value.device, c.device)
|
||||
self.assertDeviceEqual(d.device, '/gpu:0')
|
||||
self.assertDeviceEqual(d.initial_value.device, d.device)
|
||||
self.assertDeviceEqual(e.device, '/gpu:0')
|
||||
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
|
||||
|
||||
|
||||
class ModelVariablesTest(tf.test.TestCase):
|
||||
|
||||
def testNameAndShape(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertEquals(a.op.name, 'A/a')
|
||||
self.assertListEqual(a.get_shape().as_list(), [5])
|
||||
self.assertListEqual([a], tf.contrib.framework.get_model_variables('A'))
|
||||
|
||||
def testNotInLocalVariables(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertTrue(a in tf.all_variables())
|
||||
self.assertFalse(a in tf.local_variables())
|
||||
|
||||
def testGetVariablesReturns(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertEquals([a], tf.contrib.framework.get_variables('A'))
|
||||
self.assertEquals([b], tf.contrib.framework.get_variables('B'))
|
||||
|
||||
def testGetModelVariables(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
b = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertEquals([a], tf.contrib.framework.get_model_variables('A'))
|
||||
self.assertEquals([b], tf.contrib.framework.get_model_variables('B'))
|
||||
|
||||
def testGetLocalVariables(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
_ = tf.contrib.framework.model_variable('a', [5])
|
||||
with tf.variable_scope('B'):
|
||||
_ = tf.contrib.framework.model_variable('a', [5])
|
||||
self.assertEquals([], tf.contrib.framework.get_local_variables('A'))
|
||||
self.assertEquals([], tf.contrib.framework.get_local_variables('B'))
|
||||
|
||||
def testInitializedVariableValue(self):
|
||||
with self.test_session() as sess:
|
||||
a = tf.contrib.framework.model_variable('a', [5], initializer=tf.ones)
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllEqual(a.eval(), [1]*5)
|
||||
|
||||
def testDeviceFn(self):
|
||||
class DevFn(object):
|
||||
|
||||
def __init__(self):
|
||||
self.counter = -1
|
||||
|
||||
def __call__(self, op):
|
||||
self.counter += 1
|
||||
return '/cpu:%d' % self.counter
|
||||
|
||||
with tf.Graph().as_default():
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable],
|
||||
device=DevFn()):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
b = tf.contrib.framework.model_variable('b', [20])
|
||||
self.assertDeviceEqual(a.device, '/cpu:0')
|
||||
self.assertDeviceEqual(a.initial_value.device, '/cpu:0')
|
||||
self.assertDeviceEqual(b.device, '/cpu:1')
|
||||
self.assertDeviceEqual(b.initial_value.device, '/cpu:1')
|
||||
|
||||
def testVariableWithVariableDeviceChooser(self):
|
||||
|
||||
with tf.Graph().as_default():
|
||||
device_fn = tf.contrib.framework.VariableDeviceChooser()
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable],
|
||||
device=device_fn):
|
||||
a = tf.contrib.framework.model_variable('a', [5])
|
||||
b = tf.contrib.framework.model_variable('b', [20])
|
||||
self.assertDeviceEqual(a.device, 'cpu:0')
|
||||
self.assertDeviceEqual(a.initial_value.device, a.device)
|
||||
self.assertDeviceEqual(b.device, 'cpu:0')
|
||||
self.assertDeviceEqual(b.initial_value.device, b.device)
|
||||
|
||||
|
||||
class GetVariablesCollections(tf.test.TestCase):
|
||||
|
||||
def testVariableCollection(self):
|
||||
with self.test_session():
|
||||
a = tf.contrib.framework.variable('a', [], collections='A')
|
||||
b = tf.contrib.framework.variable('b', [], collections='B')
|
||||
self.assertEquals(a, tf.get_collection('A')[0])
|
||||
self.assertEquals(b, tf.get_collection('B')[0])
|
||||
|
||||
def testVariableCollections(self):
|
||||
with self.test_session():
|
||||
a = tf.contrib.framework.variable('a', [], collections=['A', 'C'])
|
||||
b = tf.contrib.framework.variable('b', [], collections=['B', 'C'])
|
||||
self.assertEquals(a, tf.get_collection('A')[0])
|
||||
self.assertEquals(b, tf.get_collection('B')[0])
|
||||
self.assertListEqual([a, b], tf.get_collection('C'))
|
||||
|
||||
def testVariableCollectionsWithArgScope(self):
|
||||
with self.test_session():
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
collections='A'):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
self.assertListEqual([a, b], tf.get_collection('A'))
|
||||
|
||||
def testVariableCollectionsWithArgScopeNested(self):
|
||||
with self.test_session():
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
collections='A'):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
collections='B'):
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
self.assertEquals(a, tf.get_collection('A')[0])
|
||||
self.assertEquals(b, tf.get_collection('B')[0])
|
||||
|
||||
def testVariableCollectionsWithArgScopeNonNested(self):
|
||||
with self.test_session():
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
collections='A'):
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
collections='B'):
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
tf.contrib.framework.variable('c', [])
|
||||
self.assertListEqual([a], tf.get_collection('A'))
|
||||
self.assertListEqual([b], tf.get_collection('B'))
|
||||
|
||||
def testVariableRestoreWithArgScopeNested(self):
|
||||
with self.test_session():
|
||||
a = tf.contrib.framework.variable('a', [])
|
||||
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
|
||||
trainable=False,
|
||||
collections=['A', 'B']):
|
||||
b = tf.contrib.framework.variable('b', [])
|
||||
c = tf.contrib.framework.variable('c', [], trainable=False)
|
||||
self.assertEquals([a, c], tf.contrib.framework.get_variables_to_restore())
|
||||
self.assertEquals([a], tf.trainable_variables())
|
||||
self.assertEquals([b], tf.get_collection('A'))
|
||||
self.assertEquals([b], tf.get_collection('B'))
|
||||
|
||||
|
||||
class GetVariablesBySuffixTest(tf.test.TestCase):
|
||||
|
||||
def testGetVariableGivenNameScoped(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
b = tf.contrib.framework.variable('b', [5])
|
||||
self.assertEquals([a],
|
||||
tf.contrib.framework.get_variables_by_suffix('a'))
|
||||
self.assertEquals([b],
|
||||
tf.contrib.framework.get_variables_by_suffix('b'))
|
||||
|
||||
def testGetVariableWithScope(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||
with tf.variable_scope('B'):
|
||||
a2 = tf.contrib.framework.variable('a', [5])
|
||||
matched_variables = tf.contrib.framework.get_variables_by_suffix('a')
|
||||
self.assertEquals([a, fooa, a2], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_suffix('/a')
|
||||
self.assertEquals([a, a2], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_suffix(
|
||||
'a', scope='A')
|
||||
self.assertEquals([a, fooa], matched_variables)
|
||||
|
||||
def testGetVariableWithoutScope(self):
|
||||
with self.test_session():
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||
b_a = tf.contrib.framework.variable('B/a', [5])
|
||||
matched_variables = tf.contrib.framework.get_variables_by_suffix('a')
|
||||
self.assertEquals([a, fooa, b_a], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_suffix('fooa')
|
||||
self.assertEquals([fooa], matched_variables)
|
||||
|
||||
|
||||
class GetVariablesByNameTest(tf.test.TestCase):
|
||||
|
||||
def testGetVariableGivenNameScoped(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
b = tf.contrib.framework.variable('b', [5])
|
||||
self.assertEquals([a], tf.contrib.framework.get_variables_by_name('a'))
|
||||
self.assertEquals([b], tf.contrib.framework.get_variables_by_name('b'))
|
||||
|
||||
def testGetVariableWithScope(self):
|
||||
with self.test_session():
|
||||
with tf.variable_scope('A'):
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||
with tf.variable_scope('B'):
|
||||
a2 = tf.contrib.framework.variable('a', [5])
|
||||
matched_variables = tf.contrib.framework.get_variables_by_name('a')
|
||||
self.assertEquals([a, a2], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_name('fooa')
|
||||
self.assertEquals([fooa], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_name('/a')
|
||||
self.assertEquals([], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_name('a',
|
||||
scope='A')
|
||||
self.assertEquals([a], matched_variables)
|
||||
|
||||
def testGetVariableWithoutScope(self):
|
||||
with self.test_session():
|
||||
a = tf.contrib.framework.variable('a', [5])
|
||||
fooa = tf.contrib.framework.variable('fooa', [5])
|
||||
b_a = tf.contrib.framework.variable('B/a', [5])
|
||||
matched_variables = tf.contrib.framework.get_variables_by_name('a')
|
||||
for v in matched_variables:
|
||||
print(v.name)
|
||||
self.assertEquals([a, b_a], matched_variables)
|
||||
matched_variables = tf.contrib.framework.get_variables_by_name('fooa')
|
||||
self.assertEquals([fooa], matched_variables)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user