Add stack and unstack operations to tensor_array_ops.py. These operations are
exactly the same as pack and unpack and will replace pack and unpack going forward. Change: 140760260
This commit is contained in:
parent
434794582e
commit
c4d507ac75
tensorflow/python
@ -63,7 +63,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
w1 = w0.write(1, convert([[6.0, 7.0]]))
|
||||
w2 = w1.write(2, convert([[8.0, 9.0]]))
|
||||
|
||||
c0 = w2.pack()
|
||||
c0 = w2.stack()
|
||||
|
||||
self.assertAllEqual(
|
||||
convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval())
|
||||
@ -123,7 +123,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
with self.assertRaisesOpError(
|
||||
"Could not read from TensorArray index 1 "
|
||||
"because it has not yet been written to."):
|
||||
ta.write(0, [[4.0, 5.0]]).pack().eval()
|
||||
ta.write(0, [[4.0, 5.0]]).stack().eval()
|
||||
|
||||
def testTensorArrayPackNotAllValuesAvailableFails(self):
|
||||
self._testTensorArrayPackNotAllValuesAvailableFails()
|
||||
@ -141,7 +141,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
convert = lambda x: np.asarray(x).astype(dtype)
|
||||
|
||||
# Unpack a vector into scalars
|
||||
w0 = ta.unpack(convert([1.0, 2.0, 3.0]))
|
||||
w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
|
||||
r0 = w0.read(0)
|
||||
r1 = w0.read(1)
|
||||
r2 = w0.read(2)
|
||||
@ -155,7 +155,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
dtype=tf_dtype, tensor_array_name="foo", size=3)
|
||||
|
||||
# Unpack a matrix into vectors
|
||||
w1 = ta.unpack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
|
||||
w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
|
||||
r0 = w1.read(0)
|
||||
r1 = w1.read(1)
|
||||
r2 = w1.read(2)
|
||||
@ -171,7 +171,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
dtype=tf_dtype, tensor_array_name="foo", size=3)
|
||||
|
||||
# Try unpacking an empty matrix, which should not cause an error.
|
||||
w2 = ta.unpack(convert([[], [], []]))
|
||||
w2 = ta.unstack(convert([[], [], []]))
|
||||
r0 = w2.read(0)
|
||||
r1 = w2.read(1)
|
||||
r2 = w2.read(2)
|
||||
@ -583,7 +583,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
|
||||
w0 = ta.write(0, value_0)
|
||||
w1 = w0.write(1, value_1)
|
||||
p0 = w1.pack()
|
||||
p0 = w1.stack()
|
||||
r0 = w1.read(0)
|
||||
s0 = w1.concat()
|
||||
|
||||
@ -610,7 +610,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
ta_readonce = tf.TensorArray(
|
||||
dtype=tf.float32, tensor_array_name="foo", size=2)
|
||||
|
||||
w_readonce = ta_readonce.unpack(value)
|
||||
w_readonce = ta_readonce.unstack(value)
|
||||
r0_readonce = w_readonce.read(0)
|
||||
with tf.control_dependencies([r0_readonce]):
|
||||
r1_readonce = w_readonce.read(0)
|
||||
@ -623,7 +623,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
ta_readtwice = tf.TensorArray(
|
||||
dtype=tf.float32, tensor_array_name="foo", size=2,
|
||||
clear_after_read=False)
|
||||
w_readtwice = ta_readtwice.unpack(value)
|
||||
w_readtwice = ta_readtwice.unstack(value)
|
||||
r0_readtwice = w_readtwice.read(0)
|
||||
with tf.control_dependencies([r0_readtwice]):
|
||||
r1_readtwice = w_readtwice.read(0)
|
||||
@ -638,7 +638,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
|
||||
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
|
||||
|
||||
w = ta.unpack(value)
|
||||
w = ta.unstack(value)
|
||||
r0 = w.read(0)
|
||||
r0_1 = w.read(0)
|
||||
r1 = w.read(1)
|
||||
@ -682,7 +682,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
|
||||
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
|
||||
|
||||
w = ta.unpack(value)
|
||||
w = ta.unstack(value)
|
||||
r0 = w.read(0)
|
||||
r1 = w.read(1)
|
||||
|
||||
@ -746,7 +746,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
tensor_shape.unknown_shape(),
|
||||
tensor_shape.unknown_shape()),
|
||||
parallel_iterations=3)
|
||||
vout = h_final.pack()
|
||||
vout = h_final.stack()
|
||||
|
||||
grad_val = -np.arange(3*5, dtype=np_dtype).reshape(3, 5)
|
||||
v0_grad = tf.gradients([vout], [v0], [grad_val])[0]
|
||||
@ -824,7 +824,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
return i + 1, acc.write(i, z)
|
||||
_, acc2 = tf.while_loop(lambda i, acc: i < num_steps, fn, [i1, acc1])
|
||||
|
||||
r = acc2.pack()
|
||||
r = acc2.stack()
|
||||
grad = tf.gradients(r, [x])[0]
|
||||
self.assertAllClose(31.0, grad.eval())
|
||||
|
||||
@ -926,7 +926,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
dtype=tf.float32, tensor_array_name="foo",
|
||||
size=0, dynamic_size=True, infer_shape=True)
|
||||
value = tf.constant([[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
|
||||
w0 = ta.unpack(value)
|
||||
w0 = ta.unstack(value)
|
||||
r0 = w0.read(0)
|
||||
self.assertAllEqual((2,), r0.get_shape())
|
||||
|
||||
@ -972,7 +972,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
ta = tf.TensorArray(dtype=tf.float32, size=2)
|
||||
x = tf.constant([2.0, 3.0])
|
||||
w = ta.unpack(x)
|
||||
w = ta.unstack(x)
|
||||
r0 = w.read(0)
|
||||
# calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
|
||||
grad_r0 = tf.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
|
||||
@ -987,9 +987,9 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
ta = tf.TensorArray(dtype=tf.float32, size=3,
|
||||
dynamic_size=True)
|
||||
x = tf.constant([1.0, 2.0, 3.0])
|
||||
w0 = ta.unpack(x)
|
||||
w0 = ta.unstack(x)
|
||||
w1 = w0.write(3, 4.0)
|
||||
r = w1.pack()
|
||||
r = w1.stack()
|
||||
self.assertAllEqual(np.array([1.0, 2.0, 3.0, 4.0]), r.eval())
|
||||
grad = tf.gradients(ys=[r], xs=[x])
|
||||
self.assertAllEqual(np.array([1.0, 1.0, 1.0]),
|
||||
@ -1021,7 +1021,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
"TensorArray has size zero, but element shape <unknown> is not fully "
|
||||
"defined. Currently only static shapes are supported when packing "
|
||||
"zero-size TensorArrays."):
|
||||
ta.pack().eval()
|
||||
ta.stack().eval()
|
||||
|
||||
def testTensorArrayEvalEmpty(self):
|
||||
self._testTensorArrayEvalEmpty()
|
||||
@ -1034,8 +1034,8 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
infer_shape=True)
|
||||
self.assertEqual(0, ta.size().eval())
|
||||
# Don't actually perform the pack. This stores the static shape.
|
||||
ta.unpack(tf.zeros([0, 3, 5]))
|
||||
packed = ta.pack()
|
||||
ta.unstack(tf.zeros([0, 3, 5]))
|
||||
packed = ta.stack()
|
||||
self.assertAllEqual([0, 3, 5], packed.eval().shape)
|
||||
# Concatenating zero tensors along their first dimension gives a
|
||||
# first dimension of zero
|
||||
@ -1075,7 +1075,7 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
values = tf.constant([[1.0*x, -1.0*x] for x in range(10)])
|
||||
indices = tf.constant([1, 8])
|
||||
|
||||
w = ta.unpack(values)
|
||||
w = ta.unstack(values)
|
||||
g = w.gather(indices)
|
||||
|
||||
# Test combined gradients + aggregation of read(0)
|
||||
@ -1119,11 +1119,11 @@ class TensorArrayTest(tf.test.TestCase):
|
||||
self.assertEqual(ta.handle.device, "")
|
||||
self.assertEqual(ta.flow.device, "")
|
||||
with tf.device("/gpu:0"):
|
||||
ta = ta.unpack([1.0, 2.0])
|
||||
ta = ta.unstack([1.0, 2.0])
|
||||
self.assertTrue("gpu:0" in ta.handle.device.lower())
|
||||
self.assertTrue("gpu:0" in ta.flow.device.lower())
|
||||
with tf.device("/gpu:1"):
|
||||
ta = ta.unpack([1.0, 2.0])
|
||||
ta = ta.unstack([1.0, 2.0])
|
||||
self.assertTrue("gpu:0" in ta.handle.device.lower())
|
||||
self.assertTrue("gpu:0" in ta.flow.device.lower())
|
||||
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
|
||||
|
||||
def _maybe_set_device(handle_op, value_t):
|
||||
@ -262,21 +263,30 @@ class TensorArray(object):
|
||||
ta._elem_shape.append(val_shape)
|
||||
return ta
|
||||
|
||||
def pack(self, name=None):
|
||||
"""Return the values in the TensorArray as a packed `Tensor`.
|
||||
def stack(self, name=None):
|
||||
"""Return the values in the TensorArray as a stacked `Tensor`.
|
||||
|
||||
All of the values must have been written and their shapes must all match.
|
||||
If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
|
||||
|
||||
Args:
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
All the tensors in the TensorArray packed into one tensor.
|
||||
All the tensors in the TensorArray stacked into one tensor.
|
||||
"""
|
||||
with ops.colocate_with(self._handle):
|
||||
with ops.name_scope(name, "TensorArrayPack", [self._handle]):
|
||||
with ops.name_scope(name, "TensorArrayStack", [self._handle]):
|
||||
return self.gather(math_ops.range(0, self.size()), name=name)
|
||||
|
||||
@deprecated(
|
||||
"2016-12-12",
|
||||
"This op will be removed after the deprecation date. "
|
||||
"Please switch to tf.stack.")
|
||||
def pack(self, name=None):
|
||||
return self.stack(name)
|
||||
pack.__doc__ = stack.__doc__
|
||||
|
||||
def gather(self, indices, name=None):
|
||||
"""Return selected values in the TensorArray as a packed `Tensor`.
|
||||
|
||||
@ -335,25 +345,35 @@ class TensorArray(object):
|
||||
value.set_shape([None] + self._elem_shape[0].dims[1:])
|
||||
return value
|
||||
|
||||
def unpack(self, value, name=None):
|
||||
"""Pack the values of a `Tensor` in the TensorArray.
|
||||
def unstack(self, value, name=None):
|
||||
"""Unstack the values of a `Tensor` in the TensorArray.
|
||||
|
||||
If input value shapes have rank-`R`, then the output TensorArray will
|
||||
contain elements whose shapes are rank-`(R-1)`.
|
||||
Args:
|
||||
value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack.
|
||||
value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A new TensorArray object with flow that ensures the unpack occurs.
|
||||
A new TensorArray object with flow that ensures the unstack occurs.
|
||||
Use this object all for subsequent operations.
|
||||
|
||||
Raises:
|
||||
ValueError: if the shape inference fails.
|
||||
"""
|
||||
with ops.name_scope(name, "TensorArrayPack", [self._handle, value]):
|
||||
with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
|
||||
num_elements = array_ops.shape(value)[0]
|
||||
return self.scatter(
|
||||
indices=math_ops.range(0, num_elements), value=value, name=name)
|
||||
|
||||
@deprecated(
|
||||
"2016-12-12",
|
||||
"This op will be removed after the deprecation date. "
|
||||
"Please switch to tf.unstack.")
|
||||
def unpack(self, value, name=None):
|
||||
return self.unstack(value, name)
|
||||
unpack.__doc__ = unstack.__doc__
|
||||
|
||||
def scatter(self, indices, value, name=None):
|
||||
"""Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user