diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 14302874441..82aa1f557ef 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -395,9 +395,10 @@ tf_cc_test( # :toco is the main public command-line tool exposing the functionality # of the :toco_tooling library. -tf_cc_binary( - name = "toco", - srcs = ["toco.cc"], +cc_library( + name = "toco_convert", + srcs = ["toco_convert.cc"], + hdrs = ["toco_convert.h"], visibility = ["//visibility:public"], deps = [ ":model", @@ -416,6 +417,51 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "toco", + srcs = ["toco.cc"], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_cmdline_flags", + ":toco_convert", + ":toco_flags_proto_cc", + ":toco_port", + ":toco_tooling", + ":types_proto_cc", + "@com_google_absl//absl/strings", + "//tensorflow/core:lib", + # We cannot embed the core:ops dependency directly into :toco_tooling as + # it can conflict with downstream deps when toco is used as a library. + "//tensorflow/core:ops", + ], +) + +tf_cc_test( + name = "toco_convert_test", + srcs = ["toco_convert_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_cmdline_flags", + ":toco_convert", + ":toco_flags_proto_cc", + ":toco_port", + ":toco_tooling", + ":types_proto_cc", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "//tensorflow/core:lib", + # We cannot embed the core:ops dependency directly into :toco_tooling as + # it can conflict with downstream deps when toco is used as a library. + "//tensorflow/core:ops", + ], +) + tf_cc_test( name = "toco_port_test", srcs = ["toco_port_test.cc"], diff --git a/tensorflow/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc index 1752745aaee..bdc3a5b0fb4 100644 --- a/tensorflow/lite/toco/export_tensorflow.cc +++ b/tensorflow/lite/toco/export_tensorflow.cc @@ -48,7 +48,8 @@ using tensorflow::TensorProto; namespace toco { namespace { -tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) { +tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type, + const string& error_location) { switch (data_type) { case ArrayDataType::kBool: return tensorflow::DT_BOOL; @@ -66,14 +67,21 @@ tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) { return tensorflow::DT_COMPLEX64; default: case ArrayDataType::kNone: - LOG(FATAL) << "Unsupported data type: " << static_cast(data_type); + LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type) + << "' in " << error_location; return tensorflow::DT_INVALID; } } +tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type, + const string& op_name) { + return GetTensorFlowDataType(data_type, "op '" + op_name + "'"); +} + tensorflow::DataType GetTensorFlowDataType(const Model& model, const string& array_name) { - return GetTensorFlowDataType(model.GetArray(array_name).data_type); + return GetTensorFlowDataType(model.GetArray(array_name).data_type, + "array '" + array_name + "'"); } // TensorFlow sometimes forbids what it calls "legacy scalars", @@ -1285,7 +1293,7 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, *range_op->add_input() = src_op.inputs[1]; *range_op->add_input() = src_op.inputs[2]; (*range_op->mutable_attr())["Tidx"].set_type( - GetTensorFlowDataType(src_op.dtype)); + GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0])); } void ConvertPackOperator(const Model& model, const PackOperator& src_op, @@ -1298,7 +1306,8 @@ void ConvertPackOperator(const Model& model, const PackOperator& src_op, } (*pack_op->mutable_attr())["axis"].set_i(src_op.axis); (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size()); - (*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype)); + (*pack_op->mutable_attr())["T"].set_type( + GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0])); } void ConvertFillOperator(const Model& model, const FillOperator& src_op, @@ -1887,7 +1896,7 @@ void ConvertRandomUniformOperator(const Model& model, GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(shape_type); (*new_op->mutable_attr())["dtype"].set_type( - GetTensorFlowDataType(src_op.dtype)); + GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0])); (*new_op->mutable_attr())["seed"].set_i(src_op.seed); (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2); } diff --git a/tensorflow/lite/toco/toco.cc b/tensorflow/lite/toco/toco.cc index 9740015850a..4a3d6a58487 100644 --- a/tensorflow/lite/toco/toco.cc +++ b/tensorflow/lite/toco/toco.cc @@ -16,87 +16,9 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" -#include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_cmdline_flags.h" -#include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_cmdline_flags.h" -#include "tensorflow/lite/toco/toco_flags.pb.h" -#include "tensorflow/lite/toco/toco_port.h" -#include "tensorflow/lite/toco/toco_tooling.h" -#include "tensorflow/lite/toco/toco_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { -namespace { - -// Checks the permissions of the output file to ensure it is writeable. -void CheckOutputFilePermissions(const Arg& output_file) { - QCHECK(output_file.specified()) << "Missing required flag --output_file.\n"; - QCHECK(port::file::Writable(output_file.value()).ok()) - << "Specified output_file is not writable: " << output_file.value() - << ".\n"; -} - -// Checks the permissions of the frozen model file. -void CheckFrozenModelPermissions(const Arg& input_file) { - QCHECK(input_file.specified()) << "Missing required flag --input_file.\n"; - QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) - << "Specified input_file does not exist: " << input_file.value() << ".\n"; - QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) - << "Specified input_file exists, but is not readable: " - << input_file.value() << ".\n"; -} - -// Reads the contents of the GraphDef from either the frozen graph file or the -// SavedModel directory. If it reads the SavedModel directory, it updates the -// ModelFlags and TocoFlags accordingly. -void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents) { - port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); - - // Ensure savedmodel_directory is not set. - QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) - << "Use `tensorflow/lite/python/tflite_convert` script with " - << "SavedModel directories.\n"; - - // Checks the input file permissions and reads the contents. - CheckFrozenModelPermissions(parsed_toco_flags.input_file); - CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), - graph_def_contents, port::file::Defaults()) - .ok()); -} - -tensorflow::Status ToolMain(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags) { - ModelFlags model_flags; - ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); - - TocoFlags toco_flags; - ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); - - string graph_def_contents; - ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags, - &model_flags, &graph_def_contents); - CheckOutputFilePermissions(parsed_toco_flags.output_file); - - std::unique_ptr model = - Import(toco_flags, model_flags, graph_def_contents); - Transform(toco_flags, model.get()); - string output_file_contents; - TF_RETURN_IF_ERROR(Export(toco_flags, *model, toco_flags.allow_custom_ops(), - &output_file_contents)); - TF_RETURN_IF_ERROR( - port::file::SetContents(parsed_toco_flags.output_file.value(), - output_file_contents, port::file::Defaults())); - return tensorflow::Status(); -} - -} // namespace -} // namespace toco +#include "tensorflow/lite/toco/toco_convert.h" int main(int argc, char** argv) { toco::string msg; @@ -126,6 +48,6 @@ int main(int argc, char** argv) { return 1; } toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true); - auto status = toco::ToolMain(parsed_toco_flags, parsed_model_flags); + auto status = toco::Convert(parsed_toco_flags, parsed_model_flags); return status.ok() ? 0 : -1; } diff --git a/tensorflow/lite/toco/toco_convert.cc b/tensorflow/lite/toco/toco_convert.cc new file mode 100644 index 00000000000..28e7b10ecd0 --- /dev/null +++ b/tensorflow/lite/toco/toco_convert.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_cmdline_flags.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_tooling.h" +#include "tensorflow/lite/toco/toco_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace { + +// Checks the permissions of the output file to ensure it is writeable. +void CheckOutputFilePermissions(const Arg& output_file) { + QCHECK(output_file.specified()) << "Missing required flag --output_file.\n"; + QCHECK(port::file::Writable(output_file.value()).ok()) + << "Specified output_file is not writable: " << output_file.value() + << ".\n"; +} + +// Checks the permissions of the frozen model file. +void CheckFrozenModelPermissions(const Arg& input_file) { + QCHECK(input_file.specified()) << "Missing required flag --input_file.\n"; + QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) + << "Specified input_file does not exist: " << input_file.value() << ".\n"; + QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) + << "Specified input_file exists, but is not readable: " + << input_file.value() << ".\n"; +} + +// Reads the contents of the GraphDef from either the frozen graph file or the +// SavedModel directory. If it reads the SavedModel directory, it updates the +// ModelFlags and TocoFlags accordingly. +void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); + + // Ensure savedmodel_directory is not set. + QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) + << "Use `tensorflow/lite/python/tflite_convert` script with " + << "SavedModel directories.\n"; + + // Checks the input file permissions and reads the contents. + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); +} +} // namespace + +tensorflow::Status Convert(const string& graph_def_contents, + const TocoFlags& toco_flags, + const ModelFlags& model_flags, + string* output_file_contents) { + std::unique_ptr model = + Import(toco_flags, model_flags, graph_def_contents); + Transform(toco_flags, model.get()); + return Export(toco_flags, *model, toco_flags.allow_custom_ops(), + output_file_contents); +} + +tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags) { + ModelFlags model_flags; + ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); + + TocoFlags toco_flags; + ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); + + string graph_def_contents; + ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags, + &model_flags, &graph_def_contents); + CheckOutputFilePermissions(parsed_toco_flags.output_file); + + string output_file_contents; + TF_RETURN_IF_ERROR(Convert(graph_def_contents, toco_flags, model_flags, + &output_file_contents)); + + TF_RETURN_IF_ERROR( + port::file::SetContents(parsed_toco_flags.output_file.value(), + output_file_contents, port::file::Defaults())); + return tensorflow::Status(); +} + +} // namespace toco diff --git a/tensorflow/lite/toco/toco_convert.h b/tensorflow/lite/toco/toco_convert.h new file mode 100644 index 00000000000..ebbd336d3f5 --- /dev/null +++ b/tensorflow/lite/toco/toco_convert.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_ +#define TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/lite/toco/args.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" + +namespace toco { + +tensorflow::Status Convert(const string& graph_def_contents, + const TocoFlags& toco_flags, + const ModelFlags& model_flags, + string* output_file_contents); + +tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags); +} // namespace toco + +#endif // TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_ diff --git a/tensorflow/lite/toco/toco_convert_test.cc b/tensorflow/lite/toco/toco_convert_test.cc new file mode 100644 index 00000000000..c3c440db943 --- /dev/null +++ b/tensorflow/lite/toco/toco_convert_test.cc @@ -0,0 +1,173 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/toco/toco_convert.h" +#include +#include + +namespace toco { +namespace { + +TEST(TocoTest, MissingInputFile) { + ParsedTocoFlags toco_flags; + ParsedModelFlags model_flags; + EXPECT_DEATH(Convert(toco_flags, model_flags).ok(), + "Missing required flag --input_file"); +} + +TEST(TocoTest, BadInputFormat) { + TocoFlags toco_flags; + ModelFlags model_flags; + + string input; + string output; + + EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(), + "Unhandled input_format='FILE_FORMAT_UNKNOWN'"); +} + +TEST(TocoTest, MissingOuputArrays) { + TocoFlags toco_flags; + ModelFlags model_flags; + + toco_flags.set_input_format(TENSORFLOW_GRAPHDEF); + string input; + string output; + + EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(), + "This model does not define output arrays, so a --output_arrays " + "flag must be given on the command-line"); +} + +TEST(TocoTest, BadOutputArray) { + TocoFlags toco_flags; + ModelFlags model_flags; + + toco_flags.set_input_format(TENSORFLOW_GRAPHDEF); + model_flags.add_output_arrays("output1"); + string input; + string output; + + EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(), + "Specified output array .output1. is not produced by any op " + "in this graph. Is it a typo. To silence this message, pass " + "this flag: allow_nonexistent_arrays"); +} + +TEST(TocoTest, BadOutputFormat) { + TocoFlags toco_flags; + ModelFlags model_flags; + + toco_flags.set_input_format(TENSORFLOW_GRAPHDEF); + model_flags.add_output_arrays("output1"); + string input = R"GraphDef( + node { + name: "output1" + input: "input1" + input: "input2" + op: "Sub" + attr { key: "T" value { type: DT_FLOAT } } + } + )GraphDef"; + + string output; + + EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(), + "Unhandled output_format='FILE_FORMAT_UNKNOWN'"); +} + +TEST(TocoTest, SimpleFloatModel) { + TocoFlags toco_flags; + ModelFlags model_flags; + + toco_flags.set_input_format(TENSORFLOW_GRAPHDEF); + toco_flags.set_output_format(TENSORFLOW_GRAPHDEF); + + // Inputs are automatically selected (but that might not be a good idea). + model_flags.add_output_arrays("output1"); + string input = R"GraphDef( + node { + name: "input1" + op: "Placeholder" + attr { key: "dtype" value { type: DT_INT64 } } + } + node { + name: "input2" + op: "Placeholder" + attr { key: "dtype" value { type: DT_INT64 } } + } + node { + name: "output1" + input: "input1" + input: "input2" + op: "Sub" + attr { key: "T" value { type: DT_FLOAT } } + } + )GraphDef"; + + string output; + EXPECT_TRUE(Convert(input, toco_flags, model_flags, &output).ok()); + EXPECT_TRUE(!output.empty()); +} + +TEST(TocoTest, TransientStringTensors) { + TocoFlags toco_flags; + ModelFlags model_flags; + + toco_flags.set_input_format(TENSORFLOW_GRAPHDEF); + + // We need to do a couple of things to trigger the transient array + // initialization code: output format must support memory planning, and the + // input array must have a shape. + toco_flags.set_output_format(TFLITE); + + model_flags.add_output_arrays("output1"); + string input = R"GraphDef( + node { + name: "input1" + op: "Placeholder" + attr { key: "dtype" value { type: DT_STRING } } + attr { key: "shape" value { shape { dim { size:1 }}}} + } + node { + name: "indices1" + op: "Placeholder" + attr { key: "dtype" value { type: DT_INT64 } } + } + node { + name: "intermediate1" + op: "Gather" + input: "input1" + input: "indices1" + attr { key: "Tparams" value { type: DT_STRING } } + attr { key: "Tindices" value { type: DT_INT64 } } + } + node { + name: "output1" + op: "Gather" + input: "intermediate1" + input: "indices2" + attr { key: "Tparams" value { type: DT_STRING } } + attr { key: "Tindices" value { type: DT_INT64 } } + } + )GraphDef"; + + string output; + + EXPECT_TRUE(Convert(input, toco_flags, model_flags, &output).ok()); + EXPECT_TRUE(!output.empty()); +} + +} // namespace +} // namespace toco diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index 5f96e833fbf..d8b111d0379 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -210,7 +210,8 @@ std::unique_ptr Import(const TocoFlags& toco_flags, CheckInvariants(*model); break; default: - LOG(FATAL) << "Unhandled input_format"; + LOG(FATAL) << "Unhandled input_format='" + << FileFormat_Name(toco_flags.input_format()) << "'"; } LogDump(kLogLevelModelChanged, "AT IMPORT", *model); @@ -424,7 +425,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, DumpGraphviz(model, output_file_contents); break; default: - LOG(FATAL) << "Unhandled output_format"; + LOG(FATAL) << "Unhandled output_format='" + << FileFormat_Name(toco_flags.output_format()) << "'"; } return tensorflow::Status(); } diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index cff387782f8..084169548e2 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -1770,6 +1770,14 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { if (!array->has_shape()) { return false; } + + // The size of string tensors is rarely known ahead of time, so all transient + // tensors of this type will need to be dynamically allocated. + if (array->final_data_type == ArrayDataType::kString || + array->data_type == ArrayDataType::kString) { + return false; + } + return true; }