Add stack and unstack operations to tensor_array_ops.py. These operations are

exactly the same as pack and unpack and will replace pack and unpack going
forward.
Change: 140760260
This commit is contained in:
A. Unique TensorFlower 2016-12-01 12:05:14 -08:00 committed by TensorFlower Gardener
parent 434794582e
commit c4d507ac75
2 changed files with 51 additions and 31 deletions
tensorflow/python

View File

@ -63,7 +63,7 @@ class TensorArrayTest(tf.test.TestCase):
w1 = w0.write(1, convert([[6.0, 7.0]]))
w2 = w1.write(2, convert([[8.0, 9.0]]))
c0 = w2.pack()
c0 = w2.stack()
self.assertAllEqual(
convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval())
@ -123,7 +123,7 @@ class TensorArrayTest(tf.test.TestCase):
with self.assertRaisesOpError(
"Could not read from TensorArray index 1 "
"because it has not yet been written to."):
ta.write(0, [[4.0, 5.0]]).pack().eval()
ta.write(0, [[4.0, 5.0]]).stack().eval()
def testTensorArrayPackNotAllValuesAvailableFails(self):
self._testTensorArrayPackNotAllValuesAvailableFails()
@ -141,7 +141,7 @@ class TensorArrayTest(tf.test.TestCase):
convert = lambda x: np.asarray(x).astype(dtype)
# Unpack a vector into scalars
w0 = ta.unpack(convert([1.0, 2.0, 3.0]))
w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
@ -155,7 +155,7 @@ class TensorArrayTest(tf.test.TestCase):
dtype=tf_dtype, tensor_array_name="foo", size=3)
# Unpack a matrix into vectors
w1 = ta.unpack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
r0 = w1.read(0)
r1 = w1.read(1)
r2 = w1.read(2)
@ -171,7 +171,7 @@ class TensorArrayTest(tf.test.TestCase):
dtype=tf_dtype, tensor_array_name="foo", size=3)
# Try unpacking an empty matrix, which should not cause an error.
w2 = ta.unpack(convert([[], [], []]))
w2 = ta.unstack(convert([[], [], []]))
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
@ -583,7 +583,7 @@ class TensorArrayTest(tf.test.TestCase):
w0 = ta.write(0, value_0)
w1 = w0.write(1, value_1)
p0 = w1.pack()
p0 = w1.stack()
r0 = w1.read(0)
s0 = w1.concat()
@ -610,7 +610,7 @@ class TensorArrayTest(tf.test.TestCase):
ta_readonce = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
w_readonce = ta_readonce.unpack(value)
w_readonce = ta_readonce.unstack(value)
r0_readonce = w_readonce.read(0)
with tf.control_dependencies([r0_readonce]):
r1_readonce = w_readonce.read(0)
@ -623,7 +623,7 @@ class TensorArrayTest(tf.test.TestCase):
ta_readtwice = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2,
clear_after_read=False)
w_readtwice = ta_readtwice.unpack(value)
w_readtwice = ta_readtwice.unstack(value)
r0_readtwice = w_readtwice.read(0)
with tf.control_dependencies([r0_readtwice]):
r1_readtwice = w_readtwice.read(0)
@ -638,7 +638,7 @@ class TensorArrayTest(tf.test.TestCase):
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
w = ta.unpack(value)
w = ta.unstack(value)
r0 = w.read(0)
r0_1 = w.read(0)
r1 = w.read(1)
@ -682,7 +682,7 @@ class TensorArrayTest(tf.test.TestCase):
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
w = ta.unpack(value)
w = ta.unstack(value)
r0 = w.read(0)
r1 = w.read(1)
@ -746,7 +746,7 @@ class TensorArrayTest(tf.test.TestCase):
tensor_shape.unknown_shape(),
tensor_shape.unknown_shape()),
parallel_iterations=3)
vout = h_final.pack()
vout = h_final.stack()
grad_val = -np.arange(3*5, dtype=np_dtype).reshape(3, 5)
v0_grad = tf.gradients([vout], [v0], [grad_val])[0]
@ -824,7 +824,7 @@ class TensorArrayTest(tf.test.TestCase):
return i + 1, acc.write(i, z)
_, acc2 = tf.while_loop(lambda i, acc: i < num_steps, fn, [i1, acc1])
r = acc2.pack()
r = acc2.stack()
grad = tf.gradients(r, [x])[0]
self.assertAllClose(31.0, grad.eval())
@ -926,7 +926,7 @@ class TensorArrayTest(tf.test.TestCase):
dtype=tf.float32, tensor_array_name="foo",
size=0, dynamic_size=True, infer_shape=True)
value = tf.constant([[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
w0 = ta.unpack(value)
w0 = ta.unstack(value)
r0 = w0.read(0)
self.assertAllEqual((2,), r0.get_shape())
@ -972,7 +972,7 @@ class TensorArrayTest(tf.test.TestCase):
with self.test_session(use_gpu=True) as session:
ta = tf.TensorArray(dtype=tf.float32, size=2)
x = tf.constant([2.0, 3.0])
w = ta.unpack(x)
w = ta.unstack(x)
r0 = w.read(0)
# calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
grad_r0 = tf.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
@ -987,9 +987,9 @@ class TensorArrayTest(tf.test.TestCase):
ta = tf.TensorArray(dtype=tf.float32, size=3,
dynamic_size=True)
x = tf.constant([1.0, 2.0, 3.0])
w0 = ta.unpack(x)
w0 = ta.unstack(x)
w1 = w0.write(3, 4.0)
r = w1.pack()
r = w1.stack()
self.assertAllEqual(np.array([1.0, 2.0, 3.0, 4.0]), r.eval())
grad = tf.gradients(ys=[r], xs=[x])
self.assertAllEqual(np.array([1.0, 1.0, 1.0]),
@ -1021,7 +1021,7 @@ class TensorArrayTest(tf.test.TestCase):
"TensorArray has size zero, but element shape <unknown> is not fully "
"defined. Currently only static shapes are supported when packing "
"zero-size TensorArrays."):
ta.pack().eval()
ta.stack().eval()
def testTensorArrayEvalEmpty(self):
self._testTensorArrayEvalEmpty()
@ -1034,8 +1034,8 @@ class TensorArrayTest(tf.test.TestCase):
infer_shape=True)
self.assertEqual(0, ta.size().eval())
# Don't actually perform the pack. This stores the static shape.
ta.unpack(tf.zeros([0, 3, 5]))
packed = ta.pack()
ta.unstack(tf.zeros([0, 3, 5]))
packed = ta.stack()
self.assertAllEqual([0, 3, 5], packed.eval().shape)
# Concatenating zero tensors along their first dimension gives a
# first dimension of zero
@ -1075,7 +1075,7 @@ class TensorArrayTest(tf.test.TestCase):
values = tf.constant([[1.0*x, -1.0*x] for x in range(10)])
indices = tf.constant([1, 8])
w = ta.unpack(values)
w = ta.unstack(values)
g = w.gather(indices)
# Test combined gradients + aggregation of read(0)
@ -1119,11 +1119,11 @@ class TensorArrayTest(tf.test.TestCase):
self.assertEqual(ta.handle.device, "")
self.assertEqual(ta.flow.device, "")
with tf.device("/gpu:0"):
ta = ta.unpack([1.0, 2.0])
ta = ta.unstack([1.0, 2.0])
self.assertTrue("gpu:0" in ta.handle.device.lower())
self.assertTrue("gpu:0" in ta.flow.device.lower())
with tf.device("/gpu:1"):
ta = ta.unpack([1.0, 2.0])
ta = ta.unstack([1.0, 2.0])
self.assertTrue("gpu:0" in ta.handle.device.lower())
self.assertTrue("gpu:0" in ta.flow.device.lower())

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.deprecation import deprecated
def _maybe_set_device(handle_op, value_t):
@ -262,21 +263,30 @@ class TensorArray(object):
ta._elem_shape.append(val_shape)
return ta
def pack(self, name=None):
"""Return the values in the TensorArray as a packed `Tensor`.
def stack(self, name=None):
"""Return the values in the TensorArray as a stacked `Tensor`.
All of the values must have been written and their shapes must all match.
If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
Args:
name: A name for the operation (optional).
Returns:
All the tensors in the TensorArray packed into one tensor.
All the tensors in the TensorArray stacked into one tensor.
"""
with ops.colocate_with(self._handle):
with ops.name_scope(name, "TensorArrayPack", [self._handle]):
with ops.name_scope(name, "TensorArrayStack", [self._handle]):
return self.gather(math_ops.range(0, self.size()), name=name)
@deprecated(
"2016-12-12",
"This op will be removed after the deprecation date. "
"Please switch to tf.stack.")
def pack(self, name=None):
return self.stack(name)
pack.__doc__ = stack.__doc__
def gather(self, indices, name=None):
"""Return selected values in the TensorArray as a packed `Tensor`.
@ -335,25 +345,35 @@ class TensorArray(object):
value.set_shape([None] + self._elem_shape[0].dims[1:])
return value
def unpack(self, value, name=None):
"""Pack the values of a `Tensor` in the TensorArray.
def unstack(self, value, name=None):
"""Unstack the values of a `Tensor` in the TensorArray.
If input value shapes have rank-`R`, then the output TensorArray will
contain elements whose shapes are rank-`(R-1)`.
Args:
value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack.
value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack.
name: A name for the operation (optional).
Returns:
A new TensorArray object with flow that ensures the unpack occurs.
A new TensorArray object with flow that ensures the unstack occurs.
Use this object all for subsequent operations.
Raises:
ValueError: if the shape inference fails.
"""
with ops.name_scope(name, "TensorArrayPack", [self._handle, value]):
with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
num_elements = array_ops.shape(value)[0]
return self.scatter(
indices=math_ops.range(0, num_elements), value=value, name=name)
@deprecated(
"2016-12-12",
"This op will be removed after the deprecation date. "
"Please switch to tf.unstack.")
def unpack(self, value, name=None):
return self.unstack(value, name)
unpack.__doc__ = unstack.__doc__
def scatter(self, indices, value, name=None):
"""Scatter the values of a `Tensor` in specific indices of a `TensorArray`.