Merge pull request #3744 from suiyuan2009/fix-py_func-length-1
extract element from list when py_func's output type is a single tensorflow type
This commit is contained in:
commit
891f8f73cc
@ -34,6 +34,13 @@ class PyOpTest(tf.test.TestCase):
|
||||
def my_func(x, y):
|
||||
return np.sinh(x) + np.cosh(y)
|
||||
|
||||
# single type
|
||||
with self.test_session():
|
||||
x = tf.constant(1.0, tf.float32)
|
||||
y = tf.constant(2.0, tf.float32)
|
||||
z = tf.py_func(my_func, [x, y], tf.float32)
|
||||
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
|
||||
|
||||
# scalar
|
||||
with self.test_session():
|
||||
x = tf.constant(1.0, tf.float32)
|
||||
|
@ -132,8 +132,8 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
||||
Args:
|
||||
func: A python function.
|
||||
inp: A list of `Tensor`.
|
||||
Tout: A list of tensorflow data types indicating what `func`
|
||||
returns.
|
||||
Tout: A list of tensorflow data types or a single tensorflow data type
|
||||
indicating what `func` returns.
|
||||
stateful: A boolean indicating whether the function should be considered
|
||||
stateful or stateless. I.e. whether it, given the same input, will
|
||||
return the same output and at the same time does not change state
|
||||
@ -142,7 +142,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` which `func` computes.
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes.
|
||||
"""
|
||||
token = _py_funcs.insert(func)
|
||||
# We tie the registered function's life-time with the current
|
||||
@ -162,14 +162,20 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
||||
# the funcs registry.
|
||||
g._cleanup_py_funcs_used_in_graph.append(cleanup)
|
||||
|
||||
if isinstance(Tout, list):
|
||||
is_list = True
|
||||
else:
|
||||
Tout = [Tout]
|
||||
is_list = False
|
||||
if stateful:
|
||||
return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name)
|
||||
result = gen_script_ops._py_func(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
return gen_script_ops._py_func_stateless(
|
||||
input=inp, token=token, Tout=Tout,
|
||||
name=name)
|
||||
result = gen_script_ops._py_func_stateless(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
# pylint: enable=protected-access
|
||||
return result if is_list else result[0]
|
||||
|
||||
|
||||
ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user