Merge pull request #3744 from suiyuan2009/fix-py_func-length-1

extract element from list when py_func's output type is a single tensorflow type
This commit is contained in:
Daniel W Mane 2016-09-06 13:30:24 -07:00 committed by GitHub
commit 891f8f73cc
2 changed files with 20 additions and 7 deletions

View File

@ -34,6 +34,13 @@ class PyOpTest(tf.test.TestCase):
def my_func(x, y): def my_func(x, y):
return np.sinh(x) + np.cosh(y) return np.sinh(x) + np.cosh(y)
# single type
with self.test_session():
x = tf.constant(1.0, tf.float32)
y = tf.constant(2.0, tf.float32)
z = tf.py_func(my_func, [x, y], tf.float32)
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
# scalar # scalar
with self.test_session(): with self.test_session():
x = tf.constant(1.0, tf.float32) x = tf.constant(1.0, tf.float32)

View File

@ -132,8 +132,8 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Args: Args:
func: A python function. func: A python function.
inp: A list of `Tensor`. inp: A list of `Tensor`.
Tout: A list of tensorflow data types indicating what `func` Tout: A list of tensorflow data types or a single tensorflow data type
returns. indicating what `func` returns.
stateful: A boolean indicating whether the function should be considered stateful: A boolean indicating whether the function should be considered
stateful or stateless. I.e. whether it, given the same input, will stateful or stateless. I.e. whether it, given the same input, will
return the same output and at the same time does not change state return the same output and at the same time does not change state
@ -142,7 +142,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
A list of `Tensor` which `func` computes. A list of `Tensor` or a single `Tensor` which `func` computes.
""" """
token = _py_funcs.insert(func) token = _py_funcs.insert(func)
# We tie the registered function's life-time with the current # We tie the registered function's life-time with the current
@ -162,14 +162,20 @@ def py_func(func, inp, Tout, stateful=True, name=None):
# the funcs registry. # the funcs registry.
g._cleanup_py_funcs_used_in_graph.append(cleanup) g._cleanup_py_funcs_used_in_graph.append(cleanup)
if isinstance(Tout, list):
is_list = True
else:
Tout = [Tout]
is_list = False
if stateful: if stateful:
return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name) result = gen_script_ops._py_func(
input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access # pylint: enable=protected-access
else: else:
return gen_script_ops._py_func_stateless( result = gen_script_ops._py_func_stateless(
input=inp, token=token, Tout=Tout, input=inp, token=token, Tout=Tout, name=name)
name=name)
# pylint: enable=protected-access # pylint: enable=protected-access
return result if is_list else result[0]
ops.RegisterShape("PyFunc")(common_shapes.unknown_shape) ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)