Added model_variable and helpers to manage variables.

Change: 122727793
This commit is contained in:
A. Unique TensorFlower 2016-05-19 05:55:50 -08:00 committed by TensorFlower Gardener
parent 63c29c82b3
commit 144855b385
5 changed files with 906 additions and 54 deletions

View File

@ -18,13 +18,40 @@
@@assert_same_float_dtype @@assert_same_float_dtype
@@assert_scalar_int @@assert_scalar_int
@@convert_to_tensor_or_sparse_tensor @@convert_to_tensor_or_sparse_tensor
@@local_variable @@get_graph_from_inputs
@@is_numeric_tensor
@@is_non_decreasing
@@is_strictly_increasing
@@reduce_sum_n @@reduce_sum_n
@@safe_embedding_lookup_sparse @@safe_embedding_lookup_sparse
@@with_shape @@with_shape
@@with_same_shape @@with_same_shape
@@get_graph_from_inputs
## Arg_Scope
@@arg_scope
@@add_arg_scope
@@has_arg_scope
@@arg_scoped_arguments
## Variables
@@add_model_variable
@@assert_global_step
@@assert_or_get_global_step
@@create_global_step
@@get_global_step
@@get_or_create_global_step
@@get_local_variables
@@get_model_variables
@@get_unique_variable
@@get_variables_by_name
@@get_variables_by_suffix
@@get_variables_to_restore
@@get_variables
@@local_variable
@@model_variable
@@variable
@@VariableDeviceChooser
""" """
from __future__ import absolute_import from __future__ import absolute_import

View File

