Updated foldl, foldr, and map to work with TensorArray.

Change: 114350923
This commit is contained in:
Yuan Yu 2016-02-10 11:45:46 -08:00 committed by TensorFlower Gardener
parent d77ef35e13
commit cab186e7a6
3 changed files with 185 additions and 55 deletions

View File

@ -1160,7 +1160,6 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
// Propagates outputs along out edges, and puts newly ready nodes
// into the ready queue.
ready->clear();
{
FrameState* output_frame = input_frame;
int64 output_iter = input_iter;

View File

@ -1239,28 +1239,52 @@ class ControlFlowTest(tf.test.TestCase):
r = sess.run(r, feed_dict={v: 2.0})
self.assertAllClose(1024.0, r)
def testFold_1(self):
def testFoldl_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
r = control_flow_ops.fold(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, [1])
result = r.eval()
self.assertTrue(check_op_order(elems.graph))
self.assertAllEqual(np.array([208]), result)
def testFold_2(self):
r = control_flow_ops.foldl(
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
self.assertAllEqual(208, r.eval())
r = control_flow_ops.foldl(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(880, r.eval())
def testFoldr_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
ten = tf.convert_to_tensor(10)
def compute(a, x):
r = tf.mul(x, ten)
return tf.add(a, r)
r = control_flow_ops.foldr(
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
self.assertAllEqual(450, r.eval())
r = control_flow_ops.fold(compute, elems, [1])
result = r.eval()
self.assertTrue(check_op_order(elems.graph))
self.assertAllEqual([201], result)
r = control_flow_ops.foldr(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(1282, r.eval())
def testFold_Grad(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = control_flow_ops.foldl(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(720.0, r.eval())
r = control_flow_ops.foldr(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(720.0, r.eval())
def testMap_Simple(self):
with self.test_session():
nums = [1, 2, 3, 4, 5, 6]
elems = tf.constant(nums, name="data")
r = control_flow_ops.map(
lambda x: tf.mul(tf.add(x, 3), 2), elems)
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
def testOneValueCond(self):
with self.test_session():

View File

@ -1150,14 +1150,17 @@ def cond(pred, fn1, fn2, name=None):
return tensors of different types.
Example:
```python
x = constant(2)
y = constant(5)
def f1(): return constant(17)
def f2(): return constant(23)
x = tf.constant(2)
y = tf.constant(5)
def f1(): return tf.mul(x, 17)
def f2(): return tf.add(y, 23)
r = cond(math_ops.less(x, y), f1, f2)
# r is set to f1()
# r is set to f1().
# Operations in f2 (e.g., tf.add) are not executed.
```
"""
with ops.op_scope([pred], name, "cond") as name:
if not callable(fn1):
@ -1534,7 +1537,7 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
Example:
```python
i = Constant(0)
i = constant(0)
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
r = While(c, b, [i])
@ -1746,51 +1749,155 @@ def tuple(tensors, name=None, control_inputs=None):
return tpl
# TODO(yuanbyu): It would be nicer if we could have the distributed list
# support that Derek has been proposing.
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
def fold(fn, elems, elem_shape, name=None):
"""The fold operator on slices of a tensor.
def foldl(fn, elems, initializer=None, name=None):
"""The foldl operator on the unpacked tensors of a tensor.
This fold operator applies the function `fn` to slices of `elems` on
dimension 0. The shape of the slices is specified by `elem_shape`. `elems`
must contain at least one slice (`shape(elems)[0] / elem_shape[0] > 0`).
This foldl operator applies the function `fn` to a sequence of elements
from left to right. The elements are made of the tensors unpacked from
`elems`. If `initializer` is None, `elems` must contain at least one
element.
Args:
fn: The function to be performed on each slice of the tensor.
elems: The tensor to whose slices we want to apply `fn`.
elem_shape: The shape definition for the slices.
name: Optional name prefix for the returned tensors.
fn: The function to be performed.
elems: A tensor to be unpacked.
initializer: (optional) The initial value for the accumulator.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor resulting from applying `fn` consecutively on each slice of
`elems`.
A tensor resulting from applying `fn` consecutively on each
element/slice of `elems`, from left to right.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldl(lambda a, x: a + x, elems)
```
"""
with ops.op_scope([elems], name, "fold") as name:
with ops.op_scope([elems], name, "foldl") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
s0 = array_ops.shape(elems)[0]
d0 = elem_shape[0]
n = math_ops.div(s0, d0)
b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0),
dtype=dtypes.int32)
# Initialize the output with slice 0
b = array_ops.concat(0, [[0], b1])
o = array_ops.slice(elems, b, elem_shape)
i = ops.convert_to_tensor(d0)
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
def Compute(i, o):
b = array_ops.concat(0, [array_ops.expand_dims(i, 0), b1])
x = array_ops.slice(elems, b, elem_shape)
o = fn(o, x)
i = math_ops.add(i, d0)
return [i, o]
r = While(lambda i, o: math_ops.less(i, n), Compute, [i, o])
return r[1]
if initializer is None:
a = elems_ta.read(0)
i = constant_op.constant(1)
else:
a = ops.convert_to_tensor(initializer)
i = constant_op.constant(0)
def compute(i, a):
a = fn(a, elems_ta.read(i))
return [i + 1, a]
_, r_a = While(lambda i, a: i < n, compute, [i, a])
return r_a
def foldr(fn, elems, initializer=None, name=None):
"""The foldr operator operator on the unpacked tensors of a tensor.
This foldr operator applies the function `fn` to a sequence of elements
from right to left. The elements are made of the tensors unpacked from
`elems`. If `initializer` is None, `elems` must contain at least one
element.
Args:
fn: The function to be performed.
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
initializer: (optional) The initial value for the accumulator.
use_tensor_array: (optional) use tensor_array if true.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor resulting from applying `fn` consecutively on each
element/slice of `elems`, from right to left.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldr(lambda a, x: a + x, elems)
```
"""
with ops.op_scope([elems], name, "foldr") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
i = n - 1
a = elems_ta.read(i)
else:
i = n
a = ops.convert_to_tensor(initializer)
def compute(i, a):
i -= 1
a = fn(a, elems_ta.read(i))
return [i, a]
_, r_a = While(lambda i, a: i > 0, compute, [i, a])
return r_a
def map(fn, elems, dtype=None, name=None):
"""The map operator on on the unpacked tensors of a tensor.
This map operator applies the function `fn` to a sequence of elements
from right to left. The elements are made of the tensors unpacked from
`elems`.
Args:
fn: The function to be performed.
elems: A tensor to be unpacked to apply `fn`.
dtype: (optional) The output type of `fn`.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor that packs the results of applying `fn` on each element
of `elems`.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
squares = map(lambda x: x * x, elems)
```
"""
with ops.op_scope([elems], name, "map") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
dtype = dtype if dtype else elems.dtype
# Convert elems to tensor array.
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=0,
dynamic_size=True)
elems_ta = elems_ta.unpack(elems)
n = elems_ta.size()
i = constant_op.constant(0)
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n)
def compute(i, a):
a = a.write(i, fn(elems_ta.read(i)))
i = math_ops.add(i, 1)
return [i, a]
_, r_a = While(lambda i, a: math_ops.less(i, n), compute, [i, acc_ta])
return r_a.pack()
def case(pred_fn_pairs, default, exclusive=False, name="case"):
@ -1943,10 +2050,10 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
ops.RegisterShape("Exit")(common_shapes.unknown_shape)
ops.RegisterShape("Exit")(common_shapes.unchanged_shape)
ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
ops.RegisterShape("RefExit")(common_shapes.unknown_shape)
ops.RegisterShape("RefExit")(common_shapes.unchanged_shape)
ops.RegisterShape("RefNextIteration")(common_shapes.unchanged_shape)
ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
ops.RegisterShape("NoOp")(common_shapes.no_outputs)