Adds support to non-placeholder inputs in _graph_to_function_def.
Specifically, supports input ops with more than one output tensor. PiperOrigin-RevId: 157640908
This commit is contained in:
parent
d310de4fac
commit
c048e2938c
@ -73,15 +73,21 @@ def _get_op_def(op):
|
||||
|
||||
|
||||
def _is_in_placeholders(op, func_arg_placeholders):
|
||||
return op.values() and (op.values()[0].name in func_arg_placeholders)
|
||||
"""Checks whether any output of this op is in func_arg_placeholders."""
|
||||
return op.values() and any(x.name in func_arg_placeholders
|
||||
for x in op.values())
|
||||
|
||||
|
||||
def _create_input_dict(function_graph, func_arg_placeholders):
|
||||
def _create_input_dict(function_graph,
|
||||
func_arg_placeholders,
|
||||
initial_value=None):
|
||||
"""Create a mapping from graph tensor names to function tensor names."""
|
||||
input_dict = {}
|
||||
if initial_value is None:
|
||||
input_dict = {}
|
||||
else:
|
||||
input_dict = dict(initial_value)
|
||||
for op in function_graph.get_operations():
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
input_dict[op.values()[0].name] = op.values()[0].name
|
||||
input_dict[op.name] = op.name
|
||||
else:
|
||||
op_def = _get_op_def(op)
|
||||
@ -150,6 +156,10 @@ def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||||
used_names = set()
|
||||
func.signature.input_arg.extend(
|
||||
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
|
||||
# Initializes the input map with all placeholder input tensors.
|
||||
initial_dict = {}
|
||||
for o, m in zip(inputs, func.signature.input_arg):
|
||||
initial_dict[o.name] = m.name
|
||||
if out_names is None:
|
||||
used_names = set()
|
||||
func.signature.output_arg.extend(
|
||||
@ -165,7 +175,8 @@ def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||
func_arg_placeholders = set([i.name for i in inputs])
|
||||
input_dict = _create_input_dict(graph, func_arg_placeholders)
|
||||
input_dict = _create_input_dict(graph, func_arg_placeholders,
|
||||
initial_value=initial_dict)
|
||||
|
||||
for op in operations:
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_logging_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -817,6 +818,22 @@ class FunctionTest(test.TestCase):
|
||||
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
||||
self.assertAllEqual(out2, np.linspace(2, 11, 10))
|
||||
|
||||
def testTwoInputsSameOp(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
m = array_ops.placeholder(dtypes.float32)
|
||||
s, u, v = linalg_ops.svd(m)
|
||||
ss = math_ops.reduce_sum(s)
|
||||
uu = math_ops.reduce_sum(u)
|
||||
vv = math_ops.reduce_sum(v)
|
||||
result = ss + uu + vv
|
||||
f = function._graph_to_function_def(
|
||||
g,
|
||||
g.get_operations()[1:], # skip the placeholder
|
||||
[s, u, v],
|
||||
[result])
|
||||
self.assertEqual(len(f.signature.input_arg), 3)
|
||||
|
||||
|
||||
class FunctionsFromProtos(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user