@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
"""Tensor utility functions.""" """Tensor utility functions."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -27,14 +26,16 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
__all__ = [ __all__ = [
'assert_same_float_dtype', 'assert_scalar_int', 'assert_same_float_dtype',
'convert_to_tensor_or_sparse_tensor', 'local_variable', 'reduce_sum_n', 'assert_scalar_int',
'with_shape', 'with_same_shape', 'convert_to_tensor_or_sparse_tensor',
] 'reduce_sum_n',
'with_shape',
'with_same_shape']
def _assert_same_base_type(items, expected_type=None): def _assert_same_base_type(items, expected_type=None):
"""Asserts all items are of the same base type. r"""Asserts all items are of the same base type.
Args: Args:
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
@ -110,23 +111,6 @@ def assert_scalar_int(tensor):
return tensor return tensor
# TODO(ptucker): Move to tf.variables?
def local_variable(initial_value, validate_shape=True, name=None):
"""Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.
Args:
initial_value: See variables.Variable.__init__.
validate_shape: See variables.Variable.__init__.
name: See variables.Variable.__init__.
Returns:
New variable.
"""
return variables.Variable(
initial_value, trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
validate_shape=validate_shape, name=name)
def reduce_sum_n(tensors, name=None): def reduce_sum_n(tensors, name=None):
"""Reduce tensors to a scalar sum. """Reduce tensors to a scalar sum.

View File

@ -47,11 +47,6 @@
@tf.contrib.add_arg_scope @tf.contrib.add_arg_scope
def conv2d(*args, **kwargs) def conv2d(*args, **kwargs)
@@arg_scope
@@add_arg_scope
@@has_arg_scope
@@arg_scoped_arguments
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -59,8 +54,10 @@ from __future__ import print_function
import contextlib import contextlib
import functools import functools
__all__ = ['arg_scope', 'add_arg_scope', __all__ = ['arg_scope',
'has_arg_scope', 'arg_scoped_arguments'] 'add_arg_scope',
'has_arg_scope',
'arg_scoped_arguments']
_ARGSTACK = [{}] _ARGSTACK = [{}]

View File

@ -14,24 +14,37 @@
# ============================================================================== # ==============================================================================
"""Variable functions. """Variable functions.
@@assert_global_step
@@create_global_step
@@get_global_step
@@assert_or_get_global_step
@@local_variable
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
__all__ = [
'assert_global_step', 'create_global_step', 'get_global_step', __all__ = ['add_model_variable',
'assert_or_get_global_step', 'local_variable'] 'assert_global_step',
'assert_or_get_global_step',
'create_global_step',
'get_global_step',
'get_or_create_global_step',
'get_local_variables',
'get_model_variables',
'get_unique_variable',
'get_variables_by_name',
'get_variables_by_suffix',
'get_variables_to_restore',
'get_variables',
'local_variable',
'model_variable',
'variable',
'VariableDeviceChooser']
def assert_global_step(global_step_tensor): def assert_global_step(global_step_tensor):
@ -125,7 +138,6 @@ def create_global_step(graph=None):
Global step tensor. Global step tensor.
Raises: Raises:
TypeError: if `dtype` is invalid.
ValueError: if global step key is already defined. ValueError: if global step key is already defined.
""" """
graph = ops.get_default_graph() if graph is None else graph graph = ops.get_default_graph() if graph is None else graph
@ -133,10 +145,27 @@ def create_global_step(graph=None):
raise ValueError('"global_step" already exists.') raise ValueError('"global_step" already exists.')
# Create in proper graph and base name_scope. # Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None): with graph.as_default() as g, g.name_scope(None):
result = variables.Variable( collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP]
0, trainable=False, dtype=dtypes.int64, name=ops.GraphKeys.GLOBAL_STEP) return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64,
graph.add_to_collection(ops.GraphKeys.GLOBAL_STEP, result) initializer=init_ops.zeros_initializer, trainable=False,
return result collections=collections)
def get_or_create_global_step(graph=None):
"""Returns and create (if necessary) the global step variable.
Args:
graph: The graph in which to create the global step. If missing, use default
graph.
Returns:
the tensor representing the global step variable.
"""
graph = ops.get_default_graph() if graph is None else graph
globalstep = get_global_step(graph)
if globalstep is None:
globalstep = create_global_step(graph)
return globalstep
def local_variable(initial_value, validate_shape=True, name=None): def local_variable(initial_value, validate_shape=True, name=None):
@ -154,3 +183,268 @@ def local_variable(initial_value, validate_shape=True, name=None):
collections=[ops.GraphKeys.LOCAL_VARIABLES], collections=[ops.GraphKeys.LOCAL_VARIABLES],
validate_shape=validate_shape, name=name) validate_shape=validate_shape, name=name)
@contrib_add_arg_scope
def variable(name, shape=None, dtype=dtypes.float32, initializer=None,
regularizer=None, trainable=True, collections=None,
caching_device=None, device=None):
"""Gets an existing variable with these parameters or creates a new one.
Args:
name: the name of the new or existing variable.
shape: shape of the new or existing variable.
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
initializer: initializer for the variable if one is created.
regularizer: a (Tensor -> Tensor or None) function; the result of
applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
collections: A list of collection names to which the Variable will be added.
If None it would default to tf.GraphKeys.VARIABLES.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device.
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
Returns:
The created or existing variable.
"""
collections = list(collections or [ops.GraphKeys.VARIABLES])
# Remove duplicates
collections = set(collections)
with ops.device(device or ''):
return variable_scope.get_variable(name, shape=shape, dtype=dtype,
initializer=initializer,
regularizer=regularizer,
trainable=trainable,
collections=collections,
caching_device=caching_device)
# TODO(sguada) move it to ops.GraphKeys or to contrib.framework.GraphKeys
# Collection containing all the variables created using model_variables.
MODEL_VARIABLES = '_model_variables_'
@contrib_add_arg_scope
def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
regularizer=None, trainable=True, collections=None,
caching_device=None, device=None):
"""Gets an existing model variable with these parameters or creates a new one.
Args:
name: the name of the new or existing variable.
shape: shape of the new or existing variable.
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
initializer: initializer for the variable if one is created.
regularizer: a (Tensor -> Tensor or None) function; the result of
applying it on a newly created variable will be added to the collection
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
collections: A list of collection names to which the Variable will be added.
Note that the variable is always also added to the tf.GraphKeys.VARIABLES
and MODEL_VARIABLES collections.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device.
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
Returns:
The created or existing variable.
"""
collections = list(collections or [])
# Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
collections += [ops.GraphKeys.VARIABLES, MODEL_VARIABLES]
return variable(name, shape=shape, dtype=dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable, collections=collections,
caching_device=caching_device, device=device)
def add_model_variable(var):
"""Adds a variable to the MODEL_VARIABLES collection.
Args:
var: a variable.
"""
if var not in ops.get_collection(MODEL_VARIABLES):
ops.add_to_collection(MODEL_VARIABLES, var)
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES):
"""Gets the list of variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return.
suffix: an optional suffix for filtering the variables to return.
collection: in which collection search for. Defaults to GraphKeys.VARIABLES.
Returns:
a list of variables in colelction with scope and suffix.
"""
if suffix is not None:
if ':' not in suffix:
suffix += ':'
scope = (scope or '') + '.*' + suffix
return ops.get_collection(collection, scope)
def get_model_variables(scope=None, suffix=None):
"""Gets the list of model variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return.
suffix: an optional suffix for filtering the variables to return.
Returns:
a list of variables in colelction with scope and suffix.
"""
return get_variables(scope, suffix, MODEL_VARIABLES)
def get_local_variables(scope=None, suffix=None):
"""Gets the list of model variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return.
suffix: an optional suffix for filtering the variables to return.
Returns:
a list of variables in colelction with scope and suffix.
"""
return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES)
def get_variables_to_restore(include=None, exclude=None):
"""Gets the list of the variables to restore.
Args:
include: an optional list/tuple of scope strings for filtering which
variables from the VARIABLES collection to include. None would include all
the variables.
exclude: an optional list/tuple of scope strings for filtering which
variables from the VARIABLES collection to exclude. None it would not
exclude any.
Returns:
a list of variables to restore.
Raises:
TypeError: include or exclude is provided but is not a list or a tuple.
"""
if include is None:
# Include all variables.
vars_to_include = get_variables()
else:
if not isinstance(include, (list, tuple)):
raise TypeError('include is provided but is not a list or a tuple.')
vars_to_include = []
for scope in include:
vars_to_include += get_variables(scope)
vars_to_exclude = set()
if exclude is not None:
if not isinstance(exclude, (list, tuple)):
raise TypeError('exclude is provided but is not a list or a tuple.')
for scope in exclude:
vars_to_exclude |= set(get_variables(scope))
# Exclude the variables in vars_to_exclude
return [v for v in vars_to_include if v not in vars_to_exclude]
def get_variables_by_suffix(suffix, scope=None):
"""Gets the list of variables that end with the given suffix.
Args:
suffix: suffix for filtering the variables to return.
scope: an optional scope for filtering the variables to return.
Returns:
a copied list of variables with the given name and prefix.
"""
return get_variables(scope=scope, suffix=suffix)
def get_variables_by_name(given_name, scope=None):
"""Gets the list of variables that were given that name.
Args:
given_name: name given to the variable without any scope.
scope: an optional scope for filtering the variables to return.
Returns:
a copied list of variables with the given name and scope.
"""
suffix = '/' + given_name + ':|^' + given_name + ':'
return get_variables(scope=scope, suffix=suffix)
def get_unique_variable(var_op_name):
"""Gets the variable uniquely identified by that var_op_name.
Args:
var_op_name: the full name of the variable op, including the scope.
Returns:
a tensorflow variable.
Raises:
ValueError: if no variable uniquely identified by the name exists.
"""
candidates = get_variables(scope=var_op_name)
if not candidates:
raise ValueError('Couldnt find variable %s' % var_op_name)
for candidate in candidates:
if candidate.op.name == var_op_name:
return candidate
raise ValueError('Variable %s does not uniquely identify a variable',
var_op_name)
class VariableDeviceChooser(object):
"""Device chooser for variables.
When using a parameter server it will assign them in a round-robin fashion.
When not using a parameter server it allows GPU or CPU placement.
"""
def __init__(self,
num_tasks=0,
device_type='CPU',
device_index=0):
"""Initialize VariableDeviceChooser.
Usage:
To use with 2 parameter servers:
VariableDeviceChooser(2)
To use without parameter servers:
VariableDeviceChooser()
VariableDeviceChooser(device_type='GPU') # For GPU placement
Args:
num_tasks: number of tasks.
device_type: Optional device type string (e.g. "CPU" or "GPU")
device_index: int. Optional device index. If left
unspecified, device represents 'any' device_index.
"""
self._job_name = 'ps' if num_tasks > 0 else None
self._device_type = device_type
self._device_index = device_index
self._num_tasks = num_tasks
self._next_task_id = 0
def __call__(self, op):
device_spec = tf_device.DeviceSpec(job=self._job_name,
device_type=self._device_type,
device_index=self._device_index)
if self._num_tasks > 0:
task_id = self._next_task_id
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
device_spec.task = task_id
return device_spec.to_string()

View File

@ -37,11 +37,57 @@ class LocalVariableTest(tf.test.TestCase):
tf.initialize_variables(variables).run() tf.initialize_variables(variables).run()
self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
def testLocalVariableNameAndShape(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.local_variable([1, 1, 1, 1, 1], name='a')
self.assertEquals(a.op.name, 'A/a')
self.assertListEqual(a.get_shape().as_list(), [5])
self.assertListEqual([a], tf.contrib.framework.get_local_variables())
def testLocalVariableNotInAllVariables(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.local_variable(0)
self.assertFalse(a in tf.all_variables())
self.assertTrue(a in tf.local_variables())
def testLocalVariableNotInVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.local_variable(0)
self.assertFalse(a in tf.contrib.framework.get_variables_to_restore())
self.assertTrue(a in tf.local_variables())
def testGetVariablesDontReturnsTransients(self):
with self.test_session():
with tf.variable_scope('A'):
tf.contrib.framework.local_variable(0)
with tf.variable_scope('B'):
tf.contrib.framework.local_variable(0)
self.assertEquals([], tf.contrib.framework.get_variables('A'))
self.assertEquals([], tf.contrib.framework.get_variables('B'))
def testGetLocalVariablesReturnsTransients(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.local_variable(0)
with tf.variable_scope('B'):
b = tf.contrib.framework.local_variable(0)
self.assertEquals([a], tf.contrib.framework.get_local_variables('A'))
self.assertEquals([b], tf.contrib.framework.get_local_variables('B'))
def testInitializedVariableValue(self):
with self.test_session() as sess:
a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a')
sess.run(tf.initialize_local_variables())
self.assertAllEqual(a.eval(), [0]*5)
class GlobalStepTest(tf.test.TestCase): class GlobalStepTest(tf.test.TestCase):
def _assert_global_step(self, global_step, expected_dtype=tf.int64): def _assert_global_step(self, global_step, expected_dtype=tf.int64):
self.assertEquals("%s:0" % tf.GraphKeys.GLOBAL_STEP, global_step.name) self.assertEquals('%s:0' % tf.GraphKeys.GLOBAL_STEP, global_step.name)
self.assertEquals(expected_dtype, global_step.dtype.base_dtype) self.assertEquals(expected_dtype, global_step.dtype.base_dtype)
self.assertEquals([], global_step.get_shape().as_list()) self.assertEquals([], global_step.get_shape().as_list())
@ -51,10 +97,10 @@ class GlobalStepTest(tf.test.TestCase):
tf.Variable( tf.Variable(
0.0, trainable=False, dtype=tf.float32, name=tf.GraphKeys.GLOBAL_STEP) 0.0, trainable=False, dtype=tf.float32, name=tf.GraphKeys.GLOBAL_STEP)
self.assertRaisesRegexp( self.assertRaisesRegexp(
TypeError, "does not have integer type", TypeError, 'does not have integer type',
tf.contrib.framework.get_global_step) tf.contrib.framework.get_global_step)
self.assertRaisesRegexp( self.assertRaisesRegexp(
TypeError, "does not have integer type", TypeError, 'does not have integer type',
tf.contrib.framework.get_global_step, g) tf.contrib.framework.get_global_step, g)
def test_invalid_shape(self): def test_invalid_shape(self):
@ -63,10 +109,10 @@ class GlobalStepTest(tf.test.TestCase):
tf.Variable( tf.Variable(
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP) [0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
self.assertRaisesRegexp( self.assertRaisesRegexp(
TypeError, "not scalar", TypeError, 'not scalar',
tf.contrib.framework.get_global_step) tf.contrib.framework.get_global_step)
self.assertRaisesRegexp( self.assertRaisesRegexp(
TypeError, "not scalar", TypeError, 'not scalar',
tf.contrib.framework.get_global_step, g) tf.contrib.framework.get_global_step, g)
def test_create_global_step(self): def test_create_global_step(self):
@ -75,9 +121,9 @@ class GlobalStepTest(tf.test.TestCase):
global_step = tf.contrib.framework.create_global_step() global_step = tf.contrib.framework.create_global_step()
self._assert_global_step(global_step) self._assert_global_step(global_step)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "already exists", tf.contrib.framework.create_global_step) ValueError, 'already exists', tf.contrib.framework.create_global_step)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "already exists", tf.contrib.framework.create_global_step, ValueError, 'already exists', tf.contrib.framework.create_global_step,
g) g)
self._assert_global_step( self._assert_global_step(
tf.contrib.framework.create_global_step(tf.Graph())) tf.contrib.framework.create_global_step(tf.Graph()))
@ -92,6 +138,510 @@ class GlobalStepTest(tf.test.TestCase):
self._assert_global_step( self._assert_global_step(
tf.contrib.framework.get_global_step(g), expected_dtype=tf.int32) tf.contrib.framework.get_global_step(g), expected_dtype=tf.int32)
def test_get_or_create_global_step(self):
with tf.Graph().as_default() as g:
self.assertEquals(None, tf.contrib.framework.get_global_step())
self._assert_global_step(
tf.contrib.framework.get_or_create_global_step())
self._assert_global_step(
tf.contrib.framework.get_or_create_global_step(g))
if __name__ == "__main__":
class VariablesTest(tf.test.TestCase):
def testCreateVariable(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
self.assertEquals(a.op.name, 'A/a')
self.assertListEqual(a.get_shape().as_list(), [5])
def testGetVariables(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.variable('a', [5])
self.assertEquals([a, b], tf.contrib.framework.get_variables())
self.assertEquals([a], tf.contrib.framework.get_variables('A'))
self.assertEquals([b], tf.contrib.framework.get_variables('B'))
def testGetVariablesSuffix(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
with tf.variable_scope('A'):
b = tf.contrib.framework.variable('b', [5])
self.assertEquals([a], tf.contrib.framework.get_variables(suffix='a'))
self.assertEquals([b], tf.contrib.framework.get_variables(suffix='b'))
def testGetVariableWithSingleVar(self):
with self.test_session():
with tf.variable_scope('parent'):
a = tf.contrib.framework.variable('child', [5])
self.assertEquals(
a, tf.contrib.framework.get_unique_variable('parent/child'))
def testGetVariableWithDistractors(self):
with self.test_session():
with tf.variable_scope('parent'):
a = tf.contrib.framework.variable('child', [5])
with tf.variable_scope('child'):
tf.contrib.framework.variable('grandchild1', [7])
tf.contrib.framework.variable('grandchild2', [9])
self.assertEquals(
a, tf.contrib.framework.get_unique_variable('parent/child'))
def testGetVariableThrowsExceptionWithNoMatch(self):
var_name = 'cant_find_me'
with self.test_session():
with self.assertRaises(ValueError):
tf.contrib.framework.get_unique_variable(var_name)
def testGetThrowsExceptionWithChildrenButNoMatch(self):
var_name = 'parent/child'
with self.test_session():
with tf.variable_scope(var_name):
tf.contrib.framework.variable('grandchild1', [7])
tf.contrib.framework.variable('grandchild2', [9])
with self.assertRaises(ValueError):
tf.contrib.framework.get_unique_variable(var_name)
def testGetVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.variable('a', [5])
self.assertEquals([a, b],
tf.contrib.framework.get_variables_to_restore())
def testIncludeGetVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.variable('a', [5])
self.assertEquals([a, b], tf.contrib.framework.get_variables())
self.assertEquals([a],
tf.contrib.framework.get_variables_to_restore(['A']))
def testExcludeGetVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.variable('a', [5])
self.assertEquals([a, b], tf.contrib.framework.get_variables())
self.assertEquals([a],
tf.contrib.framework.get_variables_to_restore(
exclude=['B']))
def testWrongIncludeGetVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.variable('a', [5])
self.assertEquals([a, b], tf.contrib.framework.get_variables())
self.assertEquals([],
tf.contrib.framework.get_variables_to_restore(['a']))
def testGetMixedVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
b = tf.contrib.framework.variable('b', [5])
with tf.variable_scope('B'):
c = tf.contrib.framework.variable('c', [5])
d = tf.contrib.framework.variable('d', [5])
self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables())
self.assertEquals([a, c],
tf.contrib.framework.get_variables_to_restore(
include=['A/a', 'B/c']))
def testExcludeGetMixedVariablesToRestore(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
b = tf.contrib.framework.variable('b', [5])
with tf.variable_scope('B'):
c = tf.contrib.framework.variable('c', [5])
d = tf.contrib.framework.variable('d', [5])
self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables())
self.assertEquals([b, d],
tf.contrib.framework.get_variables_to_restore(
exclude=['A/a', 'B/c']))
def testReuseVariable(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [])
with tf.variable_scope('A', reuse=True):
b = tf.contrib.framework.variable('a', [])
self.assertEquals(a, b)
self.assertListEqual([a], tf.contrib.framework.get_variables())
def testVariableWithRegularizer(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [], regularizer=tf.nn.l2_loss)
loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertDeviceEqual(loss.device, a.device)
def testVariableWithRegularizerColocate(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [], device='gpu:0',
regularizer=tf.nn.l2_loss)
loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertDeviceEqual(loss.device, a.device)
def testVariableWithDevice(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [], device='cpu:0')
b = tf.contrib.framework.variable('b', [], device='cpu:1')
self.assertDeviceEqual(a.device, 'cpu:0')
self.assertDeviceEqual(b.device, 'cpu:1')
def testVariableWithDeviceFromScope(self):
with self.test_session():
with tf.device('/cpu:0'):
a = tf.contrib.framework.variable('a', [])
b = tf.contrib.framework.variable('b', [], device='cpu:1')
self.assertDeviceEqual(a.device, 'cpu:0')
self.assertDeviceEqual(b.device, 'cpu:1')
def testVariableWithDeviceFunction(self):
class DevFn(object):
def __init__(self):
self.counter = -1
def __call__(self, op):
self.counter += 1
return 'cpu:%d' % self.counter
with self.test_session():
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
device=DevFn()):
a = tf.contrib.framework.variable('a', [])
b = tf.contrib.framework.variable('b', [])
c = tf.contrib.framework.variable('c', [], device='cpu:12')
d = tf.contrib.framework.variable('d', [])
with tf.device('cpu:99'):
e_init = tf.constant(12)
e = tf.contrib.framework.variable('e', initializer=e_init)
self.assertDeviceEqual(a.device, 'cpu:0')
self.assertDeviceEqual(a.initial_value.device, 'cpu:0')
self.assertDeviceEqual(b.device, 'cpu:1')
self.assertDeviceEqual(b.initial_value.device, 'cpu:1')
self.assertDeviceEqual(c.device, 'cpu:12')
self.assertDeviceEqual(c.initial_value.device, 'cpu:12')
self.assertDeviceEqual(d.device, 'cpu:2')
self.assertDeviceEqual(d.initial_value.device, 'cpu:2')
self.assertDeviceEqual(e.device, 'cpu:3')
self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
def testVariableWithReplicaDeviceSetter(self):
with self.test_session():
with tf.device(tf.train.replica_device_setter(ps_tasks=2)):
a = tf.contrib.framework.variable('a', [])
b = tf.contrib.framework.variable('b', [])
c = tf.contrib.framework.variable('c', [], device='cpu:12')
d = tf.contrib.framework.variable('d', [])
with tf.device('cpu:99'):
e_init = tf.constant(12)
e = tf.contrib.framework.variable('e', initializer=e_init)
# The values below highlight how the replica_device_setter puts initial
# values on the worker job, and how it merges explicit devices.
self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
self.assertDeviceEqual(a.initial_value.device, a.device)
self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(b.initial_value.device, b.device)
self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12')
self.assertDeviceEqual(c.initial_value.device, c.device)
self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(d.initial_value.device, d.device)
self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0')
self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99')
def testVariableWithVariableDeviceChooser(self):
with tf.Graph().as_default():
device_fn = tf.contrib.framework.VariableDeviceChooser(num_tasks=2)
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
device=device_fn):
a = tf.contrib.framework.variable('a', [])
b = tf.contrib.framework.variable('b', [])
c = tf.contrib.framework.variable('c', [], device='cpu:12')
d = tf.contrib.framework.variable('d', [])
with tf.device('cpu:99'):
e_init = tf.constant(12)
e = tf.contrib.framework.variable('e', initializer=e_init)
# The values below highlight how the VariableDeviceChooser puts initial
# values on the same device as the variable job.
self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
self.assertDeviceEqual(a.initial_value.device, a.device)
self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(b.initial_value.device, b.device)
self.assertDeviceEqual(c.device, '/cpu:12')
self.assertDeviceEqual(c.initial_value.device, c.device)
self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
self.assertDeviceEqual(d.initial_value.device, d.device)
self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
def testVariableGPUPlacement(self):
with tf.Graph().as_default():
device_fn = tf.contrib.framework.VariableDeviceChooser(device_type='GPU')
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
device=device_fn):
a = tf.contrib.framework.variable('a', [])
b = tf.contrib.framework.variable('b', [])
c = tf.contrib.framework.variable('c', [], device='cpu:12')
d = tf.contrib.framework.variable('d', [])
with tf.device('cpu:99'):
e_init = tf.constant(12)
e = tf.contrib.framework.variable('e', initializer=e_init)
# The values below highlight how the VariableDeviceChooser puts initial
# values on the same device as the variable job.
self.assertDeviceEqual(a.device, '/gpu:0')
self.assertDeviceEqual(a.initial_value.device, a.device)
self.assertDeviceEqual(b.device, '/gpu:0')
self.assertDeviceEqual(b.initial_value.device, b.device)
self.assertDeviceEqual(c.device, '/cpu:12')
self.assertDeviceEqual(c.initial_value.device, c.device)
self.assertDeviceEqual(d.device, '/gpu:0')
self.assertDeviceEqual(d.initial_value.device, d.device)
self.assertDeviceEqual(e.device, '/gpu:0')
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
class ModelVariablesTest(tf.test.TestCase):
def testNameAndShape(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.model_variable('a', [5])
self.assertEquals(a.op.name, 'A/a')
self.assertListEqual(a.get_shape().as_list(), [5])
self.assertListEqual([a], tf.contrib.framework.get_model_variables('A'))
def testNotInLocalVariables(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.model_variable('a', [5])
self.assertTrue(a in tf.all_variables())
self.assertFalse(a in tf.local_variables())
def testGetVariablesReturns(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.model_variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.model_variable('a', [5])
self.assertEquals([a], tf.contrib.framework.get_variables('A'))
self.assertEquals([b], tf.contrib.framework.get_variables('B'))
def testGetModelVariables(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.model_variable('a', [5])
with tf.variable_scope('B'):
b = tf.contrib.framework.model_variable('a', [5])
self.assertEquals([a], tf.contrib.framework.get_model_variables('A'))
self.assertEquals([b], tf.contrib.framework.get_model_variables('B'))
def testGetLocalVariables(self):
with self.test_session():
with tf.variable_scope('A'):
_ = tf.contrib.framework.model_variable('a', [5])
with tf.variable_scope('B'):
_ = tf.contrib.framework.model_variable('a', [5])
self.assertEquals([], tf.contrib.framework.get_local_variables('A'))
self.assertEquals([], tf.contrib.framework.get_local_variables('B'))
def testInitializedVariableValue(self):
with self.test_session() as sess:
a = tf.contrib.framework.model_variable('a', [5], initializer=tf.ones)
sess.run(tf.initialize_all_variables())
self.assertAllEqual(a.eval(), [1]*5)
def testDeviceFn(self):
class DevFn(object):
def __init__(self):
self.counter = -1
def __call__(self, op):
self.counter += 1
return '/cpu:%d' % self.counter
with tf.Graph().as_default():
with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable],
device=DevFn()):
a = tf.contrib.framework.model_variable('a', [5])
b = tf.contrib.framework.model_variable('b', [20])
self.assertDeviceEqual(a.device, '/cpu:0')
self.assertDeviceEqual(a.initial_value.device, '/cpu:0')
self.assertDeviceEqual(b.device, '/cpu:1')
self.assertDeviceEqual(b.initial_value.device, '/cpu:1')
def testVariableWithVariableDeviceChooser(self):
with tf.Graph().as_default():
device_fn = tf.contrib.framework.VariableDeviceChooser()
with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable],
device=device_fn):
a = tf.contrib.framework.model_variable('a', [5])
b = tf.contrib.framework.model_variable('b', [20])
self.assertDeviceEqual(a.device, 'cpu:0')
self.assertDeviceEqual(a.initial_value.device, a.device)
self.assertDeviceEqual(b.device, 'cpu:0')
self.assertDeviceEqual(b.initial_value.device, b.device)
class GetVariablesCollections(tf.test.TestCase):
def testVariableCollection(self):
with self.test_session():
a = tf.contrib.framework.variable('a', [], collections='A')
b = tf.contrib.framework.variable('b', [], collections='B')
self.assertEquals(a, tf.get_collection('A')[0])
self.assertEquals(b, tf.get_collection('B')[0])
def testVariableCollections(self):
with self.test_session():
a = tf.contrib.framework.variable('a', [], collections=['A', 'C'])
b = tf.contrib.framework.variable('b', [], collections=['B', 'C'])
self.assertEquals(a, tf.get_collection('A')[0])
self.assertEquals(b, tf.get_collection('B')[0])
self.assertListEqual([a, b], tf.get_collection('C'))
def testVariableCollectionsWithArgScope(self):
with self.test_session():
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
collections='A'):
a = tf.contrib.framework.variable('a', [])
b = tf.contrib.framework.variable('b', [])
self.assertListEqual([a, b], tf.get_collection('A'))
def testVariableCollectionsWithArgScopeNested(self):
with self.test_session():
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
collections='A'):
a = tf.contrib.framework.variable('a', [])
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
collections='B'):
b = tf.contrib.framework.variable('b', [])
self.assertEquals(a, tf.get_collection('A')[0])
self.assertEquals(b, tf.get_collection('B')[0])
def testVariableCollectionsWithArgScopeNonNested(self):
with self.test_session():
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
collections='A'):
a = tf.contrib.framework.variable('a', [])
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
collections='B'):
b = tf.contrib.framework.variable('b', [])
tf.contrib.framework.variable('c', [])
self.assertListEqual([a], tf.get_collection('A'))
self.assertListEqual([b], tf.get_collection('B'))
def testVariableRestoreWithArgScopeNested(self):
with self.test_session():
a = tf.contrib.framework.variable('a', [])
with tf.contrib.framework.arg_scope([tf.contrib.framework.variable],
trainable=False,
collections=['A', 'B']):
b = tf.contrib.framework.variable('b', [])
c = tf.contrib.framework.variable('c', [], trainable=False)
self.assertEquals([a, c], tf.contrib.framework.get_variables_to_restore())
self.assertEquals([a], tf.trainable_variables())
self.assertEquals([b], tf.get_collection('A'))
self.assertEquals([b], tf.get_collection('B'))
class GetVariablesBySuffixTest(tf.test.TestCase):
def testGetVariableGivenNameScoped(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
b = tf.contrib.framework.variable('b', [5])
self.assertEquals([a],
tf.contrib.framework.get_variables_by_suffix('a'))
self.assertEquals([b],
tf.contrib.framework.get_variables_by_suffix('b'))
def testGetVariableWithScope(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
fooa = tf.contrib.framework.variable('fooa', [5])
with tf.variable_scope('B'):
a2 = tf.contrib.framework.variable('a', [5])
matched_variables = tf.contrib.framework.get_variables_by_suffix('a')
self.assertEquals([a, fooa, a2], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_suffix('/a')
self.assertEquals([a, a2], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_suffix(
'a', scope='A')
self.assertEquals([a, fooa], matched_variables)
def testGetVariableWithoutScope(self):
with self.test_session():
a = tf.contrib.framework.variable('a', [5])
fooa = tf.contrib.framework.variable('fooa', [5])
b_a = tf.contrib.framework.variable('B/a', [5])
matched_variables = tf.contrib.framework.get_variables_by_suffix('a')
self.assertEquals([a, fooa, b_a], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_suffix('fooa')
self.assertEquals([fooa], matched_variables)
class GetVariablesByNameTest(tf.test.TestCase):
def testGetVariableGivenNameScoped(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
b = tf.contrib.framework.variable('b', [5])
self.assertEquals([a], tf.contrib.framework.get_variables_by_name('a'))
self.assertEquals([b], tf.contrib.framework.get_variables_by_name('b'))
def testGetVariableWithScope(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.variable('a', [5])
fooa = tf.contrib.framework.variable('fooa', [5])
with tf.variable_scope('B'):
a2 = tf.contrib.framework.variable('a', [5])
matched_variables = tf.contrib.framework.get_variables_by_name('a')
self.assertEquals([a, a2], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_name('fooa')
self.assertEquals([fooa], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_name('/a')
self.assertEquals([], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_name('a',
scope='A')
self.assertEquals([a], matched_variables)
def testGetVariableWithoutScope(self):
with self.test_session():
a = tf.contrib.framework.variable('a', [5])
fooa = tf.contrib.framework.variable('fooa', [5])
b_a = tf.contrib.framework.variable('B/a', [5])
matched_variables = tf.contrib.framework.get_variables_by_name('a')
for v in matched_variables:
print(v.name)
self.assertEquals([a, b_a], matched_variables)
matched_variables = tf.contrib.framework.get_variables_by_name('fooa')
self.assertEquals([fooa], matched_variables)
if __name__ == '__main__':
tf.test.main() tf.test.main()