Updated foldl, foldr, and map to work with TensorArray.
Change: 114350923
This commit is contained in:
parent
d77ef35e13
commit
cab186e7a6
@ -1160,7 +1160,6 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
// Propagates outputs along out edges, and puts newly ready nodes
|
// Propagates outputs along out edges, and puts newly ready nodes
|
||||||
// into the ready queue.
|
// into the ready queue.
|
||||||
ready->clear();
|
ready->clear();
|
||||||
|
|
||||||
{
|
{
|
||||||
FrameState* output_frame = input_frame;
|
FrameState* output_frame = input_frame;
|
||||||
int64 output_iter = input_iter;
|
int64 output_iter = input_iter;
|
||||||
|
|||||||
@ -1239,28 +1239,52 @@ class ControlFlowTest(tf.test.TestCase):
|
|||||||
r = sess.run(r, feed_dict={v: 2.0})
|
r = sess.run(r, feed_dict={v: 2.0})
|
||||||
self.assertAllClose(1024.0, r)
|
self.assertAllClose(1024.0, r)
|
||||||
|
|
||||||
def testFold_1(self):
|
def testFoldl_Simple(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||||
r = control_flow_ops.fold(
|
|
||||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, [1])
|
|
||||||
result = r.eval()
|
|
||||||
self.assertTrue(check_op_order(elems.graph))
|
|
||||||
self.assertAllEqual(np.array([208]), result)
|
|
||||||
|
|
||||||
def testFold_2(self):
|
r = control_flow_ops.foldl(
|
||||||
|
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||||
|
self.assertAllEqual(208, r.eval())
|
||||||
|
|
||||||
|
r = control_flow_ops.foldl(
|
||||||
|
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||||
|
self.assertAllEqual(880, r.eval())
|
||||||
|
|
||||||
|
def testFoldr_Simple(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||||
ten = tf.convert_to_tensor(10)
|
|
||||||
|
|
||||||
def compute(a, x):
|
r = control_flow_ops.foldr(
|
||||||
r = tf.mul(x, ten)
|
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||||
return tf.add(a, r)
|
self.assertAllEqual(450, r.eval())
|
||||||
|
|
||||||
r = control_flow_ops.fold(compute, elems, [1])
|
r = control_flow_ops.foldr(
|
||||||
result = r.eval()
|
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||||
self.assertTrue(check_op_order(elems.graph))
|
self.assertAllEqual(1282, r.eval())
|
||||||
self.assertAllEqual([201], result)
|
|
||||||
|
def testFold_Grad(self):
|
||||||
|
with self.test_session():
|
||||||
|
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||||
|
v = tf.constant(2.0, name="v")
|
||||||
|
|
||||||
|
r = control_flow_ops.foldl(
|
||||||
|
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||||
|
r = tf.gradients(r, v)[0]
|
||||||
|
self.assertAllEqual(720.0, r.eval())
|
||||||
|
|
||||||
|
r = control_flow_ops.foldr(
|
||||||
|
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||||
|
r = tf.gradients(r, v)[0]
|
||||||
|
self.assertAllEqual(720.0, r.eval())
|
||||||
|
|
||||||
|
def testMap_Simple(self):
|
||||||
|
with self.test_session():
|
||||||
|
nums = [1, 2, 3, 4, 5, 6]
|
||||||
|
elems = tf.constant(nums, name="data")
|
||||||
|
r = control_flow_ops.map(
|
||||||
|
lambda x: tf.mul(tf.add(x, 3), 2), elems)
|
||||||
|
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
||||||
|
|
||||||
def testOneValueCond(self):
|
def testOneValueCond(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
|||||||
@ -1150,14 +1150,17 @@ def cond(pred, fn1, fn2, name=None):
|
|||||||
return tensors of different types.
|
return tensors of different types.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
x = constant(2)
|
x = tf.constant(2)
|
||||||
y = constant(5)
|
y = tf.constant(5)
|
||||||
def f1(): return constant(17)
|
def f1(): return tf.mul(x, 17)
|
||||||
def f2(): return constant(23)
|
def f2(): return tf.add(y, 23)
|
||||||
r = cond(math_ops.less(x, y), f1, f2)
|
r = cond(math_ops.less(x, y), f1, f2)
|
||||||
# r is set to f1()
|
# r is set to f1().
|
||||||
|
# Operations in f2 (e.g., tf.add) are not executed.
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([pred], name, "cond") as name:
|
with ops.op_scope([pred], name, "cond") as name:
|
||||||
if not callable(fn1):
|
if not callable(fn1):
|
||||||
@ -1534,7 +1537,7 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
i = Constant(0)
|
i = constant(0)
|
||||||
c = lambda i: math_ops.less(i, 10)
|
c = lambda i: math_ops.less(i, 10)
|
||||||
b = lambda i: math_ops.add(i, 1)
|
b = lambda i: math_ops.add(i, 1)
|
||||||
r = While(c, b, [i])
|
r = While(c, b, [i])
|
||||||
@ -1746,51 +1749,155 @@ def tuple(tensors, name=None, control_inputs=None):
|
|||||||
return tpl
|
return tpl
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuanbyu): It would be nicer if we could have the distributed list
|
|
||||||
# support that Derek has been proposing.
|
|
||||||
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
|
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
|
||||||
def fold(fn, elems, elem_shape, name=None):
|
def foldl(fn, elems, initializer=None, name=None):
|
||||||
"""The fold operator on slices of a tensor.
|
"""The foldl operator on the unpacked tensors of a tensor.
|
||||||
|
|
||||||
This fold operator applies the function `fn` to slices of `elems` on
|
This foldl operator applies the function `fn` to a sequence of elements
|
||||||
dimension 0. The shape of the slices is specified by `elem_shape`. `elems`
|
from left to right. The elements are made of the tensors unpacked from
|
||||||
must contain at least one slice (`shape(elems)[0] / elem_shape[0] > 0`).
|
`elems`. If `initializer` is None, `elems` must contain at least one
|
||||||
|
element.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn: The function to be performed on each slice of the tensor.
|
fn: The function to be performed.
|
||||||
elems: The tensor to whose slices we want to apply `fn`.
|
elems: A tensor to be unpacked.
|
||||||
elem_shape: The shape definition for the slices.
|
initializer: (optional) The initial value for the accumulator.
|
||||||
name: Optional name prefix for the returned tensors.
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor resulting from applying `fn` consecutively on each slice of
|
A tensor resulting from applying `fn` consecutively on each
|
||||||
`elems`.
|
element/slice of `elems`, from left to right.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `fn` is not callable.
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = foldl(lambda a, x: a + x, elems)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([elems], name, "fold") as name:
|
with ops.op_scope([elems], name, "foldl") as name:
|
||||||
if not callable(fn):
|
if not callable(fn):
|
||||||
raise TypeError("fn must be callable.")
|
raise TypeError("fn must be callable.")
|
||||||
|
|
||||||
s0 = array_ops.shape(elems)[0]
|
# Convert elems to tensor array.
|
||||||
d0 = elem_shape[0]
|
n = array_ops.shape(elems)[0]
|
||||||
n = math_ops.div(s0, d0)
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||||
b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0),
|
dynamic_size=False)
|
||||||
dtype=dtypes.int32)
|
elems_ta = elems_ta.unpack(elems)
|
||||||
# Initialize the output with slice 0
|
|
||||||
b = array_ops.concat(0, [[0], b1])
|
|
||||||
o = array_ops.slice(elems, b, elem_shape)
|
|
||||||
i = ops.convert_to_tensor(d0)
|
|
||||||
|
|
||||||
def Compute(i, o):
|
if initializer is None:
|
||||||
b = array_ops.concat(0, [array_ops.expand_dims(i, 0), b1])
|
a = elems_ta.read(0)
|
||||||
x = array_ops.slice(elems, b, elem_shape)
|
i = constant_op.constant(1)
|
||||||
o = fn(o, x)
|
else:
|
||||||
i = math_ops.add(i, d0)
|
a = ops.convert_to_tensor(initializer)
|
||||||
return [i, o]
|
i = constant_op.constant(0)
|
||||||
r = While(lambda i, o: math_ops.less(i, n), Compute, [i, o])
|
|
||||||
return r[1]
|
def compute(i, a):
|
||||||
|
a = fn(a, elems_ta.read(i))
|
||||||
|
return [i + 1, a]
|
||||||
|
_, r_a = While(lambda i, a: i < n, compute, [i, a])
|
||||||
|
return r_a
|
||||||
|
|
||||||
|
|
||||||
|
def foldr(fn, elems, initializer=None, name=None):
|
||||||
|
"""The foldr operator operator on the unpacked tensors of a tensor.
|
||||||
|
|
||||||
|
This foldr operator applies the function `fn` to a sequence of elements
|
||||||
|
from right to left. The elements are made of the tensors unpacked from
|
||||||
|
`elems`. If `initializer` is None, `elems` must contain at least one
|
||||||
|
element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The function to be performed.
|
||||||
|
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||||
|
initializer: (optional) The initial value for the accumulator.
|
||||||
|
use_tensor_array: (optional) use tensor_array if true.
|
||||||
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor resulting from applying `fn` consecutively on each
|
||||||
|
element/slice of `elems`, from right to left.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = foldr(lambda a, x: a + x, elems)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with ops.op_scope([elems], name, "foldr") as name:
|
||||||
|
if not callable(fn):
|
||||||
|
raise TypeError("fn must be callable.")
|
||||||
|
|
||||||
|
# Convert elems to tensor array.
|
||||||
|
n = array_ops.shape(elems)[0]
|
||||||
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
elems_ta = elems_ta.unpack(elems)
|
||||||
|
|
||||||
|
if initializer is None:
|
||||||
|
i = n - 1
|
||||||
|
a = elems_ta.read(i)
|
||||||
|
else:
|
||||||
|
i = n
|
||||||
|
a = ops.convert_to_tensor(initializer)
|
||||||
|
def compute(i, a):
|
||||||
|
i -= 1
|
||||||
|
a = fn(a, elems_ta.read(i))
|
||||||
|
return [i, a]
|
||||||
|
_, r_a = While(lambda i, a: i > 0, compute, [i, a])
|
||||||
|
return r_a
|
||||||
|
|
||||||
|
|
||||||
|
def map(fn, elems, dtype=None, name=None):
|
||||||
|
"""The map operator on on the unpacked tensors of a tensor.
|
||||||
|
|
||||||
|
This map operator applies the function `fn` to a sequence of elements
|
||||||
|
from right to left. The elements are made of the tensors unpacked from
|
||||||
|
`elems`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The function to be performed.
|
||||||
|
elems: A tensor to be unpacked to apply `fn`.
|
||||||
|
dtype: (optional) The output type of `fn`.
|
||||||
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor that packs the results of applying `fn` on each element
|
||||||
|
of `elems`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
squares = map(lambda x: x * x, elems)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with ops.op_scope([elems], name, "map") as name:
|
||||||
|
if not callable(fn):
|
||||||
|
raise TypeError("fn must be callable.")
|
||||||
|
dtype = dtype if dtype else elems.dtype
|
||||||
|
|
||||||
|
# Convert elems to tensor array.
|
||||||
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=0,
|
||||||
|
dynamic_size=True)
|
||||||
|
elems_ta = elems_ta.unpack(elems)
|
||||||
|
|
||||||
|
n = elems_ta.size()
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n)
|
||||||
|
def compute(i, a):
|
||||||
|
a = a.write(i, fn(elems_ta.read(i)))
|
||||||
|
i = math_ops.add(i, 1)
|
||||||
|
return [i, a]
|
||||||
|
_, r_a = While(lambda i, a: math_ops.less(i, n), compute, [i, acc_ta])
|
||||||
|
return r_a.pack()
|
||||||
|
|
||||||
|
|
||||||
def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||||
@ -1943,10 +2050,10 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
|||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
|
ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("Exit")(common_shapes.unknown_shape)
|
ops.RegisterShape("Exit")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
|
ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
|
ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("RefExit")(common_shapes.unknown_shape)
|
ops.RegisterShape("RefExit")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("RefNextIteration")(common_shapes.unchanged_shape)
|
ops.RegisterShape("RefNextIteration")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
|
ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
|
||||||
ops.RegisterShape("NoOp")(common_shapes.no_outputs)
|
ops.RegisterShape("NoOp")(common_shapes.no_outputs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user