108 lines
3.7 KiB
Python
108 lines
3.7 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
r"""Removes unneeded nodes from a GraphDef file.
|
|
|
|
This script is designed to help streamline models, by taking the input and
|
|
output nodes that will be used by an application and figuring out the smallest
|
|
set of operations that are required to run for those arguments. The resulting
|
|
minimal graph is then saved out.
|
|
|
|
The advantages of running this script are:
|
|
- You may be able to shrink the file size.
|
|
- Operations that are unsupported on your platform but still present can be
|
|
safely removed.
|
|
The resulting graph may not be as flexible as the original though, since any
|
|
input nodes that weren't explicitly mentioned may not be accessible any more.
|
|
|
|
An example of command-line usage is:
|
|
bazel build tensorflow/python/tools:strip_unused && \
|
|
bazel-bin/tensorflow/python/tools/strip_unused \
|
|
--input_graph=some_graph_def.pb \
|
|
--output_graph=/tmp/stripped_graph.pb \
|
|
--input_node_names=input0
|
|
--output_node_names=softmax
|
|
|
|
You can also look at strip_unused_test.py for an example of how to use it.
|
|
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import sys
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.platform import app
|
|
from tensorflow.python.tools import strip_unused_lib
|
|
|
|
FLAGS = None
|
|
|
|
|
|
def main(unused_args):
|
|
strip_unused_lib.strip_unused_from_files(FLAGS.input_graph,
|
|
FLAGS.input_binary,
|
|
FLAGS.output_graph,
|
|
FLAGS.output_binary,
|
|
FLAGS.input_node_names,
|
|
FLAGS.output_node_names,
|
|
FLAGS.placeholder_type_enum)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
|
parser.add_argument(
|
|
'--input_graph',
|
|
type=str,
|
|
default='',
|
|
help='TensorFlow \'GraphDef\' file to load.')
|
|
parser.add_argument(
|
|
'--input_binary',
|
|
nargs='?',
|
|
const=True,
|
|
type='bool',
|
|
default=False,
|
|
help='Whether the input files are in binary format.')
|
|
parser.add_argument(
|
|
'--output_graph',
|
|
type=str,
|
|
default='',
|
|
help='Output \'GraphDef\' file name.')
|
|
parser.add_argument(
|
|
'--output_binary',
|
|
nargs='?',
|
|
const=True,
|
|
type='bool',
|
|
default=True,
|
|
help='Whether to write a binary format graph.')
|
|
parser.add_argument(
|
|
'--input_node_names',
|
|
type=str,
|
|
default='',
|
|
help='The name of the input nodes, comma separated.')
|
|
parser.add_argument(
|
|
'--output_node_names',
|
|
type=str,
|
|
default='',
|
|
help='The name of the output nodes, comma separated.')
|
|
parser.add_argument(
|
|
'--placeholder_type_enum',
|
|
type=int,
|
|
default=dtypes.float32.as_datatype_enum,
|
|
help='The AttrValue enum to use for placeholders.')
|
|
FLAGS, unparsed = parser.parse_known_args()
|
|
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|