graph_to_function_def gets its own file
PiperOrigin-RevId: 163709410
This commit is contained in:
parent
29550762bd
commit
b876065afe
@ -14,6 +14,7 @@ py_library(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:graph_to_function_def",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
|
@ -22,6 +22,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
@ -84,7 +85,7 @@ class _ExperimentalFuncGraph(function._FuncGraph):
|
||||
return op.outputs[tensor.value_index]
|
||||
|
||||
def _add_op_and_parents(self, op):
|
||||
op_def = function._get_op_def(op)
|
||||
op_def = graph_to_function_def._get_op_def(op)
|
||||
if op_def.is_stateful:
|
||||
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
|
||||
"by value." % (op.name, op.type))
|
||||
@ -178,7 +179,7 @@ class _ExperimentalDefinedFunction(function._DefinedFunction):
|
||||
self._sub_functions = temp_graph._functions
|
||||
|
||||
# Build the FunctionDef
|
||||
self._definition = function._graph_to_function_def(
|
||||
self._definition = graph_to_function_def.graph_to_function_def(
|
||||
temp_graph, temp_graph.get_operations(), inputs, outputs,
|
||||
out_names=self._out_names)
|
||||
|
||||
|
@ -503,6 +503,7 @@ py_library(
|
||||
":array_ops",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":graph_to_function_def",
|
||||
":op_def_registry",
|
||||
":util",
|
||||
":variable_scope",
|
||||
@ -510,6 +511,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "graph_to_function_def",
|
||||
srcs = ["framework/graph_to_function_def.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":op_def_registry",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "graph_util",
|
||||
srcs = [
|
||||
|
@ -23,13 +23,11 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -371,7 +369,7 @@ class _DefinedFunction(object):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Build the FunctionDef
|
||||
self._definition = _graph_to_function_def(
|
||||
self._definition = graph_to_function_def.graph_to_function_def(
|
||||
temp_graph,
|
||||
temp_graph.get_operations(),
|
||||
inputs,
|
||||
@ -823,161 +821,6 @@ def _from_library(lib):
|
||||
return initialized.values()
|
||||
|
||||
|
||||
def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||||
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||||
|
||||
This method creates a [`FunctionDef`](
|
||||
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
|
||||
protocol buffer that contains all the ops in `operations`. The
|
||||
operations become the body of the function.
|
||||
|
||||
The arguments `inputs` and `outputs` will be listed as the inputs
|
||||
and outputs tensors of the function. They must be lists of
|
||||
tensors present in the graph. The lists can optionally be empty.
|
||||
|
||||
Args:
|
||||
graph: Graph.
|
||||
operations: the operations to put in the function. Must be a subset of
|
||||
the operations in the graph.
|
||||
inputs: List of tensors. Inputs to the function.
|
||||
outputs: List of tensors. Outputs of the function.
|
||||
out_names: Optional list of string names for the outputs.
|
||||
|
||||
Returns:
|
||||
A FunctionDef protocol buffer.
|
||||
|
||||
Raises:
|
||||
ValueError: if out_names is specified and the wrong length.
|
||||
"""
|
||||
func = function_pb2.FunctionDef()
|
||||
func.signature.name = "_"
|
||||
used_names = set()
|
||||
func.signature.input_arg.extend(
|
||||
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
|
||||
# Initializes the input map with all placeholder input tensors.
|
||||
initial_dict = {}
|
||||
for o, m in zip(inputs, func.signature.input_arg):
|
||||
initial_dict[o.name] = m.name
|
||||
if out_names is None:
|
||||
used_names = set()
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
||||
elif len(outputs) != len(out_names):
|
||||
raise ValueError(
|
||||
"Length of out_names (%d) does not match number of outputs (%d): %s" %
|
||||
(len(out_names), len(outputs), ", ".join(out_names)))
|
||||
elif len(out_names) != len(set(out_names)):
|
||||
raise ValueError(
|
||||
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
|
||||
else:
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||
func_arg_placeholders = set([i.name for i in inputs])
|
||||
input_dict = _create_input_dict(graph, func_arg_placeholders,
|
||||
initial_value=initial_dict)
|
||||
|
||||
for op in operations:
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
continue
|
||||
_add_op_node(op, func, input_dict)
|
||||
|
||||
if out_names is None:
|
||||
for index, o in enumerate(outputs):
|
||||
k = func.signature.output_arg[index].name
|
||||
func.ret[k] = input_dict[o.name]
|
||||
else:
|
||||
for o, n in zip(outputs, out_names):
|
||||
func.ret[n] = input_dict[o.name]
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _make_argname_from_tensor_name(name):
|
||||
return re.sub(":0$", "", name).replace(":", "_o")
|
||||
|
||||
|
||||
def _tensor_to_argdef(t, name=None, used_names=None):
|
||||
"""Convert tensor t to an argdef, with a specified name or a unique name."""
|
||||
arg = op_def_pb2.OpDef.ArgDef()
|
||||
if name is None:
|
||||
arg.name = _make_argname_from_tensor_name(t.name)
|
||||
if used_names is not None:
|
||||
if arg.name in used_names:
|
||||
i = 0
|
||||
while True:
|
||||
new_name = "%s_U%d" % (arg.name, i)
|
||||
if new_name not in used_names:
|
||||
arg.name = new_name
|
||||
break
|
||||
i += 1
|
||||
used_names.add(arg.name)
|
||||
else:
|
||||
arg.name = name
|
||||
arg.type = t.dtype.as_datatype_enum
|
||||
return arg
|
||||
|
||||
|
||||
def _get_node_def(op):
|
||||
return op._node_def # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _get_op_def(op):
|
||||
return op.op_def or op_def_registry.get_registered_ops()[op.type]
|
||||
|
||||
|
||||
def _is_in_placeholders(op, func_arg_placeholders):
|
||||
"""Checks whether any output of this op is in func_arg_placeholders."""
|
||||
return op.values() and any(x.name in func_arg_placeholders
|
||||
for x in op.values())
|
||||
|
||||
|
||||
def _create_input_dict(function_graph,
|
||||
func_arg_placeholders,
|
||||
initial_value=None):
|
||||
"""Create a mapping from graph tensor names to function tensor names."""
|
||||
if initial_value is None:
|
||||
input_dict = {}
|
||||
else:
|
||||
input_dict = dict(initial_value)
|
||||
for op in function_graph.get_operations():
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
input_dict[op.name] = op.name
|
||||
else:
|
||||
op_def = _get_op_def(op)
|
||||
attrs = _get_node_def(op).attr
|
||||
o = 0
|
||||
for arg_def in op_def.output_arg:
|
||||
if arg_def.number_attr:
|
||||
num = attrs[arg_def.number_attr].i
|
||||
elif arg_def.type_list_attr:
|
||||
num = len(attrs[arg_def.type_list_attr].list.type)
|
||||
else:
|
||||
num = 1
|
||||
for i in range(num):
|
||||
result = "%s:%s:%d" % (op.name, arg_def.name, i)
|
||||
input_dict[op.values()[o].name] = result
|
||||
if o == 0:
|
||||
input_dict[op.name] = result
|
||||
o += 1
|
||||
return input_dict
|
||||
|
||||
|
||||
def _add_op_node(op, func, input_dict):
|
||||
"""Converts an op to a function def node and add it to `func`."""
|
||||
# Add an entry in func.node_def
|
||||
|
||||
# Note that extend() makes a copy in this case, see:
|
||||
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
|
||||
func.node_def.extend([_get_node_def(op)])
|
||||
node_def = func.node_def[-1]
|
||||
for i in range(len(node_def.input)):
|
||||
if not node_def.input[i].startswith("^"):
|
||||
assert node_def.input[i] in input_dict, ("%s missing from %s" %
|
||||
(node_def.input[i],
|
||||
input_dict.items()))
|
||||
node_def.input[i] = input_dict[node_def.input[i]]
|
||||
|
||||
|
||||
def _parse_kwargs_as_attrs(func_name, **kwargs):
|
||||
"""Parses **kwargs into a node's attributes."""
|
||||
attrs = {}
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -852,7 +853,7 @@ class FunctionTest(test.TestCase):
|
||||
uu = math_ops.reduce_sum(u)
|
||||
vv = math_ops.reduce_sum(v)
|
||||
result = ss + uu + vv
|
||||
f = function._graph_to_function_def(
|
||||
f = graph_to_function_def.graph_to_function_def(
|
||||
g,
|
||||
g.get_operations()[1:], # skip the placeholder
|
||||
[s, u, v],
|
||||
|
180
tensorflow/python/framework/graph_to_function_def.py
Normal file
180
tensorflow/python/framework/graph_to_function_def.py
Normal file
@ -0,0 +1,180 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Utility to convert a Graph to a FunctionDef."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
|
||||
|
||||
def _make_argname_from_tensor_name(name):
|
||||
return re.sub(":0$", "", name).replace(":", "_o")
|
||||
|
||||
|
||||
def _tensor_to_argdef(t, name=None, used_names=None):
|
||||
"""Convert tensor t to an argdef, with a specified name or a unique name."""
|
||||
arg = op_def_pb2.OpDef.ArgDef()
|
||||
if name is None:
|
||||
arg.name = _make_argname_from_tensor_name(t.name)
|
||||
if used_names is not None:
|
||||
if arg.name in used_names:
|
||||
i = 0
|
||||
while True:
|
||||
new_name = "%s_U%d" % (arg.name, i)
|
||||
if new_name not in used_names:
|
||||
arg.name = new_name
|
||||
break
|
||||
i += 1
|
||||
used_names.add(arg.name)
|
||||
else:
|
||||
arg.name = name
|
||||
arg.type = t.dtype.as_datatype_enum
|
||||
return arg
|
||||
|
||||
|
||||
def _is_in_placeholders(op, func_arg_placeholders):
|
||||
"""Checks whether any output of this op is in func_arg_placeholders."""
|
||||
return op.values() and any(x.name in func_arg_placeholders
|
||||
for x in op.values())
|
||||
|
||||
|
||||
def _get_node_def(op):
|
||||
return op._node_def # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _get_op_def(op):
|
||||
return op.op_def or op_def_registry.get_registered_ops()[op.type]
|
||||
|
||||
|
||||
def _create_input_dict(function_graph,
|
||||
func_arg_placeholders,
|
||||
initial_value=None):
|
||||
"""Create a mapping from graph tensor names to function tensor names."""
|
||||
if initial_value is None:
|
||||
input_dict = {}
|
||||
else:
|
||||
input_dict = dict(initial_value)
|
||||
for op in function_graph.get_operations():
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
input_dict[op.name] = op.name
|
||||
else:
|
||||
op_def = _get_op_def(op)
|
||||
attrs = _get_node_def(op).attr
|
||||
o = 0
|
||||
for arg_def in op_def.output_arg:
|
||||
if arg_def.number_attr:
|
||||
num = attrs[arg_def.number_attr].i
|
||||
elif arg_def.type_list_attr:
|
||||
num = len(attrs[arg_def.type_list_attr].list.type)
|
||||
else:
|
||||
num = 1
|
||||
for i in range(num):
|
||||
result = "%s:%s:%d" % (op.name, arg_def.name, i)
|
||||
input_dict[op.values()[o].name] = result
|
||||
if o == 0:
|
||||
input_dict[op.name] = result
|
||||
o += 1
|
||||
return input_dict
|
||||
|
||||
|
||||
def _add_op_node(op, func, input_dict):
|
||||
"""Converts an op to a function def node and add it to `func`."""
|
||||
# Add an entry in func.node_def
|
||||
|
||||
# Note that extend() makes a copy in this case, see:
|
||||
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
|
||||
func.node_def.extend([_get_node_def(op)])
|
||||
node_def = func.node_def[-1]
|
||||
for i in range(len(node_def.input)):
|
||||
if not node_def.input[i].startswith("^"):
|
||||
assert node_def.input[i] in input_dict, ("%s missing from %s" %
|
||||
(node_def.input[i],
|
||||
input_dict.items()))
|
||||
node_def.input[i] = input_dict[node_def.input[i]]
|
||||
|
||||
|
||||
def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||||
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||||
|
||||
This method creates a [`FunctionDef`](
|
||||
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
|
||||
protocol buffer that contains all the ops in `operations`. The
|
||||
operations become the body of the function.
|
||||
|
||||
The arguments `inputs` and `outputs` will be listed as the inputs
|
||||
and outputs tensors of the function. They must be lists of
|
||||
tensors present in the graph. The lists can optionally be empty.
|
||||
|
||||
Args:
|
||||
graph: Graph.
|
||||
operations: the operations to put in the function. Must be a subset of
|
||||
the operations in the graph.
|
||||
inputs: List of tensors. Inputs to the function.
|
||||
outputs: List of tensors. Outputs of the function.
|
||||
out_names: Optional list of string names for the outputs.
|
||||
|
||||
Returns:
|
||||
A FunctionDef protocol buffer.
|
||||
|
||||
Raises:
|
||||
ValueError: if out_names is specified and the wrong length.
|
||||
"""
|
||||
func = function_pb2.FunctionDef()
|
||||
func.signature.name = "_"
|
||||
used_names = set()
|
||||
func.signature.input_arg.extend(
|
||||
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
|
||||
# Initializes the input map with all placeholder input tensors.
|
||||
initial_dict = {}
|
||||
for o, m in zip(inputs, func.signature.input_arg):
|
||||
initial_dict[o.name] = m.name
|
||||
if out_names is None:
|
||||
used_names = set()
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
||||
elif len(outputs) != len(out_names):
|
||||
raise ValueError(
|
||||
"Length of out_names (%d) does not match number of outputs (%d): %s" %
|
||||
(len(out_names), len(outputs), ", ".join(out_names)))
|
||||
elif len(out_names) != len(set(out_names)):
|
||||
raise ValueError(
|
||||
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
|
||||
else:
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||
func_arg_placeholders = set([i.name for i in inputs])
|
||||
input_dict = _create_input_dict(graph, func_arg_placeholders,
|
||||
initial_value=initial_dict)
|
||||
|
||||
for op in operations:
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
continue
|
||||
_add_op_node(op, func, input_dict)
|
||||
|
||||
if out_names is None:
|
||||
for index, o in enumerate(outputs):
|
||||
k = func.signature.output_arg[index].name
|
||||
func.ret[k] = input_dict[o.name]
|
||||
else:
|
||||
for o, n in zip(outputs, out_names):
|
||||
func.ret[n] = input_dict[o.name]
|
||||
|
||||
return func
|
Loading…
Reference in New Issue
Block a user