Expose PartitionedVariable.
* Add __len__ and __iter__ methods which work when the number of partition axes is equal to 1. * Add error checks ensuring shapes match and saved info is available. * Remove the privatization & add unit tests. Change: 132933434
This commit is contained in:
parent
d5d16aa43b
commit
6de77d7f29
tensorflow/python
@ -499,7 +499,7 @@ class PartitionedVariablesTestCase(tf.test.TestCase):
|
||||
c = tf.constant(1.0)
|
||||
with tf.control_dependencies([c]):
|
||||
ops_before_concat = session.graph.get_operations()
|
||||
value = var_x.concat()
|
||||
value = var_x._concat() # pylint: disable=protected-access
|
||||
concat_ops = [op for op in session.graph.get_operations()
|
||||
if op not in ops_before_concat]
|
||||
|
||||
@ -507,7 +507,7 @@ class PartitionedVariablesTestCase(tf.test.TestCase):
|
||||
for ci in op.control_inputs]
|
||||
self.assertTrue(
|
||||
c.op in concat_control_inputs,
|
||||
"var_x.concat() should get control dependencies from its scope.")
|
||||
"var_x._concat() should get control dependencies from its scope.")
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
|
||||
|
||||
|
@ -26,6 +26,7 @@ import tensorflow as tf
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
class VariablesTestCase(tf.test.TestCase):
|
||||
@ -447,6 +448,84 @@ class ObsoleteIsInitializedTest(tf.test.TestCase):
|
||||
inited.op.run()
|
||||
|
||||
|
||||
class PartitionedVariableTest(tf.test.TestCase):
|
||||
|
||||
def testPartitionedVariable(self):
|
||||
with tf.Graph().as_default():
|
||||
v0 = tf.Variable([0])
|
||||
v1 = tf.Variable([1])
|
||||
v0._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||
v0.name, [2], [0], [1]))
|
||||
v1._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||
v0.name, [2], [1], [1]))
|
||||
partitions = [2]
|
||||
|
||||
# Pass variable_list as [v1, v0] to ensure they are properly
|
||||
# re-sorted to [v0, v1] based on their slice info offsets.
|
||||
partitioned_variable = variables.PartitionedVariable(
|
||||
name="two_vars",
|
||||
shape=[2],
|
||||
dtype=v0.dtype,
|
||||
variable_list=[v1, v0],
|
||||
partitions=partitions)
|
||||
|
||||
concatenated = tf.convert_to_tensor(partitioned_variable)
|
||||
num_partitions = len(partitioned_variable)
|
||||
iterated_partitions = list(partitioned_variable)
|
||||
self.assertEqual(2, num_partitions)
|
||||
self.assertEqual([v0, v1], iterated_partitions)
|
||||
self.assertEqual([2], concatenated.get_shape())
|
||||
|
||||
def testPartitionedVariableFailures(self):
|
||||
with tf.Graph().as_default():
|
||||
with self.assertRaisesRegexp(ValueError, "empty"):
|
||||
variables.PartitionedVariable(
|
||||
name="fail",
|
||||
shape=2,
|
||||
dtype=tf.int32,
|
||||
variable_list=[],
|
||||
partitions=[])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "must have a save_slice_info"):
|
||||
v0 = tf.Variable([0])
|
||||
partitions = [1]
|
||||
variables.PartitionedVariable(
|
||||
name="two_vars",
|
||||
shape=[1],
|
||||
dtype=v0.dtype,
|
||||
variable_list=[v0],
|
||||
partitions=partitions)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "full shapes must match"):
|
||||
v0 = tf.Variable([0])
|
||||
v1 = tf.Variable([1])
|
||||
v0._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||
v0.name, [2], [0], [1]))
|
||||
v1._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||
v0.name, [2], [1], [1]))
|
||||
partitions = [2]
|
||||
|
||||
variables.PartitionedVariable(
|
||||
name="two_vars",
|
||||
shape=[3],
|
||||
dtype=v0.dtype,
|
||||
variable_list=[v1, v0],
|
||||
partitions=partitions)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "must be positive"):
|
||||
v0 = tf.Variable([0])
|
||||
v0._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||
v0.name, [2], [0], [1]))
|
||||
partitions = [0]
|
||||
|
||||
variables.PartitionedVariable(
|
||||
name="two_vars",
|
||||
shape=[2],
|
||||
dtype=v0.dtype,
|
||||
variable_list=[v0],
|
||||
partitions=partitions)
|
||||
|
||||
|
||||
class VariableContainerTest(tf.test.TestCase):
|
||||
|
||||
def testContainer(self):
|
||||
|
@ -306,5 +306,5 @@ def create_partitioned_variables(
|
||||
trainable=trainable,
|
||||
partitioner=partitioner,
|
||||
collections=collections)
|
||||
return partitioned_var._get_variable_list()
|
||||
return list(partitioned_var)
|
||||
# pylint: enable=protected-access
|
||||
|
@ -563,11 +563,11 @@ class _VariableStore(object):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# pylint: disable=protected-access
|
||||
partitioned_var = variables._PartitionedVariable(name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
variable_list=vs,
|
||||
partitions=partitions)
|
||||
partitioned_var = variables.PartitionedVariable(name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
variable_list=vs,
|
||||
partitions=partitions)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self._partitioned_vars[name] = partitioned_var
|
||||
|
@ -687,9 +687,19 @@ class Variable(object):
|
||||
"""Returns a `Variable` object created from `variable_def`."""
|
||||
return Variable(variable_def=variable_def)
|
||||
|
||||
# Experimental support for saving variables as slices of a larger variable.
|
||||
class SaveSliceInfo(object):
|
||||
"""Information on how to save this Variable as a slice."""
|
||||
"""Information on how to save this Variable as a slice.
|
||||
|
||||
Provides internal support for saving variables as slices of a larger
|
||||
variable. This API is not public and is subject to change.
|
||||
|
||||
Available properties:
|
||||
|
||||
* full_name
|
||||
* full_shape
|
||||
* var_offset
|
||||
* var_shape
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
full_name=None,
|
||||
@ -752,31 +762,115 @@ class Variable(object):
|
||||
"""
|
||||
self._save_slice_info = save_slice_info
|
||||
|
||||
def _get_save_slice_info(self):
|
||||
return self._save_slice_info
|
||||
|
||||
class _PartitionedVariable(object):
|
||||
"""Wrapper around a list of partitioned `Variable`.
|
||||
|
||||
May get merged into the main `Variable` class.
|
||||
"""
|
||||
class PartitionedVariable(object):
|
||||
"""A container for partitioned `Variable` objects."""
|
||||
|
||||
class PartitionedVariableIterator(object):
|
||||
"""An iterator that allows accessing the underlying `Variable` objects.
|
||||
|
||||
This iterator is necessary to control order of access when Variables
|
||||
are not partitioned in a standard way along a single axis.
|
||||
|
||||
Allows e.g. `list(partitioned_variable)` to return a proper list.
|
||||
"""
|
||||
|
||||
def __init__(self, partitioned_variable):
|
||||
self._ix = 0
|
||||
self._partitioned_variable = partitioned_variable
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self): # For python3 compatibility.
|
||||
return self.next()
|
||||
|
||||
def next(self):
|
||||
# pylint: disable=protected-access
|
||||
if self._ix >= len(self._partitioned_variable._variable_list):
|
||||
raise StopIteration()
|
||||
variable = self._partitioned_variable._variable_list[self._ix]
|
||||
# pylint: enable=protected-access
|
||||
self._ix += 1
|
||||
return variable
|
||||
|
||||
def __init__(self, name, shape, dtype, variable_list, partitions):
|
||||
"""Creates a new partitioned variable wrapper.
|
||||
|
||||
Variables passed via the variable_list must contain a save_slice_info
|
||||
field. Concatenation and iteration is in lexicographic order according
|
||||
to the var_offset property of the save_slice_info.
|
||||
|
||||
Args:
|
||||
name: Overall name of the variables.
|
||||
shape: Overall shape of the variables.
|
||||
name: String. Overall name of the variables.
|
||||
shape: List of integers. Overall shape of the variables.
|
||||
dtype: Type of the variables.
|
||||
variable_list: List of `Variable` that comprise this partitioned variable.
|
||||
partitions: List of number of partitions for each dimension.
|
||||
partitions: List of integers. Number of partitions for each dimension.
|
||||
|
||||
Raises:
|
||||
TypeError: If `variable_list` is not a list of `Variable` objects, or
|
||||
`partitions` is not a list.
|
||||
ValueError: If `variable_list` is empty, or the `Variable` shape
|
||||
information does not match `shape`, or `partitions` has invalid values.
|
||||
"""
|
||||
if not isinstance(variable_list, (list, tuple)):
|
||||
raise TypeError(
|
||||
"variable_list is not a list or tuple: %s" % variable_list)
|
||||
if not isinstance(partitions, (list, tuple)):
|
||||
raise TypeError("partitions is not a list or tuple: %s" % partitions)
|
||||
if not all([p >= 1 for p in partitions]):
|
||||
raise ValueError("partition values must be positive: %s" % partitions)
|
||||
if not variable_list:
|
||||
raise ValueError("variable_list may not be empty")
|
||||
for v in variable_list:
|
||||
if not isinstance(v, Variable):
|
||||
raise TypeError("Not all entries in variable_list are variables: %s"
|
||||
% variable_list)
|
||||
# Sort the variable_list lexicographically according to var offset value.
|
||||
# pylint: disable=protected-access
|
||||
if not all([v._get_save_slice_info() is not None for v in variable_list]):
|
||||
raise ValueError("All variables must have a save_slice_info available: %s"
|
||||
% [v.name for v in variable_list])
|
||||
if len(shape) != len(partitions):
|
||||
raise ValueError("len(shape) != len(partitions): %s vs. %s"
|
||||
% (shape, partitions))
|
||||
if not all([v._get_save_slice_info().full_shape == shape]):
|
||||
raise ValueError(
|
||||
"All variables' full shapes must match shape: %s; "
|
||||
"but full shapes were: %s"
|
||||
% (shape, str([v._get_save_slice_info().full_shape])))
|
||||
self._variable_list = sorted(
|
||||
variable_list, key=lambda v: v._get_save_slice_info().var_offset)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self._name = name
|
||||
self._shape = shape
|
||||
self._dtype = dtype
|
||||
self._variable_list = variable_list
|
||||
self._partitions = partitions
|
||||
self._as_tensor = None
|
||||
|
||||
def concat(self):
|
||||
def __iter__(self):
|
||||
"""Return an iterable for accessing the underlying partition Variables."""
|
||||
return self.PartitionedVariableIterator(self)
|
||||
|
||||
def __len__(self):
|
||||
num_partition_axes = len(self._partition_axes())
|
||||
if num_partition_axes > 1:
|
||||
raise ValueError("Cannot get a length for %d > 1 partition axes"
|
||||
% num_partition_axes)
|
||||
return len(self._variable_list)
|
||||
|
||||
def _partition_axes(self):
|
||||
if all([p == 1 for p in self._partitions]):
|
||||
return [0]
|
||||
else:
|
||||
return [i for i, p in enumerate(self._partitions) if p > 1]
|
||||
|
||||
def _concat(self):
|
||||
"""Returns the overall concatenated value as a `Tensor`.
|
||||
|
||||
This is different from using the partitioned variable directly as a tensor
|
||||
@ -790,10 +884,13 @@ class _PartitionedVariable(object):
|
||||
with ops.name_scope(None):
|
||||
return array_ops.identity(self._variable_list[0], name=self._name)
|
||||
|
||||
if all([p < 2 for p in self._partitions]):
|
||||
partition_ix = 0
|
||||
else:
|
||||
partition_ix = [i for i, p in enumerate(self._partitions) if p > 1][0]
|
||||
partition_axes = self._partition_axes()
|
||||
|
||||
if len(partition_axes) > 1:
|
||||
raise NotImplementedError(
|
||||
"Cannot concatenate along more than one dimension: %s. "
|
||||
"Multi-axis partition concat is not supported" % str(partition_axes))
|
||||
partition_ix = partition_axes[0]
|
||||
|
||||
with ops.name_scope(self._name + "/ConcatPartitions/"):
|
||||
concatenated = array_ops.concat(partition_ix, self._variable_list)
|
||||
@ -815,7 +912,7 @@ class _PartitionedVariable(object):
|
||||
# Be sure to cache the concatenated tensor to not do extraneous
|
||||
# computations.
|
||||
with ops.control_dependencies(None):
|
||||
self._as_tensor = self.concat()
|
||||
self._as_tensor = self._concat()
|
||||
|
||||
return self._as_tensor
|
||||
|
||||
@ -828,7 +925,7 @@ class _PartitionedVariable(object):
|
||||
"of type '%s'" % (dtype.name, v.dtype.name))
|
||||
if as_ref:
|
||||
raise NotImplementedError(
|
||||
"_PartitionedVariable doesn't support being used as a reference.")
|
||||
"PartitionedVariable doesn't support being used as a reference.")
|
||||
else:
|
||||
return v.as_tensor()
|
||||
|
||||
@ -852,7 +949,7 @@ class _PartitionedVariable(object):
|
||||
def assign(self, value, use_locking=False):
|
||||
_ = value, use_locking
|
||||
raise NotImplementedError(
|
||||
"assign() has not been implemented for _PartitionedVariable.")
|
||||
"assign() has not been implemented for PartitionedVariable.")
|
||||
|
||||
|
||||
def all_variables():
|
||||
@ -1070,7 +1167,7 @@ ops.register_tensor_conversion_function(Variable,
|
||||
Variable._OverloadAllOperators()
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
_PartitionedVariable, _PartitionedVariable._TensorConversionFunction)
|
||||
PartitionedVariable, PartitionedVariable._TensorConversionFunction)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
ops.register_dense_tensor_like_type(Variable)
|
||||
|
Loading…
Reference in New Issue
Block a user