Adds support to non-placeholder inputs in _graph_to_function_def.

Specifically, supports input ops with more than one output tensor.

PiperOrigin-RevId: 157640908
This commit is contained in:
Alexandre Passos 2017-05-31 15:05:52 -07:00 committed by TensorFlower Gardener
parent d310de4fac
commit c048e2938c
2 changed files with 33 additions and 5 deletions

View File

@ -73,15 +73,21 @@ def _get_op_def(op):
def _is_in_placeholders(op, func_arg_placeholders):
return op.values() and (op.values()[0].name in func_arg_placeholders)
"""Checks whether any output of this op is in func_arg_placeholders."""
return op.values() and any(x.name in func_arg_placeholders
for x in op.values())
def _create_input_dict(function_graph, func_arg_placeholders):
def _create_input_dict(function_graph,
func_arg_placeholders,
initial_value=None):
"""Create a mapping from graph tensor names to function tensor names."""
input_dict = {}
if initial_value is None:
input_dict = {}
else:
input_dict = dict(initial_value)
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.values()[0].name] = op.values()[0].name
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
@ -150,6 +156,10 @@ def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
used_names = set()
func.signature.input_arg.extend(
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
# Initializes the input map with all placeholder input tensors.
initial_dict = {}
for o, m in zip(inputs, func.signature.input_arg):
initial_dict[o.name] = m.name
if out_names is None:
used_names = set()
func.signature.output_arg.extend(
@ -165,7 +175,8 @@ def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
func.signature.output_arg.extend(
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders)
input_dict = _create_input_dict(graph, func_arg_placeholders,
initial_value=initial_dict)
for op in operations:
if _is_in_placeholders(op, func_arg_placeholders):

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ -817,6 +818,22 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(out1, np.linspace(2, 11, 10))
self.assertAllEqual(out2, np.linspace(2, 11, 10))
def testTwoInputsSameOp(self):
g = ops.Graph()
with g.as_default():
m = array_ops.placeholder(dtypes.float32)
s, u, v = linalg_ops.svd(m)
ss = math_ops.reduce_sum(s)
uu = math_ops.reduce_sum(u)
vv = math_ops.reduce_sum(v)
result = ss + uu + vv
f = function._graph_to_function_def(
g,
g.get_operations()[1:], # skip the placeholder
[s, u, v],
[result])
self.assertEqual(len(f.signature.input_arg), 3)
class FunctionsFromProtos(test.TestCase):