From e52dfa730a45d0e8e75fc64c587b74d937341fad Mon Sep 17 00:00:00 2001 From: suiyuan2009 <dzm1016397507@gmail.com> Date: Thu, 11 Aug 2016 13:31:57 +0800 Subject: [PATCH] make py_func return scalar when pass single type to output --- .../python/kernel_tests/py_func_test.py | 7 +++++++ tensorflow/python/ops/script_ops.py | 20 ++++++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 3d23e4a8de0..f314712d7cb 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -34,6 +34,13 @@ class PyOpTest(tf.test.TestCase): def my_func(x, y): return np.sinh(x) + np.cosh(y) + # single type + with self.test_session(): + x = tf.constant(1.0, tf.float32) + y = tf.constant(2.0, tf.float32) + z = tf.py_func(my_func, [x, y], tf.float32) + self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32)) + # scalar with self.test_session(): x = tf.constant(1.0, tf.float32) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 69b524a5a98..08c89b99ca8 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -132,8 +132,8 @@ def py_func(func, inp, Tout, stateful=True, name=None): Args: func: A python function. inp: A list of `Tensor`. - Tout: A list of tensorflow data types indicating what `func` - returns. + Tout: A list of tensorflow data types or a single tensorflow data type + indicating what `func` returns. stateful: A boolean indicating whether the function should be considered stateful or stateless. I.e. whether it, given the same input, will return the same output and at the same time does not change state @@ -142,7 +142,7 @@ def py_func(func, inp, Tout, stateful=True, name=None): name: A name for the operation (optional). Returns: - A list of `Tensor` which `func` computes. + A list of `Tensor` or a single `Tensor` which `func` computes. """ token = _py_funcs.insert(func) # We tie the registered function's life-time with the current @@ -162,14 +162,20 @@ def py_func(func, inp, Tout, stateful=True, name=None): # the funcs registry. g._cleanup_py_funcs_used_in_graph.append(cleanup) + if isinstance(Tout, list): + is_list = True + else: + Tout = [Tout] + is_list = False if stateful: - return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name) + result = gen_script_ops._py_func( + input=inp, token=token, Tout=Tout, name=name) # pylint: enable=protected-access else: - return gen_script_ops._py_func_stateless( - input=inp, token=token, Tout=Tout, - name=name) + result = gen_script_ops._py_func_stateless( + input=inp, token=token, Tout=Tout, name=name) # pylint: enable=protected-access + return result if is_list else result[0] ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)