graph_to_function_def gets its own file

PiperOrigin-RevId: 163709410
This commit is contained in:
Alexandre Passos 2017-07-31 10:13:16 -07:00 committed by TensorFlower Gardener
parent 29550762bd
commit b876065afe
6 changed files with 199 additions and 162 deletions

View File

@ -14,6 +14,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
],

View File

@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
@ -84,7 +85,7 @@ class _ExperimentalFuncGraph(function._FuncGraph):
return op.outputs[tensor.value_index]
def _add_op_and_parents(self, op):
op_def = function._get_op_def(op)
op_def = graph_to_function_def._get_op_def(op)
if op_def.is_stateful:
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
"by value." % (op.name, op.type))
@ -178,7 +179,7 @@ class _ExperimentalDefinedFunction(function._DefinedFunction):
self._sub_functions = temp_graph._functions
# Build the FunctionDef
self._definition = function._graph_to_function_def(
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph, temp_graph.get_operations(), inputs, outputs,
out_names=self._out_names)

View File

@ -503,6 +503,7 @@ py_library(
":array_ops",
":dtypes",
":framework_ops",
":graph_to_function_def",
":op_def_registry",
":util",
":variable_scope",
@ -510,6 +511,16 @@ py_library(
],
)
py_library(
name = "graph_to_function_def",
srcs = ["framework/graph_to_function_def.py"],
srcs_version = "PY2AND3",
deps = [
":op_def_registry",
"//tensorflow/core:protos_all_py",
],
)
py_library(
name = "graph_util",
srcs = [

View File

@ -23,13 +23,11 @@ from __future__ import print_function
import collections
import hashlib
import re
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
@ -371,7 +369,7 @@ class _DefinedFunction(object):
# pylint: enable=protected-access
# Build the FunctionDef
self._definition = _graph_to_function_def(
self._definition = graph_to_function_def.graph_to_function_def(
temp_graph,
temp_graph.get_operations(),
inputs,
@ -823,161 +821,6 @@ def _from_library(lib):
return initialized.values()
def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`](
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
protocol buffer that contains all the ops in `operations`. The
operations become the body of the function.
The arguments `inputs` and `outputs` will be listed as the inputs
and outputs tensors of the function. They must be lists of
tensors present in the graph. The lists can optionally be empty.
Args:
graph: Graph.
operations: the operations to put in the function. Must be a subset of
the operations in the graph.
inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
Returns:
A FunctionDef protocol buffer.
Raises:
ValueError: if out_names is specified and the wrong length.
"""
func = function_pb2.FunctionDef()
func.signature.name = "_"
used_names = set()
func.signature.input_arg.extend(
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
# Initializes the input map with all placeholder input tensors.
initial_dict = {}
for o, m in zip(inputs, func.signature.input_arg):
initial_dict[o.name] = m.name
if out_names is None:
used_names = set()
func.signature.output_arg.extend(
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
(len(out_names), len(outputs), ", ".join(out_names)))
elif len(out_names) != len(set(out_names)):
raise ValueError(
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
else:
func.signature.output_arg.extend(
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders,
initial_value=initial_dict)
for op in operations:
if _is_in_placeholders(op, func_arg_placeholders):
continue
_add_op_node(op, func, input_dict)
if out_names is None:
for index, o in enumerate(outputs):
k = func.signature.output_arg[index].name
func.ret[k] = input_dict[o.name]
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
return func
def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o")
def _tensor_to_argdef(t, name=None, used_names=None):
"""Convert tensor t to an argdef, with a specified name or a unique name."""
arg = op_def_pb2.OpDef.ArgDef()
if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
if used_names is not None:
if arg.name in used_names:
i = 0
while True:
new_name = "%s_U%d" % (arg.name, i)
if new_name not in used_names:
arg.name = new_name
break
i += 1
used_names.add(arg.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum
return arg
def _get_node_def(op):
return op._node_def # pylint: disable=protected-access
def _get_op_def(op):
return op.op_def or op_def_registry.get_registered_ops()[op.type]
def _is_in_placeholders(op, func_arg_placeholders):
"""Checks whether any output of this op is in func_arg_placeholders."""
return op.values() and any(x.name in func_arg_placeholders
for x in op.values())
def _create_input_dict(function_graph,
func_arg_placeholders,
initial_value=None):
"""Create a mapping from graph tensor names to function tensor names."""
if initial_value is None:
input_dict = {}
else:
input_dict = dict(initial_value)
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
attrs = _get_node_def(op).attr
o = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
num = attrs[arg_def.number_attr].i
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
else:
num = 1
for i in range(num):
result = "%s:%s:%d" % (op.name, arg_def.name, i)
input_dict[op.values()[o].name] = result
if o == 0:
input_dict[op.name] = result
o += 1
return input_dict
def _add_op_node(op, func, input_dict):
"""Converts an op to a function def node and add it to `func`."""
# Add an entry in func.node_def
# Note that extend() makes a copy in this case, see:
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
func.node_def.extend([_get_node_def(op)])
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, ("%s missing from %s" %
(node_def.input[i],
input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
def _parse_kwargs_as_attrs(func_name, **kwargs):
"""Parses **kwargs into a node's attributes."""
attrs = {}

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@ -852,7 +853,7 @@ class FunctionTest(test.TestCase):
uu = math_ops.reduce_sum(u)
vv = math_ops.reduce_sum(v)
result = ss + uu + vv
f = function._graph_to_function_def(
f = graph_to_function_def.graph_to_function_def(
g,
g.get_operations()[1:], # skip the placeholder
[s, u, v],

View File

@ -0,0 +1,180 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Utility to convert a Graph to a FunctionDef."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.python.framework import op_def_registry
def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o")
def _tensor_to_argdef(t, name=None, used_names=None):
"""Convert tensor t to an argdef, with a specified name or a unique name."""
arg = op_def_pb2.OpDef.ArgDef()
if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
if used_names is not None:
if arg.name in used_names:
i = 0
while True:
new_name = "%s_U%d" % (arg.name, i)
if new_name not in used_names:
arg.name = new_name
break
i += 1
used_names.add(arg.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum
return arg
def _is_in_placeholders(op, func_arg_placeholders):
"""Checks whether any output of this op is in func_arg_placeholders."""
return op.values() and any(x.name in func_arg_placeholders
for x in op.values())
def _get_node_def(op):
return op._node_def # pylint: disable=protected-access
def _get_op_def(op):
return op.op_def or op_def_registry.get_registered_ops()[op.type]
def _create_input_dict(function_graph,
func_arg_placeholders,
initial_value=None):
"""Create a mapping from graph tensor names to function tensor names."""
if initial_value is None:
input_dict = {}
else:
input_dict = dict(initial_value)
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
attrs = _get_node_def(op).attr
o = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
num = attrs[arg_def.number_attr].i
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
else:
num = 1
for i in range(num):
result = "%s:%s:%d" % (op.name, arg_def.name, i)
input_dict[op.values()[o].name] = result
if o == 0:
input_dict[op.name] = result
o += 1
return input_dict
def _add_op_node(op, func, input_dict):
"""Converts an op to a function def node and add it to `func`."""
# Add an entry in func.node_def
# Note that extend() makes a copy in this case, see:
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
func.node_def.extend([_get_node_def(op)])
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, ("%s missing from %s" %
(node_def.input[i],
input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`](
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
protocol buffer that contains all the ops in `operations`. The
operations become the body of the function.
The arguments `inputs` and `outputs` will be listed as the inputs
and outputs tensors of the function. They must be lists of
tensors present in the graph. The lists can optionally be empty.
Args:
graph: Graph.
operations: the operations to put in the function. Must be a subset of
the operations in the graph.
inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
Returns:
A FunctionDef protocol buffer.
Raises:
ValueError: if out_names is specified and the wrong length.
"""
func = function_pb2.FunctionDef()
func.signature.name = "_"
used_names = set()
func.signature.input_arg.extend(
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
# Initializes the input map with all placeholder input tensors.
initial_dict = {}
for o, m in zip(inputs, func.signature.input_arg):
initial_dict[o.name] = m.name
if out_names is None:
used_names = set()
func.signature.output_arg.extend(
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
(len(out_names), len(outputs), ", ".join(out_names)))
elif len(out_names) != len(set(out_names)):
raise ValueError(
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
else:
func.signature.output_arg.extend(
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders,
initial_value=initial_dict)
for op in operations:
if _is_in_placeholders(op, func_arg_placeholders):
continue
_add_op_node(op, func, input_dict)
if out_names is None:
for index, o in enumerate(outputs):
k = func.signature.output_arg[index].name
func.ret[k] = input_dict[o.name]
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
return func