Expose PartitionedVariable.

* Add __len__ and __iter__ methods which work when the number of partition axes
  is equal to 1.
* Add error checks ensuring shapes match and saved info is available.
* Remove the privatization & add unit tests.
Change: 132933434
This commit is contained in:
Eugene Brevdo 2016-09-12 15:28:04 -08:00 committed by TensorFlower Gardener
parent d5d16aa43b
commit 6de77d7f29
5 changed files with 203 additions and 27 deletions

View File

@ -499,7 +499,7 @@ class PartitionedVariablesTestCase(tf.test.TestCase):
c = tf.constant(1.0)
with tf.control_dependencies([c]):
ops_before_concat = session.graph.get_operations()
value = var_x.concat()
value = var_x._concat() # pylint: disable=protected-access
concat_ops = [op for op in session.graph.get_operations()
if op not in ops_before_concat]
@ -507,7 +507,7 @@ class PartitionedVariablesTestCase(tf.test.TestCase):
for ci in op.control_inputs]
self.assertTrue(
c.op in concat_control_inputs,
"var_x.concat() should get control dependencies from its scope.")
"var_x._concat() should get control dependencies from its scope.")
tf.initialize_all_variables().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())

View File

@ -26,6 +26,7 @@ import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
class VariablesTestCase(tf.test.TestCase):
@ -447,6 +448,84 @@ class ObsoleteIsInitializedTest(tf.test.TestCase):
inited.op.run()
class PartitionedVariableTest(tf.test.TestCase):
def testPartitionedVariable(self):
with tf.Graph().as_default():
v0 = tf.Variable([0])
v1 = tf.Variable([1])
v0._set_save_slice_info(variables.Variable.SaveSliceInfo(
v0.name, [2], [0], [1]))
v1._set_save_slice_info(variables.Variable.SaveSliceInfo(
v0.name, [2], [1], [1]))
partitions = [2]
# Pass variable_list as [v1, v0] to ensure they are properly
# re-sorted to [v0, v1] based on their slice info offsets.
partitioned_variable = variables.PartitionedVariable(
name="two_vars",
shape=[2],
dtype=v0.dtype,
variable_list=[v1, v0],
partitions=partitions)
concatenated = tf.convert_to_tensor(partitioned_variable)
num_partitions = len(partitioned_variable)
iterated_partitions = list(partitioned_variable)
self.assertEqual(2, num_partitions)
self.assertEqual([v0, v1], iterated_partitions)
self.assertEqual([2], concatenated.get_shape())
def testPartitionedVariableFailures(self):
with tf.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "empty"):
variables.PartitionedVariable(
name="fail",
shape=2,
dtype=tf.int32,
variable_list=[],
partitions=[])
with self.assertRaisesRegexp(ValueError, "must have a save_slice_info"):
v0 = tf.Variable([0])
partitions = [1]
variables.PartitionedVariable(
name="two_vars",
shape=[1],
dtype=v0.dtype,
variable_list=[v0],
partitions=partitions)
with self.assertRaisesRegexp(ValueError, "full shapes must match"):
v0 = tf.Variable([0])
v1 = tf.Variable([1])
v0._set_save_slice_info(variables.Variable.SaveSliceInfo(
v0.name, [2], [0], [1]))
v1._set_save_slice_info(variables.Variable.SaveSliceInfo(
v0.name, [2], [1], [1]))
partitions = [2]
variables.PartitionedVariable(
name="two_vars",
shape=[3],
dtype=v0.dtype,
variable_list=[v1, v0],
partitions=partitions)
with self.assertRaisesRegexp(ValueError, "must be positive"):
v0 = tf.Variable([0])
v0._set_save_slice_info(variables.Variable.SaveSliceInfo(
v0.name, [2], [0], [1]))
partitions = [0]
variables.PartitionedVariable(
name="two_vars",
shape=[2],
dtype=v0.dtype,
variable_list=[v0],
partitions=partitions)
class VariableContainerTest(tf.test.TestCase):
def testContainer(self):

View File

@ -306,5 +306,5 @@ def create_partitioned_variables(
trainable=trainable,
partitioner=partitioner,
collections=collections)
return partitioned_var._get_variable_list()
return list(partitioned_var)
# pylint: enable=protected-access

View File

@ -563,11 +563,11 @@ class _VariableStore(object):
# pylint: enable=protected-access
# pylint: disable=protected-access
partitioned_var = variables._PartitionedVariable(name=name,
shape=shape,
dtype=dtype,
variable_list=vs,
partitions=partitions)
partitioned_var = variables.PartitionedVariable(name=name,
shape=shape,
dtype=dtype,
variable_list=vs,
partitions=partitions)
# pylint: enable=protected-access
self._partitioned_vars[name] = partitioned_var

View File

