Merge pull request #3744 from suiyuan2009/fix-py_func-length-1
extract element from list when py_func's output type is a single tensorflow type
This commit is contained in:
commit
891f8f73cc
@ -34,6 +34,13 @@ class PyOpTest(tf.test.TestCase):
|
|||||||
def my_func(x, y):
|
def my_func(x, y):
|
||||||
return np.sinh(x) + np.cosh(y)
|
return np.sinh(x) + np.cosh(y)
|
||||||
|
|
||||||
|
# single type
|
||||||
|
with self.test_session():
|
||||||
|
x = tf.constant(1.0, tf.float32)
|
||||||
|
y = tf.constant(2.0, tf.float32)
|
||||||
|
z = tf.py_func(my_func, [x, y], tf.float32)
|
||||||
|
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
|
||||||
|
|
||||||
# scalar
|
# scalar
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = tf.constant(1.0, tf.float32)
|
x = tf.constant(1.0, tf.float32)
|
||||||
|
@ -132,8 +132,8 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
|||||||
Args:
|
Args:
|
||||||
func: A python function.
|
func: A python function.
|
||||||
inp: A list of `Tensor`.
|
inp: A list of `Tensor`.
|
||||||
Tout: A list of tensorflow data types indicating what `func`
|
Tout: A list of tensorflow data types or a single tensorflow data type
|
||||||
returns.
|
indicating what `func` returns.
|
||||||
stateful: A boolean indicating whether the function should be considered
|
stateful: A boolean indicating whether the function should be considered
|
||||||
stateful or stateless. I.e. whether it, given the same input, will
|
stateful or stateless. I.e. whether it, given the same input, will
|
||||||
return the same output and at the same time does not change state
|
return the same output and at the same time does not change state
|
||||||
@ -142,7 +142,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
|||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `Tensor` which `func` computes.
|
A list of `Tensor` or a single `Tensor` which `func` computes.
|
||||||
"""
|
"""
|
||||||
token = _py_funcs.insert(func)
|
token = _py_funcs.insert(func)
|
||||||
# We tie the registered function's life-time with the current
|
# We tie the registered function's life-time with the current
|
||||||
@ -162,14 +162,20 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
|||||||
# the funcs registry.
|
# the funcs registry.
|
||||||
g._cleanup_py_funcs_used_in_graph.append(cleanup)
|
g._cleanup_py_funcs_used_in_graph.append(cleanup)
|
||||||
|
|
||||||
|
if isinstance(Tout, list):
|
||||||
|
is_list = True
|
||||||
|
else:
|
||||||
|
Tout = [Tout]
|
||||||
|
is_list = False
|
||||||
if stateful:
|
if stateful:
|
||||||
return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name)
|
result = gen_script_ops._py_func(
|
||||||
|
input=inp, token=token, Tout=Tout, name=name)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
else:
|
else:
|
||||||
return gen_script_ops._py_func_stateless(
|
result = gen_script_ops._py_func_stateless(
|
||||||
input=inp, token=token, Tout=Tout,
|
input=inp, token=token, Tout=Tout, name=name)
|
||||||
name=name)
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
return result if is_list else result[0]
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)
|
ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user