Allow intermediate string tensors
PiperOrigin-RevId: 222298206
This commit is contained in:
parent
5c4efd9460
commit
760ef2cc7e
@ -395,9 +395,10 @@ tf_cc_test(
|
||||
|
||||
# :toco is the main public command-line tool exposing the functionality
|
||||
# of the :toco_tooling library.
|
||||
tf_cc_binary(
|
||||
name = "toco",
|
||||
srcs = ["toco.cc"],
|
||||
cc_library(
|
||||
name = "toco_convert",
|
||||
srcs = ["toco_convert.cc"],
|
||||
hdrs = ["toco_convert.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":model",
|
||||
@ -416,6 +417,51 @@ tf_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "toco",
|
||||
srcs = ["toco.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":model",
|
||||
":model_cmdline_flags",
|
||||
":model_flags_proto_cc",
|
||||
":toco_cmdline_flags",
|
||||
":toco_convert",
|
||||
":toco_flags_proto_cc",
|
||||
":toco_port",
|
||||
":toco_tooling",
|
||||
":types_proto_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:lib",
|
||||
# We cannot embed the core:ops dependency directly into :toco_tooling as
|
||||
# it can conflict with downstream deps when toco is used as a library.
|
||||
"//tensorflow/core:ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "toco_convert_test",
|
||||
srcs = ["toco_convert_test.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":model",
|
||||
":model_cmdline_flags",
|
||||
":model_flags_proto_cc",
|
||||
":toco_cmdline_flags",
|
||||
":toco_convert",
|
||||
":toco_flags_proto_cc",
|
||||
":toco_port",
|
||||
":toco_tooling",
|
||||
":types_proto_cc",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:lib",
|
||||
# We cannot embed the core:ops dependency directly into :toco_tooling as
|
||||
# it can conflict with downstream deps when toco is used as a library.
|
||||
"//tensorflow/core:ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "toco_port_test",
|
||||
srcs = ["toco_port_test.cc"],
|
||||
|
@ -48,7 +48,8 @@ using tensorflow::TensorProto;
|
||||
namespace toco {
|
||||
namespace {
|
||||
|
||||
tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) {
|
||||
tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
|
||||
const string& error_location) {
|
||||
switch (data_type) {
|
||||
case ArrayDataType::kBool:
|
||||
return tensorflow::DT_BOOL;
|
||||
@ -66,14 +67,21 @@ tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) {
|
||||
return tensorflow::DT_COMPLEX64;
|
||||
default:
|
||||
case ArrayDataType::kNone:
|
||||
LOG(FATAL) << "Unsupported data type: " << static_cast<int>(data_type);
|
||||
LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
|
||||
<< "' in " << error_location;
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
|
||||
const string& op_name) {
|
||||
return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
|
||||
}
|
||||
|
||||
tensorflow::DataType GetTensorFlowDataType(const Model& model,
|
||||
const string& array_name) {
|
||||
return GetTensorFlowDataType(model.GetArray(array_name).data_type);
|
||||
return GetTensorFlowDataType(model.GetArray(array_name).data_type,
|
||||
"array '" + array_name + "'");
|
||||
}
|
||||
|
||||
// TensorFlow sometimes forbids what it calls "legacy scalars",
|
||||
@ -1285,7 +1293,7 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
|
||||
*range_op->add_input() = src_op.inputs[1];
|
||||
*range_op->add_input() = src_op.inputs[2];
|
||||
(*range_op->mutable_attr())["Tidx"].set_type(
|
||||
GetTensorFlowDataType(src_op.dtype));
|
||||
GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
|
||||
}
|
||||
|
||||
void ConvertPackOperator(const Model& model, const PackOperator& src_op,
|
||||
@ -1298,7 +1306,8 @@ void ConvertPackOperator(const Model& model, const PackOperator& src_op,
|
||||
}
|
||||
(*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
|
||||
(*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
|
||||
(*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype));
|
||||
(*pack_op->mutable_attr())["T"].set_type(
|
||||
GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
|
||||
}
|
||||
|
||||
void ConvertFillOperator(const Model& model, const FillOperator& src_op,
|
||||
@ -1887,7 +1896,7 @@ void ConvertRandomUniformOperator(const Model& model,
|
||||
GetTensorFlowDataType(model, src_op.inputs[0]);
|
||||
(*new_op->mutable_attr())["T"].set_type(shape_type);
|
||||
(*new_op->mutable_attr())["dtype"].set_type(
|
||||
GetTensorFlowDataType(src_op.dtype));
|
||||
GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
|
||||
(*new_op->mutable_attr())["seed"].set_i(src_op.seed);
|
||||
(*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
|
||||
}
|
||||
|
@ -16,87 +16,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/model_cmdline_flags.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_port.h"
|
||||
#include "tensorflow/lite/toco/toco_tooling.h"
|
||||
#include "tensorflow/lite/toco/toco_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace toco {
|
||||
namespace {
|
||||
|
||||
// Checks the permissions of the output file to ensure it is writeable.
|
||||
void CheckOutputFilePermissions(const Arg<string>& output_file) {
|
||||
QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
|
||||
QCHECK(port::file::Writable(output_file.value()).ok())
|
||||
<< "Specified output_file is not writable: " << output_file.value()
|
||||
<< ".\n";
|
||||
}
|
||||
|
||||
// Checks the permissions of the frozen model file.
|
||||
void CheckFrozenModelPermissions(const Arg<string>& input_file) {
|
||||
QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
|
||||
QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
|
||||
<< "Specified input_file does not exist: " << input_file.value() << ".\n";
|
||||
QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
|
||||
<< "Specified input_file exists, but is not readable: "
|
||||
<< input_file.value() << ".\n";
|
||||
}
|
||||
|
||||
// Reads the contents of the GraphDef from either the frozen graph file or the
|
||||
// SavedModel directory. If it reads the SavedModel directory, it updates the
|
||||
// ModelFlags and TocoFlags accordingly.
|
||||
void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
|
||||
const ParsedModelFlags& parsed_model_flags,
|
||||
TocoFlags* toco_flags, ModelFlags* model_flags,
|
||||
string* graph_def_contents) {
|
||||
port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
|
||||
|
||||
// Ensure savedmodel_directory is not set.
|
||||
QCHECK(!parsed_toco_flags.savedmodel_directory.specified())
|
||||
<< "Use `tensorflow/lite/python/tflite_convert` script with "
|
||||
<< "SavedModel directories.\n";
|
||||
|
||||
// Checks the input file permissions and reads the contents.
|
||||
CheckFrozenModelPermissions(parsed_toco_flags.input_file);
|
||||
CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
|
||||
graph_def_contents, port::file::Defaults())
|
||||
.ok());
|
||||
}
|
||||
|
||||
tensorflow::Status ToolMain(const ParsedTocoFlags& parsed_toco_flags,
|
||||
const ParsedModelFlags& parsed_model_flags) {
|
||||
ModelFlags model_flags;
|
||||
ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
|
||||
|
||||
TocoFlags toco_flags;
|
||||
ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
|
||||
|
||||
string graph_def_contents;
|
||||
ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
|
||||
&model_flags, &graph_def_contents);
|
||||
CheckOutputFilePermissions(parsed_toco_flags.output_file);
|
||||
|
||||
std::unique_ptr<Model> model =
|
||||
Import(toco_flags, model_flags, graph_def_contents);
|
||||
Transform(toco_flags, model.get());
|
||||
string output_file_contents;
|
||||
TF_RETURN_IF_ERROR(Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||
&output_file_contents));
|
||||
TF_RETURN_IF_ERROR(
|
||||
port::file::SetContents(parsed_toco_flags.output_file.value(),
|
||||
output_file_contents, port::file::Defaults()));
|
||||
return tensorflow::Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace toco
|
||||
#include "tensorflow/lite/toco/toco_convert.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
toco::string msg;
|
||||
@ -126,6 +48,6 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true);
|
||||
auto status = toco::ToolMain(parsed_toco_flags, parsed_model_flags);
|
||||
auto status = toco::Convert(parsed_toco_flags, parsed_model_flags);
|
||||
return status.ok() ? 0 : -1;
|
||||
}
|
||||
|
108
tensorflow/lite/toco/toco_convert.cc
Normal file
108
tensorflow/lite/toco/toco_convert.cc
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/model_cmdline_flags.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_port.h"
|
||||
#include "tensorflow/lite/toco/toco_tooling.h"
|
||||
#include "tensorflow/lite/toco/toco_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace toco {
|
||||
namespace {
|
||||
|
||||
// Checks the permissions of the output file to ensure it is writeable.
|
||||
void CheckOutputFilePermissions(const Arg<string>& output_file) {
|
||||
QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
|
||||
QCHECK(port::file::Writable(output_file.value()).ok())
|
||||
<< "Specified output_file is not writable: " << output_file.value()
|
||||
<< ".\n";
|
||||
}
|
||||
|
||||
// Checks the permissions of the frozen model file.
|
||||
void CheckFrozenModelPermissions(const Arg<string>& input_file) {
|
||||
QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
|
||||
QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
|
||||
<< "Specified input_file does not exist: " << input_file.value() << ".\n";
|
||||
QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
|
||||
<< "Specified input_file exists, but is not readable: "
|
||||
<< input_file.value() << ".\n";
|
||||
}
|
||||
|
||||
// Reads the contents of the GraphDef from either the frozen graph file or the
|
||||
// SavedModel directory. If it reads the SavedModel directory, it updates the
|
||||
// ModelFlags and TocoFlags accordingly.
|
||||
void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
|
||||
const ParsedModelFlags& parsed_model_flags,
|
||||
TocoFlags* toco_flags, ModelFlags* model_flags,
|
||||
string* graph_def_contents) {
|
||||
port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
|
||||
|
||||
// Ensure savedmodel_directory is not set.
|
||||
QCHECK(!parsed_toco_flags.savedmodel_directory.specified())
|
||||
<< "Use `tensorflow/lite/python/tflite_convert` script with "
|
||||
<< "SavedModel directories.\n";
|
||||
|
||||
// Checks the input file permissions and reads the contents.
|
||||
CheckFrozenModelPermissions(parsed_toco_flags.input_file);
|
||||
CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
|
||||
graph_def_contents, port::file::Defaults())
|
||||
.ok());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
tensorflow::Status Convert(const string& graph_def_contents,
|
||||
const TocoFlags& toco_flags,
|
||||
const ModelFlags& model_flags,
|
||||
string* output_file_contents) {
|
||||
std::unique_ptr<Model> model =
|
||||
Import(toco_flags, model_flags, graph_def_contents);
|
||||
Transform(toco_flags, model.get());
|
||||
return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||
output_file_contents);
|
||||
}
|
||||
|
||||
tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags,
|
||||
const ParsedModelFlags& parsed_model_flags) {
|
||||
ModelFlags model_flags;
|
||||
ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
|
||||
|
||||
TocoFlags toco_flags;
|
||||
ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
|
||||
|
||||
string graph_def_contents;
|
||||
ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
|
||||
&model_flags, &graph_def_contents);
|
||||
CheckOutputFilePermissions(parsed_toco_flags.output_file);
|
||||
|
||||
string output_file_contents;
|
||||
TF_RETURN_IF_ERROR(Convert(graph_def_contents, toco_flags, model_flags,
|
||||
&output_file_contents));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
port::file::SetContents(parsed_toco_flags.output_file.value(),
|
||||
output_file_contents, port::file::Defaults()));
|
||||
return tensorflow::Status();
|
||||
}
|
||||
|
||||
} // namespace toco
|
34
tensorflow/lite/toco/toco_convert.h
Normal file
34
tensorflow/lite/toco/toco_convert.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_
|
||||
#define TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/lite/toco/args.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
tensorflow::Status Convert(const string& graph_def_contents,
|
||||
const TocoFlags& toco_flags,
|
||||
const ModelFlags& model_flags,
|
||||
string* output_file_contents);
|
||||
|
||||
tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags,
|
||||
const ParsedModelFlags& parsed_model_flags);
|
||||
} // namespace toco
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_
|
173
tensorflow/lite/toco/toco_convert_test.cc
Normal file
173
tensorflow/lite/toco/toco_convert_test.cc
Normal file
@ -0,0 +1,173 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/toco/toco_convert.h"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace toco {
|
||||
namespace {
|
||||
|
||||
TEST(TocoTest, MissingInputFile) {
|
||||
ParsedTocoFlags toco_flags;
|
||||
ParsedModelFlags model_flags;
|
||||
EXPECT_DEATH(Convert(toco_flags, model_flags).ok(),
|
||||
"Missing required flag --input_file");
|
||||
}
|
||||
|
||||
TEST(TocoTest, BadInputFormat) {
|
||||
TocoFlags toco_flags;
|
||||
ModelFlags model_flags;
|
||||
|
||||
string input;
|
||||
string output;
|
||||
|
||||
EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
|
||||
"Unhandled input_format='FILE_FORMAT_UNKNOWN'");
|
||||
}
|
||||
|
||||
TEST(TocoTest, MissingOuputArrays) {
|
||||
TocoFlags toco_flags;
|
||||
ModelFlags model_flags;
|
||||
|
||||
toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
|
||||
string input;
|
||||
string output;
|
||||
|
||||
EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
|
||||
"This model does not define output arrays, so a --output_arrays "
|
||||
"flag must be given on the command-line");
|
||||
}
|
||||
|
||||
TEST(TocoTest, BadOutputArray) {
|
||||
TocoFlags toco_flags;
|
||||
ModelFlags model_flags;
|
||||
|
||||
toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
|
||||
model_flags.add_output_arrays("output1");
|
||||
string input;
|
||||
string output;
|
||||
|
||||
EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
|
||||
"Specified output array .output1. is not produced by any op "
|
||||
"in this graph. Is it a typo. To silence this message, pass "
|
||||
"this flag: allow_nonexistent_arrays");
|
||||
}
|
||||
|
||||
TEST(TocoTest, BadOutputFormat) {
|
||||
TocoFlags toco_flags;
|
||||
ModelFlags model_flags;
|
||||
|
||||
toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
|
||||
model_flags.add_output_arrays("output1");
|
||||
string input = R"GraphDef(
|
||||
node {
|
||||
name: "output1"
|
||||
input: "input1"
|
||||
input: "input2"
|
||||
op: "Sub"
|
||||
attr { key: "T" value { type: DT_FLOAT } }
|
||||
}
|
||||
)GraphDef";
|
||||
|
||||
string output;
|
||||
|
||||
EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
|
||||
"Unhandled output_format='FILE_FORMAT_UNKNOWN'");
|
||||
}
|
||||
|
||||
TEST(TocoTest, SimpleFloatModel) {
|
||||
TocoFlags toco_flags;
|
||||
ModelFlags model_flags;
|
||||
|
||||
toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
|
||||
toco_flags.set_output_format(TENSORFLOW_GRAPHDEF);
|
||||
|
||||
// Inputs are automatically selected (but that might not be a good idea).
|
||||
model_flags.add_output_arrays("output1");
|
||||
string input = R"GraphDef(
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_INT64 } }
|
||||
}
|
||||
node {
|
||||
name: "input2"
|
||||
op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_INT64 } }
|
||||
}
|
||||
node {
|
||||
name: "output1"
|
||||
input: "input1"
|
||||
input: "input2"
|
||||
op: "Sub"
|
||||
attr { key: "T" value { type: DT_FLOAT } }
|
||||
}
|
||||
)GraphDef";
|
||||
|
||||
string output;
|
||||
EXPECT_TRUE(Convert(input, toco_flags, model_flags, &output).ok());
|
||||
EXPECT_TRUE(!output.empty());
|
||||
}
|
||||
|
||||
TEST(TocoTest, TransientStringTensors) {
|
||||
TocoFlags toco_flags;
|
||||
ModelFlags model_flags;
|
||||
|
||||
toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
|
||||
|
||||
// We need to do a couple of things to trigger the transient array
|
||||
// initialization code: output format must support memory planning, and the
|
||||
// input array must have a shape.
|
||||
toco_flags.set_output_format(TFLITE);
|
||||
|
||||
model_flags.add_output_arrays("output1");
|
||||
string input = R"GraphDef(
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_STRING } }
|
||||
attr { key: "shape" value { shape { dim { size:1 }}}}
|
||||
}
|
||||
node {
|
||||
name: "indices1"
|
||||
op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_INT64 } }
|
||||
}
|
||||
node {
|
||||
name: "intermediate1"
|
||||
op: "Gather"
|
||||
input: "input1"
|
||||
input: "indices1"
|
||||
attr { key: "Tparams" value { type: DT_STRING } }
|
||||
attr { key: "Tindices" value { type: DT_INT64 } }
|
||||
}
|
||||
node {
|
||||
name: "output1"
|
||||
op: "Gather"
|
||||
input: "intermediate1"
|
||||
input: "indices2"
|
||||
attr { key: "Tparams" value { type: DT_STRING } }
|
||||
attr { key: "Tindices" value { type: DT_INT64 } }
|
||||
}
|
||||
)GraphDef";
|
||||
|
||||
string output;
|
||||
|
||||
EXPECT_TRUE(Convert(input, toco_flags, model_flags, &output).ok());
|
||||
EXPECT_TRUE(!output.empty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace toco
|
@ -210,7 +210,8 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
|
||||
CheckInvariants(*model);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled input_format";
|
||||
LOG(FATAL) << "Unhandled input_format='"
|
||||
<< FileFormat_Name(toco_flags.input_format()) << "'";
|
||||
}
|
||||
|
||||
LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
|
||||
@ -424,7 +425,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
||||
DumpGraphviz(model, output_file_contents);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled output_format";
|
||||
LOG(FATAL) << "Unhandled output_format='"
|
||||
<< FileFormat_Name(toco_flags.output_format()) << "'";
|
||||
}
|
||||
return tensorflow::Status();
|
||||
}
|
||||
|
@ -1770,6 +1770,14 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
|
||||
if (!array->has_shape()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The size of string tensors is rarely known ahead of time, so all transient
|
||||
// tensors of this type will need to be dynamically allocated.
|
||||
if (array->final_data_type == ArrayDataType::kString ||
|
||||
array->data_type == ArrayDataType::kString) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user