@ -687,9 +687,19 @@ class Variable(object):
"""Returns a `Variable` object created from `variable_def`."""
return Variable(variable_def=variable_def)
# Experimental support for saving variables as slices of a larger variable.
class SaveSliceInfo(object):
"""Information on how to save this Variable as a slice."""
"""Information on how to save this Variable as a slice.
Provides internal support for saving variables as slices of a larger
variable. This API is not public and is subject to change.
Available properties:
* full_name
* full_shape
* var_offset
* var_shape
"""
def __init__(self,
full_name=None,
@ -752,31 +762,115 @@ class Variable(object):
"""
self._save_slice_info = save_slice_info
def _get_save_slice_info(self):
return self._save_slice_info
class _PartitionedVariable(object):
"""Wrapper around a list of partitioned `Variable`.
May get merged into the main `Variable` class.
"""
class PartitionedVariable(object):
"""A container for partitioned `Variable` objects."""
class PartitionedVariableIterator(object):
"""An iterator that allows accessing the underlying `Variable` objects.
This iterator is necessary to control order of access when Variables
are not partitioned in a standard way along a single axis.
Allows e.g. `list(partitioned_variable)` to return a proper list.
"""
def __init__(self, partitioned_variable):
self._ix = 0
self._partitioned_variable = partitioned_variable
def __iter__(self):
return self
def __next__(self): # For python3 compatibility.
return self.next()
def next(self):
# pylint: disable=protected-access
if self._ix >= len(self._partitioned_variable._variable_list):
raise StopIteration()
variable = self._partitioned_variable._variable_list[self._ix]
# pylint: enable=protected-access
self._ix += 1
return variable
def __init__(self, name, shape, dtype, variable_list, partitions):
"""Creates a new partitioned variable wrapper.
Variables passed via the variable_list must contain a save_slice_info
field. Concatenation and iteration is in lexicographic order according
to the var_offset property of the save_slice_info.
Args:
name: Overall name of the variables.
shape: Overall shape of the variables.
name: String. Overall name of the variables.
shape: List of integers. Overall shape of the variables.
dtype: Type of the variables.
variable_list: List of `Variable` that comprise this partitioned variable.
partitions: List of number of partitions for each dimension.
partitions: List of integers. Number of partitions for each dimension.
Raises:
TypeError: If `variable_list` is not a list of `Variable` objects, or
`partitions` is not a list.
ValueError: If `variable_list` is empty, or the `Variable` shape
information does not match `shape`, or `partitions` has invalid values.
"""
if not isinstance(variable_list, (list, tuple)):
raise TypeError(
"variable_list is not a list or tuple: %s" % variable_list)
if not isinstance(partitions, (list, tuple)):
raise TypeError("partitions is not a list or tuple: %s" % partitions)
if not all([p >= 1 for p in partitions]):
raise ValueError("partition values must be positive: %s" % partitions)
if not variable_list:
raise ValueError("variable_list may not be empty")
for v in variable_list:
if not isinstance(v, Variable):
raise TypeError("Not all entries in variable_list are variables: %s"
% variable_list)
# Sort the variable_list lexicographically according to var offset value.
# pylint: disable=protected-access
if not all([v._get_save_slice_info() is not None for v in variable_list]):
raise ValueError("All variables must have a save_slice_info available: %s"
% [v.name for v in variable_list])
if len(shape) != len(partitions):
raise ValueError("len(shape) != len(partitions): %s vs. %s"
% (shape, partitions))
if not all([v._get_save_slice_info().full_shape == shape]):
raise ValueError(
"All variables' full shapes must match shape: %s; "
"but full shapes were: %s"
% (shape, str([v._get_save_slice_info().full_shape])))
self._variable_list = sorted(
variable_list, key=lambda v: v._get_save_slice_info().var_offset)
# pylint: enable=protected-access
self._name = name
self._shape = shape
self._dtype = dtype
self._variable_list = variable_list
self._partitions = partitions
self._as_tensor = None
def concat(self):
def __iter__(self):
"""Return an iterable for accessing the underlying partition Variables."""
return self.PartitionedVariableIterator(self)
def __len__(self):
num_partition_axes = len(self._partition_axes())
if num_partition_axes > 1:
raise ValueError("Cannot get a length for %d > 1 partition axes"
% num_partition_axes)
return len(self._variable_list)
def _partition_axes(self):
if all([p == 1 for p in self._partitions]):
return [0]
else:
return [i for i, p in enumerate(self._partitions) if p > 1]
def _concat(self):
"""Returns the overall concatenated value as a `Tensor`.
This is different from using the partitioned variable directly as a tensor
@ -790,10 +884,13 @@ class _PartitionedVariable(object):
with ops.name_scope(None):
return array_ops.identity(self._variable_list[0], name=self._name)
if all([p < 2 for p in self._partitions]):
partition_ix = 0
else:
partition_ix = [i for i, p in enumerate(self._partitions) if p > 1][0]
partition_axes = self._partition_axes()
if len(partition_axes) > 1:
raise NotImplementedError(
"Cannot concatenate along more than one dimension: %s. "
"Multi-axis partition concat is not supported" % str(partition_axes))
partition_ix = partition_axes[0]
with ops.name_scope(self._name + "/ConcatPartitions/"):
concatenated = array_ops.concat(partition_ix, self._variable_list)
@ -815,7 +912,7 @@ class _PartitionedVariable(object):
# Be sure to cache the concatenated tensor to not do extraneous
# computations.
with ops.control_dependencies(None):
self._as_tensor = self.concat()
self._as_tensor = self._concat()
return self._as_tensor
@ -828,7 +925,7 @@ class _PartitionedVariable(object):
"of type '%s'" % (dtype.name, v.dtype.name))
if as_ref:
raise NotImplementedError(
"_PartitionedVariable doesn't support being used as a reference.")
"PartitionedVariable doesn't support being used as a reference.")
else:
return v.as_tensor()
@ -852,7 +949,7 @@ class _PartitionedVariable(object):
def assign(self, value, use_locking=False):
_ = value, use_locking
raise NotImplementedError(
"assign() has not been implemented for _PartitionedVariable.")
"assign() has not been implemented for PartitionedVariable.")
def all_variables():
@ -1070,7 +1167,7 @@ ops.register_tensor_conversion_function(Variable,
Variable._OverloadAllOperators()
ops.register_tensor_conversion_function(
_PartitionedVariable, _PartitionedVariable._TensorConversionFunction)
PartitionedVariable, PartitionedVariable._TensorConversionFunction)
# pylint: enable=protected-access
ops.register_dense_tensor_like_type(Variable)