Python wrapper for the Graph Transform Tool

Change: 145009203
This commit is contained in:
Pete Warden 2017-01-19 14:21:59 -08:00 committed by TensorFlower Gardener
parent fa82a88606
commit f736991fd3
6 changed files with 251 additions and 0 deletions

View File

@ -2292,6 +2292,7 @@ tf_py_wrap_cc(
"util/port.i",
"util/py_checkpoint_reader.i",
"util/stat_summarizer.i",
"util/transform_graph.i",
],
deps = [
":cpp_shape_inference",
@ -2310,6 +2311,7 @@ tf_py_wrap_cc(
"//tensorflow/core:lib",
"//tensorflow/core/debug",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/tools/tfprof/internal:print_model_analysis",
"//util/python:python_headers",
] + tf_additional_lib_deps() + tf_additional_plugin_deps(),

View File

@ -38,3 +38,5 @@ limitations under the License.
%include "tensorflow/python/framework/cpp_shape_inference.i"
%include "tensorflow/python/util/kernel_registry.i"
%include "tensorflow/python/util/transform_graph.i"

View File

@ -0,0 +1,86 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include <std_string.i>
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/python/lib/core/py_func.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/tools/graph_transforms/transform_graph.h"
%}
%ignoreall
%unignore tensorflow;
%unignore TransformGraphWithStringInputs;
%{
string TransformGraphWithStringInputs(string graph_def_string,
string inputs_string,
string outputs_string,
string transforms_string,
TF_Status* out_status) {
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(graph_def_string)) {
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
"Couldn't interpret input as a GraphDef"));
return "";
}
tensorflow::graph_transforms::TransformParameters params_list;
tensorflow::Status parse_status =
tensorflow::graph_transforms::ParseTransformParameters(
transforms_string, &params_list);
if (!parse_status.ok()) {
tensorflow::Set_TF_Status_from_Status(out_status, parse_status);
return "";
}
std::vector<string> inputs = tensorflow::str_util::Split(inputs_string, ',');
std::vector<string> outputs =
tensorflow::str_util::Split(outputs_string, ',');
tensorflow::Status transform_status =
tensorflow::graph_transforms::TransformGraph(
inputs, outputs, params_list, &graph_def);
if (!transform_status.ok()) {
tensorflow::Set_TF_Status_from_Status(out_status, transform_status);
return "";
}
string result;
if (!graph_def.SerializeToString(&result)) {
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
"Couldn't serialize output as a GraphDef"));
return "";
}
Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
return result;
}
%}
string TransformGraphWithStringInputs(string graph_def_string,
string inputs_string,
string outputs_string,
string transforms_string,
TF_Status* out_status);
%unignoreall

View File

@ -9,6 +9,7 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cc_test",
"tf_py_test",
)
exports_files(["LICENSE"])
@ -234,3 +235,25 @@ cc_binary(
"//tensorflow/core:lib",
],
)
py_library(
name = "transform_graph_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow/python:pywrap_tensorflow"],
)
tf_py_test(
name = "transform_graph_py_test",
size = "small",
srcs = ["python/transform_graph_test.py"],
additional_deps = [
":transform_graph_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
],
main = "python/transform_graph_test.py",
)

View File

@ -0,0 +1,53 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exposes the Python wrapper for graph transforms."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.pywrap_tensorflow import TransformGraphWithStringInputs
def TransformGraph(input_graph_def, inputs, outputs, transforms):
"""Python wrapper for the Graph Transform Tool.
Gives access to all graph transforms available through the command line tool.
See documentation at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md
for full details of the options available.
Args:
input_graph_def: GraphDef object containing a model to be transformed.
inputs: List of node names for the model inputs.
outputs: List of node names for the model outputs.
transforms: List of strings containing transform names and parameters.
Returns:
New GraphDef with transforms applied.
"""
input_graph_def_string = input_graph_def.SerializeToString()
inputs_string = ",".join(inputs)
outputs_string = ",".join(outputs)
transforms_string = " ".join(transforms)
with errors.raise_exception_on_not_ok_status() as status:
output_graph_def_string = TransformGraphWithStringInputs(
input_graph_def_string, inputs_string, outputs_string,
transforms_string, status)
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(output_graph_def_string)
return output_graph_def

View File

@ -0,0 +1,85 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for StatSummarizer Python wrapper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import test
from tensorflow.tools.graph_transforms import TransformGraph
class TransformGraphTest(test.TestCase):
# This test constructs a graph with a relu op that's not used by the normal
# inference path, and then tests that the strip_unused transform removes it as
# expected.
def testTransformGraph(self):
input_graph_def = graph_pb2.GraphDef()
const_op1 = input_graph_def.node.add()
const_op1.op = "Const"
const_op1.name = "const_op1"
const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
type=dtypes.float32.as_datatype_enum))
const_op1.attr["value"].CopyFrom(
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
[1, 2], dtypes.float32, [1, 2])))
const_op2 = input_graph_def.node.add()
const_op2.op = "Const"
const_op2.name = "const_op2"
const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
type=dtypes.float32.as_datatype_enum))
const_op2.attr["value"].CopyFrom(
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
[3, 4], dtypes.float32, [1, 2])))
# Create an add that has two constants as inputs.
add_op = input_graph_def.node.add()
add_op.op = "Add"
add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
type=dtypes.float32.as_datatype_enum))
add_op.name = "add_op"
add_op.input.extend(["const_op1", "const_op2"])
# Create a relu that reads from the add.
relu_op = input_graph_def.node.add()
relu_op.op = "Relu"
relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
type=dtypes.float32.as_datatype_enum))
relu_op.name = "relu_op"
relu_op.input.extend(["add_op"])
# We're specifying that add_op is the final output, and so the relu isn't
# needed.
input_names = []
output_names = ["add_op"]
transforms = ["strip_unused_nodes"]
transformed_graph_def = TransformGraph(input_graph_def, input_names,
output_names, transforms)
# We expect that the relu is no longer present after running the transform.
for node in transformed_graph_def.node:
self.assertNotEqual("Relu", node.op)
if __name__ == "__main__":
test.main()