Qualify uses of std::string

PiperOrigin-RevId: 317319501
Change-Id: Ib75a31ad89fa1a6bda81450f2ab5ba07d7338ada
This commit is contained in:
A. Unique TensorFlower 2020-06-19 09:15:04 -07:00 committed by TensorFlower Gardener
parent 07c54454ee
commit 16cb89bd7b
12 changed files with 133 additions and 126 deletions

View File

@ -52,7 +52,7 @@ using ::tflite::Tensor;
namespace { namespace {
// Check if a TensorFlow Op is a control flow op by its name. // Check if a TensorFlow Op is a control flow op by its name.
bool IsControlFlowOp(const string& tensorflow_op) { bool IsControlFlowOp(const std::string& tensorflow_op) {
// Technically this is equivalent to `::tensorflow::Node::IsControlFlow()`. // Technically this is equivalent to `::tensorflow::Node::IsControlFlow()`.
// It requires to construct a `::tensorflow::Graph` to use that helper // It requires to construct a `::tensorflow::Graph` to use that helper
// function, so we simply hardcode the list of control flow ops here. // function, so we simply hardcode the list of control flow ops here.
@ -68,7 +68,7 @@ bool IsControlFlowOp(const string& tensorflow_op) {
} }
// Check if a TensorFlow Op is unsupported by the Flex runtime. // Check if a TensorFlow Op is unsupported by the Flex runtime.
bool IsUnsupportedFlexOp(const string& tensorflow_op) { bool IsUnsupportedFlexOp(const std::string& tensorflow_op) {
if (IsControlFlowOp(tensorflow_op)) { if (IsControlFlowOp(tensorflow_op)) {
return true; return true;
} }
@ -82,14 +82,14 @@ bool IsUnsupportedFlexOp(const string& tensorflow_op) {
} }
// Map from operator name to TF Lite enum value, for all builtins. // Map from operator name to TF Lite enum value, for all builtins.
const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() { const std::map<std::string, BuiltinOperator>& GetBuiltinOpsMap() {
static std::map<string, BuiltinOperator>* builtin_ops = nullptr; static std::map<std::string, BuiltinOperator>* builtin_ops = nullptr;
if (builtin_ops == nullptr) { if (builtin_ops == nullptr) {
builtin_ops = new std::map<string, BuiltinOperator>(); builtin_ops = new std::map<std::string, BuiltinOperator>();
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
BuiltinOperator op = static_cast<BuiltinOperator>(i); BuiltinOperator op = static_cast<BuiltinOperator>(i);
string name = EnumNameBuiltinOperator(op); std::string name = EnumNameBuiltinOperator(op);
if (op != BuiltinOperator_CUSTOM && !name.empty()) { if (op != BuiltinOperator_CUSTOM && !name.empty()) {
(*builtin_ops)[name] = op; (*builtin_ops)[name] = op;
} }
@ -99,10 +99,10 @@ const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() {
} }
void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
string* file_contents) { std::string* file_contents) {
const uint8_t* buffer = builder.GetBufferPointer(); const uint8_t* buffer = builder.GetBufferPointer();
int size = builder.GetSize(); int size = builder.GetSize();
*file_contents = string(reinterpret_cast<const char*>(buffer), size); *file_contents = std::string(reinterpret_cast<const char*>(buffer), size);
} }
} // Anonymous namespace. } // Anonymous namespace.
@ -115,7 +115,7 @@ OperatorKey::OperatorKey(
bool enable_select_tf_ops) { bool enable_select_tf_ops) {
// Get the op name (by Toco definition). // Get the op name (by Toco definition).
const ::toco::Operator& op = *op_signature.op; const ::toco::Operator& op = *op_signature.op;
string name = HelpfulOperatorTypeName(op); std::string name = HelpfulOperatorTypeName(op);
bool is_builtin = false; bool is_builtin = false;
const auto& builtin_ops = GetBuiltinOpsMap(); const auto& builtin_ops = GetBuiltinOpsMap();
@ -146,7 +146,7 @@ OperatorKey::OperatorKey(
is_flex_op_ = true; is_flex_op_ = true;
flex_tensorflow_op_ = tensorflow_op; flex_tensorflow_op_ = tensorflow_op;
custom_code_ = custom_code_ =
string(::tflite::kFlexCustomCodePrefix) + flex_tensorflow_op_; std::string(::tflite::kFlexCustomCodePrefix) + flex_tensorflow_op_;
} else { } else {
custom_code_ = tensorflow_op; custom_code_ = tensorflow_op;
} }
@ -158,7 +158,7 @@ OperatorKey::OperatorKey(
is_flex_op_ = true; is_flex_op_ = true;
flex_tensorflow_op_ = name; flex_tensorflow_op_ = name;
custom_code_ = custom_code_ =
string(::tflite::kFlexCustomCodePrefix) + flex_tensorflow_op_; std::string(::tflite::kFlexCustomCodePrefix) + flex_tensorflow_op_;
} else { } else {
// If Flex is disabled or the original TensorFlow NodeDef isn't available, // If Flex is disabled or the original TensorFlow NodeDef isn't available,
// we produce a custom op. This gives developers a chance to implement // we produce a custom op. This gives developers a chance to implement
@ -175,7 +175,7 @@ OperatorKey::OperatorKey(
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
// First find a list of unique array names. // First find a list of unique array names.
std::set<string> names; std::set<std::string> names;
for (const auto& array_pair : model.GetArrayMap()) { for (const auto& array_pair : model.GetArrayMap()) {
names.insert(array_pair.first); names.insert(array_pair.first);
} }
@ -218,7 +218,7 @@ Offset<Vector<Offset<Tensor>>> ExportTensors(
std::map<int, Offset<Tensor>> ordered_tensors; std::map<int, Offset<Tensor>> ordered_tensors;
for (const auto& array_pair : model.GetArrayMap()) { for (const auto& array_pair : model.GetArrayMap()) {
const string& tensor_name = array_pair.first; const std::string& tensor_name = array_pair.first;
const toco::Array& array = *array_pair.second; const toco::Array& array = *array_pair.second;
int buffer_index = buffers_to_write->size(); int buffer_index = buffers_to_write->size();
@ -283,7 +283,7 @@ Offset<Vector<int32_t>> ExportOutputTensors(
const Model& model, const details::TensorsMap& tensors_map, const Model& model, const details::TensorsMap& tensors_map,
FlatBufferBuilder* builder) { FlatBufferBuilder* builder) {
std::vector<int32_t> outputs; std::vector<int32_t> outputs;
for (const string& output : model.flags.output_arrays()) { for (const std::string& output : model.flags.output_arrays()) {
outputs.push_back(tensors_map.at(output)); outputs.push_back(tensors_map.at(output));
} }
return builder->CreateVector<int32_t>(outputs); return builder->CreateVector<int32_t>(outputs);
@ -295,10 +295,10 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
const ExportParams& params) { const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins. // Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops; std::map<std::string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
BuiltinOperator op = static_cast<BuiltinOperator>(i); BuiltinOperator op = static_cast<BuiltinOperator>(i);
string name = EnumNameBuiltinOperator(op); std::string name = EnumNameBuiltinOperator(op);
if (op != BuiltinOperator_CUSTOM && !name.empty()) { if (op != BuiltinOperator_CUSTOM && !name.empty()) {
builtin_ops[name] = op; builtin_ops[name] = op;
} }
@ -349,13 +349,13 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
std::vector<Offset<Operator>> op_vector; std::vector<Offset<Operator>> op_vector;
for (const auto& op : model.operators) { for (const auto& op : model.operators) {
std::vector<int32_t> inputs; std::vector<int32_t> inputs;
for (const string& input : op->inputs) { for (const std::string& input : op->inputs) {
// -1 is the ID for optional tensor in TFLite output // -1 is the ID for optional tensor in TFLite output
int id = model.IsOptionalArray(input) ? -1 : tensors_map.at(input); int id = model.IsOptionalArray(input) ? -1 : tensors_map.at(input);
inputs.push_back(id); inputs.push_back(id);
} }
std::vector<int32_t> outputs; std::vector<int32_t> outputs;
for (const string& output : op->outputs) { for (const std::string& output : op->outputs) {
outputs.push_back(tensors_map.at(output)); outputs.push_back(tensors_map.at(output));
} }
const toco::OperatorSignature op_signature = {op.get(), &model}; const toco::OperatorSignature op_signature = {op.get(), &model};
@ -428,15 +428,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector); return builder->CreateVector(buffer_vector);
} }
tensorflow::Status Export(const Model& model, string* output_file_contents, tensorflow::Status Export(const Model& model, std::string* output_file_contents,
const ExportParams& params) { const ExportParams& params) {
const auto ops_by_type = BuildOperatorByTypeMap(params.enable_select_tf_ops); const auto ops_by_type = BuildOperatorByTypeMap(params.enable_select_tf_ops);
return Export(model, output_file_contents, params, ops_by_type); return Export(model, output_file_contents, params, ops_by_type);
} }
void ParseControlFlowErrors(std::set<string>* custom_ops, void ParseControlFlowErrors(std::set<std::string>* custom_ops,
std::vector<string>* error_msgs) { std::vector<std::string>* error_msgs) {
std::set<string> unsupported_control_flow_ops; std::set<std::string> unsupported_control_flow_ops;
// Check if unsupported ops contains control flow ops. It's impossible // Check if unsupported ops contains control flow ops. It's impossible
// to implement these ops as custom ops at the moment. // to implement these ops as custom ops at the moment.
for (const auto& op : *custom_ops) { for (const auto& op : *custom_ops) {
@ -471,10 +471,10 @@ void ExportModelVersionBuffer(
} }
tensorflow::Status Export( tensorflow::Status Export(
const Model& model, string* output_file_contents, const Model& model, std::string* output_file_contents,
const ExportParams& params, const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
for (const string& input_array : model.GetInvalidInputArrays()) { for (const std::string& input_array : model.GetInvalidInputArrays()) {
if (model.HasArray(input_array)) { if (model.HasArray(input_array)) {
return tensorflow::errors::InvalidArgument( return tensorflow::errors::InvalidArgument(
absl::StrCat("Placeholder ", input_array, absl::StrCat("Placeholder ", input_array,
@ -509,11 +509,11 @@ tensorflow::Status Export(
} }
// The set of used builtin ops. // The set of used builtin ops.
std::set<string> builtin_ops; std::set<std::string> builtin_ops;
// The set of custom ops (not including Flex ops). // The set of custom ops (not including Flex ops).
std::set<string> custom_ops; std::set<std::string> custom_ops;
// The set of Flex ops which are not supported. // The set of Flex ops which are not supported.
std::set<string> unsupported_flex_ops; std::set<std::string> unsupported_flex_ops;
for (const auto& it : operators_map) { for (const auto& it : operators_map) {
const details::OperatorKey& key = it.first; const details::OperatorKey& key = it.first;
@ -540,7 +540,7 @@ tensorflow::Status Export(
"40-tflite-op-request.md\n and pasting the following:\n\n"; "40-tflite-op-request.md\n and pasting the following:\n\n";
}; };
std::vector<string> error_msgs; std::vector<std::string> error_msgs;
ParseControlFlowErrors(&custom_ops, &error_msgs); ParseControlFlowErrors(&custom_ops, &error_msgs);
// Remove ExpandDims and ReorderAxes from unimplemented list unless they // Remove ExpandDims and ReorderAxes from unimplemented list unless they
@ -549,7 +549,7 @@ tensorflow::Status Export(
// transformation is unable to run because the output shape is not // transformation is unable to run because the output shape is not
// defined. This causes unnecessary confusion during model conversion // defined. This causes unnecessary confusion during model conversion
// time. // time.
std::set<string> custom_ops_final; std::set<std::string> custom_ops_final;
for (const auto& op_type : custom_ops) { for (const auto& op_type : custom_ops) {
if (op_type != "ReorderAxes" && op_type != "ExpandDims") { if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
custom_ops_final.insert(op_type); custom_ops_final.insert(op_type);

View File

@ -35,19 +35,19 @@ struct ExportParams {
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string. // result in the given string.
tensorflow::Status Export(const Model& model, string* output_file_contents, tensorflow::Status Export(const Model& model, std::string* output_file_contents,
const ExportParams& params); const ExportParams& params);
// Export API with custom TFLite operator mapping. // Export API with custom TFLite operator mapping.
tensorflow::Status Export( tensorflow::Status Export(
const Model& model, string* output_file_contents, const Model& model, std::string* output_file_contents,
const ExportParams& params, const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
// This is for backward-compatibility. // This is for backward-compatibility.
// TODO(ycling): Remove the deprecated entry functions. // TODO(ycling): Remove the deprecated entry functions.
inline void Export(const Model& model, bool allow_custom_ops, inline void Export(const Model& model, bool allow_custom_ops,
bool quantize_weights, string* output_file_contents) { bool quantize_weights, std::string* output_file_contents) {
ExportParams params; ExportParams params;
params.allow_custom_ops = allow_custom_ops; params.allow_custom_ops = allow_custom_ops;
params.quantize_weights = params.quantize_weights =
@ -60,7 +60,7 @@ inline void Export(const Model& model, bool allow_custom_ops,
// TODO(ycling): Remove the deprecated entry functions. // TODO(ycling): Remove the deprecated entry functions.
inline void Export( inline void Export(
const Model& model, bool allow_custom_ops, bool quantize_weights, const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents, std::string* output_file_contents,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
ExportParams params; ExportParams params;
params.allow_custom_ops = allow_custom_ops; params.allow_custom_ops = allow_custom_ops;
@ -72,7 +72,7 @@ inline void Export(
// This is for backward-compatibility. // This is for backward-compatibility.
// TODO(ycling): Remove the deprecated entry functions. // TODO(ycling): Remove the deprecated entry functions.
inline void Export(const Model& model, string* output_file_contents) { inline void Export(const Model& model, std::string* output_file_contents) {
ExportParams params; ExportParams params;
params.allow_custom_ops = true; params.allow_custom_ops = true;
auto status = Export(model, output_file_contents, params); auto status = Export(model, output_file_contents, params);
@ -82,7 +82,7 @@ inline void Export(const Model& model, string* output_file_contents) {
namespace details { namespace details {
// A map from tensor name to its final position in the TF Lite buffer. // A map from tensor name to its final position in the TF Lite buffer.
using TensorsMap = std::unordered_map<string, int>; using TensorsMap = std::unordered_map<std::string, int>;
// A key to identify an operator. // A key to identify an operator.
// Only when `type` is `kUnsupported`, `custom_code` is filled to // Only when `type` is `kUnsupported`, `custom_code` is filled to

View File

@ -34,13 +34,13 @@ using ::testing::HasSubstr;
class ExportTest : public ::testing::Test { class ExportTest : public ::testing::Test {
protected: protected:
void ResetOperators() { input_model_.operators.clear(); } void ResetOperators() { input_model_.operators.clear(); }
void AddTensorsByName(std::initializer_list<string> names) { void AddTensorsByName(std::initializer_list<std::string> names) {
for (const string& name : names) { for (const std::string& name : names) {
input_model_.GetOrCreateArray(name); input_model_.GetOrCreateArray(name);
} }
} }
void AddOperatorsByName(std::initializer_list<string> names) { void AddOperatorsByName(std::initializer_list<std::string> names) {
for (const string& name : names) { for (const std::string& name : names) {
if (name == "Conv") { if (name == "Conv") {
auto* op = new ConvOperator; auto* op = new ConvOperator;
op->padding.type = PaddingType::kSame; op->padding.type = PaddingType::kSame;
@ -153,14 +153,15 @@ class ExportTest : public ::testing::Test {
} }
tensorflow::Status ExportAndReturnStatus(const ExportParams& params) { tensorflow::Status ExportAndReturnStatus(const ExportParams& params) {
string result; std::string result;
return Export(input_model_, &result, params); return Export(input_model_, &result, params);
} }
std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) { std::vector<std::string> ExportAndSummarizeOperators(
std::vector<string> names; const ExportParams& params) {
std::vector<std::string> names;
string result; std::string result;
auto status = Export(input_model_, &result, params); auto status = Export(input_model_, &result, params);
if (!status.ok()) { if (!status.ok()) {
LOG(INFO) << status.error_message(); LOG(INFO) << status.error_message();
@ -171,10 +172,12 @@ class ExportTest : public ::testing::Test {
for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) { for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) {
if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
names.push_back(string("builtin:") + ::tflite::EnumNameBuiltinOperator( names.push_back(
opcode->builtin_code())); std::string("builtin:") +
::tflite::EnumNameBuiltinOperator(opcode->builtin_code()));
} else { } else {
names.push_back(string("custom:") + opcode->custom_code()->c_str()); names.push_back(std::string("custom:") +
opcode->custom_code()->c_str());
} }
} }
@ -185,7 +188,7 @@ class ExportTest : public ::testing::Test {
const ExportParams& params) { const ExportParams& params) {
std::vector<uint32_t> indices; std::vector<uint32_t> indices;
string result; std::string result;
if (!Export(input_model_, &result, params).ok()) return indices; if (!Export(input_model_, &result, params).ok()) return indices;
auto* model = ::tflite::GetModel(result.data()); auto* model = ::tflite::GetModel(result.data());
@ -257,7 +260,7 @@ TEST_F(ExportTest, ExportMinRuntime) {
params.enable_select_tf_ops = false; params.enable_select_tf_ops = false;
params.quantize_weights = QuantizedBufferType::NONE; params.quantize_weights = QuantizedBufferType::NONE;
string output; std::string output;
auto status = Export(input_model_, &output, params); auto status = Export(input_model_, &output, params);
auto* model = ::tflite::GetModel(output.data()); auto* model = ::tflite::GetModel(output.data());
EXPECT_EQ(model->metadata()->size(), 1); EXPECT_EQ(model->metadata()->size(), 1);
@ -265,7 +268,8 @@ TEST_F(ExportTest, ExportMinRuntime) {
auto buf = model->metadata()->Get(0)->buffer(); auto buf = model->metadata()->Get(0)->buffer();
auto* buffer = (*model->buffers())[buf]; auto* buffer = (*model->buffers())[buf];
auto* array = buffer->data(); auto* array = buffer->data();
string version(reinterpret_cast<const char*>(array->data()), array->size()); std::string version(reinterpret_cast<const char*>(array->data()),
array->size());
EXPECT_EQ(version, "1.6.0"); EXPECT_EQ(version, "1.6.0");
} }
@ -275,7 +279,7 @@ TEST_F(ExportTest, ExportEmptyMinRuntime) {
ExportParams params; ExportParams params;
params.allow_custom_ops = true; params.allow_custom_ops = true;
string output; std::string output;
auto status = Export(input_model_, &output, params); auto status = Export(input_model_, &output, params);
auto* model = ::tflite::GetModel(output.data()); auto* model = ::tflite::GetModel(output.data());
EXPECT_EQ(model->metadata()->size(), 1); EXPECT_EQ(model->metadata()->size(), 1);
@ -283,7 +287,8 @@ TEST_F(ExportTest, ExportEmptyMinRuntime) {
auto buf = model->metadata()->Get(0)->buffer(); auto buf = model->metadata()->Get(0)->buffer();
auto* buffer = (*model->buffers())[buf]; auto* buffer = (*model->buffers())[buf];
auto* array = buffer->data(); auto* array = buffer->data();
string version(reinterpret_cast<const char*>(array->data()), array->size()); std::string version(reinterpret_cast<const char*>(array->data()),
array->size());
EXPECT_EQ(version, ""); EXPECT_EQ(version, "");
} }
@ -296,7 +301,7 @@ TEST_F(ExportTest, UnsupportedControlFlowErrors) {
// The model contains control flow ops which are not convertible, so we should // The model contains control flow ops which are not convertible, so we should
// check the returned error message. // check the returned error message.
string output; std::string output;
const auto ops_by_type = BuildOperatorByTypeMap(); const auto ops_by_type = BuildOperatorByTypeMap();
auto status = Export(input_model_, &output, params, ops_by_type); auto status = Export(input_model_, &output, params, ops_by_type);
EXPECT_EQ(status.error_message(), EXPECT_EQ(status.error_message(),
@ -318,7 +323,7 @@ TEST_F(ExportTest, UnsupportedOpsAndNeedEnableFlex) {
params.allow_custom_ops = false; params.allow_custom_ops = false;
params.enable_select_tf_ops = false; params.enable_select_tf_ops = false;
string output; std::string output;
const auto ops_by_type = BuildOperatorByTypeMap(); const auto ops_by_type = BuildOperatorByTypeMap();
auto status = Export(input_model_, &output, params, ops_by_type); auto status = Export(input_model_, &output, params, ops_by_type);
EXPECT_EQ( EXPECT_EQ(
@ -348,7 +353,7 @@ TEST_F(ExportTest, UnsupportedOpsNeedCustomImplementation) {
params.allow_custom_ops = false; params.allow_custom_ops = false;
params.enable_select_tf_ops = true; params.enable_select_tf_ops = true;
string output; std::string output;
const auto ops_by_type = BuildOperatorByTypeMap(); const auto ops_by_type = BuildOperatorByTypeMap();
auto status = Export(input_model_, &output, params, ops_by_type); auto status = Export(input_model_, &output, params, ops_by_type);
EXPECT_EQ( EXPECT_EQ(
@ -378,7 +383,7 @@ TEST_F(ExportTest, UnsupportedControlFlowAndCustomOpsErrors) {
// The model contains control flow ops which are not convertible, so we should // The model contains control flow ops which are not convertible, so we should
// check the returned error message. // check the returned error message.
string output; std::string output;
const auto ops_by_type = BuildOperatorByTypeMap(); const auto ops_by_type = BuildOperatorByTypeMap();
auto status = Export(input_model_, &output, params, ops_by_type); auto status = Export(input_model_, &output, params, ops_by_type);
EXPECT_EQ( EXPECT_EQ(
@ -407,11 +412,11 @@ TEST_F(ExportTest, UnsupportedControlFlowAndCustomOpsErrors) {
TEST_F(ExportTest, QuantizeWeights) { TEST_F(ExportTest, QuantizeWeights) {
// Sanity check for quantize_weights parameter. // Sanity check for quantize_weights parameter.
BuildQuantizableTestModel(); BuildQuantizableTestModel();
string unquantized_result; std::string unquantized_result;
Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result); Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
BuildQuantizableTestModel(); BuildQuantizableTestModel();
string quantized_result; std::string quantized_result;
Export(input_model_, true, /*quantize_weights*/ true, &quantized_result); Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
// The quantized models should be smaller. // The quantized models should be smaller.
@ -443,12 +448,13 @@ class OpSetsTest : public ExportTest {
} }
} }
std::vector<string> ImportExport(std::initializer_list<string> op_names) { std::vector<std::string> ImportExport(
std::initializer_list<std::string> op_names) {
ResetOperators(); ResetOperators();
if (!import_all_ops_as_unsupported_) { if (!import_all_ops_as_unsupported_) {
AddOperatorsByName(op_names); AddOperatorsByName(op_names);
} else { } else {
for (const string& name : op_names) { for (const std::string& name : op_names) {
auto* op = new TensorFlowUnsupportedOperator; auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = name; op->tensorflow_op = name;
input_model_.operators.emplace_back(op); input_model_.operators.emplace_back(op);
@ -644,7 +650,7 @@ TEST_F(VersionedOpExportTest, Export) {
AddConvOp(false); AddConvOp(false);
AddConvOp(true); AddConvOp(true);
string result; std::string result;
const auto ops_by_type = BuildFakeOperatorByTypeMap(); const auto ops_by_type = BuildFakeOperatorByTypeMap();
Export(input_model_, true, false, &result, ops_by_type); Export(input_model_, true, false, &result, ops_by_type);

View File

@ -99,7 +99,7 @@ void ImportTensors(const ::tflite::Model& input_model, Model* model) {
void ImportOperators( void ImportOperators(
const ::tflite::Model& input_model, const ::tflite::Model& input_model,
const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name, const std::map<std::string, std::unique_ptr<BaseOperator>>& ops_by_name,
const details::TensorsTable& tensors_table, const details::TensorsTable& tensors_table,
const details::OperatorsTable& operators_table, Model* model) { const details::OperatorsTable& operators_table, Model* model) {
// TODO(aselle): add support for multiple subgraphs. // TODO(aselle): add support for multiple subgraphs.
@ -112,12 +112,12 @@ void ImportOperators(
LOG(FATAL) << "Index " << index << " must be between zero and " LOG(FATAL) << "Index " << index << " must be between zero and "
<< operators_table.size(); << operators_table.size();
} }
string opname = operators_table.at(index); std::string opname = operators_table.at(index);
// Find and use the appropriate operator deserialization factory. // Find and use the appropriate operator deserialization factory.
std::unique_ptr<Operator> new_op = nullptr; std::unique_ptr<Operator> new_op = nullptr;
if (ops_by_name.count(opname) == 0) { if (ops_by_name.count(opname) == 0) {
string effective_opname = "TENSORFLOW_UNSUPPORTED"; std::string effective_opname = "TENSORFLOW_UNSUPPORTED";
if (ops_by_name.count(effective_opname) == 0) { if (ops_by_name.count(effective_opname) == 0) {
LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found."; LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found.";
} }
@ -147,10 +147,10 @@ void ImportOperators(
auto input_index = inputs->Get(i); auto input_index = inputs->Get(i);
// input_index == -1 indicates optional tensor. // input_index == -1 indicates optional tensor.
if (input_index != -1) { if (input_index != -1) {
const string& input_name = tensors_table.at(input_index); const std::string& input_name = tensors_table.at(input_index);
op->inputs.push_back(input_name); op->inputs.push_back(input_name);
} else { } else {
const string& tensor_name = const std::string& tensor_name =
toco::AvailableArrayName(*model, "OptionalTensor"); toco::AvailableArrayName(*model, "OptionalTensor");
model->CreateOptionalArray(tensor_name); model->CreateOptionalArray(tensor_name);
op->inputs.push_back(tensor_name); op->inputs.push_back(tensor_name);
@ -159,7 +159,7 @@ void ImportOperators(
auto outputs = input_op->outputs(); auto outputs = input_op->outputs();
for (int i = 0; i < outputs->Length(); i++) { for (int i = 0; i < outputs->Length(); i++) {
auto output_index = outputs->Get(i); auto output_index = outputs->Get(i);
const string& output_name = tensors_table.at(output_index); const std::string& output_name = tensors_table.at(output_index);
op->outputs.push_back(output_name); op->outputs.push_back(output_name);
} }
} }
@ -173,7 +173,7 @@ void ImportIOTensors(const ModelFlags& model_flags,
auto inputs = (*input_model.subgraphs())[0]->inputs(); auto inputs = (*input_model.subgraphs())[0]->inputs();
if (inputs) { if (inputs) {
for (int input : *inputs) { for (int input : *inputs) {
const string& input_name = tensors_table.at(input); const std::string& input_name = tensors_table.at(input);
model->flags.add_input_arrays()->set_name(input_name); model->flags.add_input_arrays()->set_name(input_name);
} }
} }
@ -184,7 +184,7 @@ void ImportIOTensors(const ModelFlags& model_flags,
auto outputs = (*input_model.subgraphs())[0]->outputs(); auto outputs = (*input_model.subgraphs())[0]->outputs();
if (outputs) { if (outputs) {
for (int output : *outputs) { for (int output : *outputs) {
const string& output_name = tensors_table.at(output); const std::string& output_name = tensors_table.at(output);
model->flags.add_output_arrays(output_name); model->flags.add_output_arrays(output_name);
} }
} }
@ -199,7 +199,7 @@ bool Verify(const void* buf, size_t len) {
} // namespace } // namespace
std::unique_ptr<Model> Import(const ModelFlags& model_flags, std::unique_ptr<Model> Import(const ModelFlags& model_flags,
const string& input_file_contents) { const std::string& input_file_contents) {
::tflite::AlwaysTrueResolver r; ::tflite::AlwaysTrueResolver r;
if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(), if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(),
r, ::tflite::DefaultErrorReporter())) { r, ::tflite::DefaultErrorReporter())) {

View File

@ -24,17 +24,17 @@ namespace tflite {
// Parse the given string as TF Lite flatbuffer and return a new tf.mini model. // Parse the given string as TF Lite flatbuffer and return a new tf.mini model.
std::unique_ptr<Model> Import(const ModelFlags &model_flags, std::unique_ptr<Model> Import(const ModelFlags &model_flags,
const string &input_file_contents); const std::string &input_file_contents);
namespace details { namespace details {
// The names of all tensors found in a TF Lite model. // The names of all tensors found in a TF Lite model.
using TensorsTable = std::vector<string>; using TensorsTable = std::vector<std::string>;
// The names of all operators found in TF Lite model. If the operator is // The names of all operators found in TF Lite model. If the operator is
// builtin, the string representation of the corresponding enum value is used // builtin, the string representation of the corresponding enum value is used
// as name. // as name.
using OperatorsTable = std::vector<string>; using OperatorsTable = std::vector<std::string>;
void LoadTensorsTable(const ::tflite::Model &input_model, void LoadTensorsTable(const ::tflite::Model &input_model,
TensorsTable *tensors_table); TensorsTable *tensors_table);

View File

@ -134,9 +134,9 @@ class ImportTest : public ::testing::Test {
input_model_ = ::tflite::GetModel(builder_.GetBufferPointer()); input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
} }
string InputModelAsString() { std::string InputModelAsString() {
return string(reinterpret_cast<char*>(builder_.GetBufferPointer()), return std::string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
builder_.GetSize()); builder_.GetSize());
} }
flatbuffers::FlatBufferBuilder builder_; flatbuffers::FlatBufferBuilder builder_;
const ::tflite::Model* input_model_ = nullptr; const ::tflite::Model* input_model_ = nullptr;

View File

@ -29,7 +29,7 @@ namespace tflite {
// Deprecated and please register new ops/versions in // Deprecated and please register new ops/versions in
// tflite/tools/versioning/op_version.cc". // tflite/tools/versioning/op_version.cc".
string GetMinimumRuntimeVersionForModel(const Model& model) { std::string GetMinimumRuntimeVersionForModel(const Model& model) {
// Use this as the placeholder string if a particular op is not yet included // Use this as the placeholder string if a particular op is not yet included
// in any Tensorflow's RC/Final release source package. Once that op is // in any Tensorflow's RC/Final release source package. Once that op is
// included in the release, please update this with the real version string. // included in the release, please update this with the real version string.
@ -37,8 +37,8 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
// A map from the version key of an op to its minimum runtime version. // A map from the version key of an op to its minimum runtime version.
// For example, {{kAveragePool, 1}, "1.5.0"}, means the 1st version of // For example, {{kAveragePool, 1}, "1.5.0"}, means the 1st version of
// AveragePool requires a minimum TF Lite runtime version '1.5.0`. // AveragePool requires a minimum TF Lite runtime version '1.5.0`.
static const std::map<std::pair<OperatorType, int>, string>* op_version_map = static const std::map<std::pair<OperatorType, int>, std::string>*
new std::map<std::pair<OperatorType, int>, string>({ op_version_map = new std::map<std::pair<OperatorType, int>, std::string>({
{{OperatorType::kAveragePool, 1}, "1.5.0"}, {{OperatorType::kAveragePool, 1}, "1.5.0"},
{{OperatorType::kAveragePool, 2}, "1.14.0"}, {{OperatorType::kAveragePool, 2}, "1.14.0"},
{{OperatorType::kAveragePool, 3}, kPendingReleaseOpVersion}, {{OperatorType::kAveragePool, 3}, kPendingReleaseOpVersion},
@ -253,7 +253,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
tflite::BuildOperatorByTypeMap(false /*enable_select_tf_ops=*/); tflite::BuildOperatorByTypeMap(false /*enable_select_tf_ops=*/);
OperatorSignature op_signature; OperatorSignature op_signature;
op_signature.model = &model; op_signature.model = &model;
string model_min_version; std::string model_min_version;
for (const auto& op : model.operators) { for (const auto& op : model.operators) {
if (op_types_map.find(op->type) == op_types_map.end()) continue; if (op_types_map.find(op->type) == op_types_map.end()) continue;
op_signature.op = op.get(); op_signature.op = op.get();

View File

@ -27,9 +27,9 @@ TEST(OpVersionTest, MinimumVersionForSameOpVersions) {
Model model; Model model;
// Float convolutional kernel is introduced since '1.5.0'. // Float convolutional kernel is introduced since '1.5.0'.
std::unique_ptr<ConvOperator> conv(new ConvOperator()); std::unique_ptr<ConvOperator> conv(new ConvOperator());
const string conv_input = "conv_input"; const std::string conv_input = "conv_input";
const string conv_filter = "conv_filter"; const std::string conv_filter = "conv_filter";
const string conv_output = "conv_output"; const std::string conv_output = "conv_output";
conv->inputs.push_back(conv_input); conv->inputs.push_back(conv_input);
conv->inputs.push_back(conv_filter); conv->inputs.push_back(conv_filter);
conv->outputs.push_back(conv_output); conv->outputs.push_back(conv_output);
@ -44,8 +44,8 @@ TEST(OpVersionTest, MinimumVersionForSameOpVersions) {
// Float softmax kernel is introduced since '1.5.0'. // Float softmax kernel is introduced since '1.5.0'.
std::unique_ptr<SoftmaxOperator> softmax(new SoftmaxOperator()); std::unique_ptr<SoftmaxOperator> softmax(new SoftmaxOperator());
const string softmax_input = "softmax_input"; const std::string softmax_input = "softmax_input";
const string softmax_output = "softmax_output"; const std::string softmax_output = "softmax_output";
softmax->inputs.push_back(softmax_input); softmax->inputs.push_back(softmax_input);
softmax->outputs.push_back(softmax_output); softmax->outputs.push_back(softmax_output);
array_map[softmax_input] = std::unique_ptr<Array>(new Array); array_map[softmax_input] = std::unique_ptr<Array>(new Array);
@ -60,9 +60,9 @@ TEST(OpVersionTest, MinimumVersionForMultipleOpVersions) {
Model model; Model model;
// Dilated DepthWiseConvolution is introduced since '1.12.0'. // Dilated DepthWiseConvolution is introduced since '1.12.0'.
std::unique_ptr<DepthwiseConvOperator> conv(new DepthwiseConvOperator()); std::unique_ptr<DepthwiseConvOperator> conv(new DepthwiseConvOperator());
const string conv_input = "conv_input"; const std::string conv_input = "conv_input";
const string conv_filter = "conv_filter"; const std::string conv_filter = "conv_filter";
const string conv_output = "conv_output"; const std::string conv_output = "conv_output";
conv->inputs.push_back(conv_input); conv->inputs.push_back(conv_input);
conv->inputs.push_back(conv_filter); conv->inputs.push_back(conv_filter);
conv->outputs.push_back(conv_output); conv->outputs.push_back(conv_output);
@ -77,10 +77,10 @@ TEST(OpVersionTest, MinimumVersionForMultipleOpVersions) {
// FullyConnected op with kShuffled4x16Int8 weight format is introduced from // FullyConnected op with kShuffled4x16Int8 weight format is introduced from
// '1.10.0'. // '1.10.0'.
std::unique_ptr<FullyConnectedOperator> fc(new FullyConnectedOperator()); std::unique_ptr<FullyConnectedOperator> fc(new FullyConnectedOperator());
const string fc_input = "fc_input"; const std::string fc_input = "fc_input";
const string fc_weights = "fc_weights"; const std::string fc_weights = "fc_weights";
const string fc_bias = "fc_bias"; const std::string fc_bias = "fc_bias";
const string fc_output = "fc_output"; const std::string fc_output = "fc_output";
fc->inputs.push_back(fc_input); fc->inputs.push_back(fc_input);
fc->inputs.push_back(fc_weights); fc->inputs.push_back(fc_weights);
fc->inputs.push_back(fc_bias); fc->inputs.push_back(fc_bias);
@ -121,10 +121,10 @@ TEST(OpVersionTest, MinimumVersionForMixedOpVersions) {
// FullyConnected op with kShuffled4x16Int8 weight format is introduced from // FullyConnected op with kShuffled4x16Int8 weight format is introduced from
// '1.10.0'. // '1.10.0'.
std::unique_ptr<FullyConnectedOperator> fc(new FullyConnectedOperator()); std::unique_ptr<FullyConnectedOperator> fc(new FullyConnectedOperator());
const string fc_input = "fc_input"; const std::string fc_input = "fc_input";
const string fc_weights = "fc_weights"; const std::string fc_weights = "fc_weights";
const string fc_bias = "fc_bias"; const std::string fc_bias = "fc_bias";
const string fc_output = "fc_output"; const std::string fc_output = "fc_output";
fc->inputs.push_back(fc_input); fc->inputs.push_back(fc_input);
fc->inputs.push_back(fc_weights); fc->inputs.push_back(fc_weights);
fc->inputs.push_back(fc_bias); fc->inputs.push_back(fc_bias);

View File

@ -238,7 +238,7 @@ class SpaceToBatchND
TocoOperator* op) const override {} TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0]; const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name); const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig = ::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature); GetVersioningOpSig(builtin_op(), op_signature);
@ -268,8 +268,8 @@ class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
} }
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input1_name = op_signature.op->inputs[0]; const std::string& input1_name = op_signature.op->inputs[0];
const string& input2_name = op_signature.op->inputs[1]; const std::string& input2_name = op_signature.op->inputs[1];
const Array& input1_array = op_signature.model->GetArray(input1_name); const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name); const Array& input2_array = op_signature.model->GetArray(input2_name);
::tflite::OpSignature op_sig = ::tflite::OpSignature op_sig =
@ -305,8 +305,8 @@ class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
} }
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input1_name = op_signature.op->inputs[0]; const std::string& input1_name = op_signature.op->inputs[0];
const string& input2_name = op_signature.op->inputs[1]; const std::string& input2_name = op_signature.op->inputs[1];
const Array& input1_array = op_signature.model->GetArray(input1_name); const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name); const Array& input2_array = op_signature.model->GetArray(input2_name);
::tflite::OpSignature op_sig = ::tflite::OpSignature op_sig =
@ -339,7 +339,7 @@ class BatchToSpaceND
TocoOperator* op) const override {} TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0]; const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name); const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig = ::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature); GetVersioningOpSig(builtin_op(), op_signature);
@ -662,9 +662,9 @@ class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
} }
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input1_name = op_signature.op->inputs[0]; const std::string& input1_name = op_signature.op->inputs[0];
const string& input2_name = op_signature.op->inputs[1]; const std::string& input2_name = op_signature.op->inputs[1];
const string& output_name = op_signature.op->outputs[0]; const std::string& output_name = op_signature.op->outputs[0];
const Array& input1_array = op_signature.model->GetArray(input1_name); const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name); const Array& input2_array = op_signature.model->GetArray(input2_name);
const Array& output_array = op_signature.model->GetArray(output_name); const Array& output_array = op_signature.model->GetArray(output_name);
@ -1440,7 +1440,7 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
} }
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0]; const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name); const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8/uint8 input, it is version 2. // If the op take int8/uint8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8 || if (input_array.data_type == ArrayDataType::kInt8 ||
@ -1577,7 +1577,7 @@ class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
}; };
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
const string& tensorflow_node_def) { const std::string& tensorflow_node_def) {
auto fbb = absl::make_unique<flexbuffers::Builder>(); auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def; ::tensorflow::NodeDef node_def;
@ -1597,7 +1597,7 @@ std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
class TensorFlowUnsupported : public BaseOperator { class TensorFlowUnsupported : public BaseOperator {
public: public:
TensorFlowUnsupported(const string& name, OperatorType type, TensorFlowUnsupported(const std::string& name, OperatorType type,
bool enable_select_tf_ops) bool enable_select_tf_ops)
: BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {} : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
@ -1676,7 +1676,7 @@ class TensorFlowUnsupported : public BaseOperator {
case tensorflow::AttrValue::kList: case tensorflow::AttrValue::kList:
if (attr.list().s_size() > 0) { if (attr.list().s_size() > 0) {
auto start = fbb->StartVector(key); auto start = fbb->StartVector(key);
for (const string& v : attr.list().s()) { for (const std::string& v : attr.list().s()) {
fbb->Add(v); fbb->Add(v);
} }
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
@ -1736,10 +1736,11 @@ class TensorFlowUnsupported : public BaseOperator {
break; break;
case flexbuffers::FBT_BOOL: case flexbuffers::FBT_BOOL:
(*attr)[key].set_b(value.AsBool()); (*attr)[key].set_b(value.AsBool());
if (string(key) == "_output_quantized") { if (std::string(key) == "_output_quantized") {
op->quantized = value.AsBool(); op->quantized = value.AsBool();
} }
if (string(key) == "_support_output_type_float_in_quantized_op") { if (std::string(key) ==
"_support_output_type_float_in_quantized_op") {
op->support_output_type_float_in_quantized_op = value.AsBool(); op->support_output_type_float_in_quantized_op = value.AsBool();
} }
break; break;
@ -2095,9 +2096,9 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
return result; return result;
} }
std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
bool enable_select_tf_ops) { bool enable_select_tf_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result; std::map<std::string, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops = std::vector<std::unique_ptr<BaseOperator>> ops =
BuildOperatorList(enable_select_tf_ops); BuildOperatorList(enable_select_tf_ops);
@ -2109,7 +2110,7 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
} }
bool ShouldExportAsFlexOp(bool enable_select_tf_ops, bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
const string& tensorflow_op_name) { const std::string& tensorflow_op_name) {
// If Flex ops aren't allow at all, simply return false. // If Flex ops aren't allow at all, simply return false.
if (!enable_select_tf_ops) { if (!enable_select_tf_ops) {
return false; return false;

View File

@ -30,7 +30,7 @@ class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names. // Return a map contained all know TF Lite Operators, keyed by their names.
// TODO(ycling): The pattern to propagate parameters (e.g. enable_select_tf_ops) // TODO(ycling): The pattern to propagate parameters (e.g. enable_select_tf_ops)
// is ugly here. Consider refactoring. // is ugly here. Consider refactoring.
std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
bool enable_select_tf_ops = false); bool enable_select_tf_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of // Return a map contained all know TF Lite Operators, keyed by the type of
@ -41,7 +41,7 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef // Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
// for a Flex op. // for a Flex op.
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
const string& tensorflow_node_def); const std::string& tensorflow_node_def);
// These are the flatbuffer types for custom and builtin options. // These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>; using CustomOptions = flatbuffers::Vector<uint8_t>;
@ -71,11 +71,11 @@ struct Options {
class BaseOperator { class BaseOperator {
public: public:
// Build an operator with the given TF Lite name and tf.mini type. // Build an operator with the given TF Lite name and tf.mini type.
BaseOperator(const string& name, OperatorType type) BaseOperator(const std::string& name, OperatorType type)
: name_(name), type_(type) {} : name_(name), type_(type) {}
virtual ~BaseOperator() = default; virtual ~BaseOperator() = default;
string name() const { return name_; } std::string name() const { return name_; }
OperatorType type() const { return type_; } OperatorType type() const { return type_; }
// Given a tf.mini operator, create the corresponding flatbuffer options and // Given a tf.mini operator, create the corresponding flatbuffer options and
@ -111,7 +111,7 @@ class BaseOperator {
} }
private: private:
string name_; std::string name_;
OperatorType type_; OperatorType type_;
}; };
@ -123,7 +123,7 @@ class BaseOperator {
// Helper function to determine if a unsupported TensorFlow op should be // Helper function to determine if a unsupported TensorFlow op should be
// exported as an Flex op or a regular custom op. // exported as an Flex op or a regular custom op.
bool ShouldExportAsFlexOp(bool enable_select_tf_ops, bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
const string& tensorflow_op_name); const std::string& tensorflow_op_name);
} // namespace tflite } // namespace tflite

View File

@ -30,8 +30,8 @@ namespace {
class OperatorTest : public ::testing::Test { class OperatorTest : public ::testing::Test {
protected: protected:
// Return the operator for the given name and type. // Return the operator for the given name and type.
const BaseOperator& GetOperator(const string& name, OperatorType type) { const BaseOperator& GetOperator(const std::string& name, OperatorType type) {
using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>; using OpsByName = std::map<std::string, std::unique_ptr<BaseOperator>>;
using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>; using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
static auto* by_name = new OpsByName(BuildOperatorByNameMap()); static auto* by_name = new OpsByName(BuildOperatorByNameMap());
@ -86,7 +86,7 @@ class OperatorTest : public ::testing::Test {
// Verify serialization and deserialization of simple operators (those // Verify serialization and deserialization of simple operators (those
// that don't have any configuration parameters). // that don't have any configuration parameters).
template <typename T> template <typename T>
void CheckSimpleOperator(const string& name, OperatorType type) { void CheckSimpleOperator(const std::string& name, OperatorType type) {
Options options; Options options;
auto output_toco_op = auto output_toco_op =
SerializeAndDeserialize(GetOperator(name, type), T(), &options); SerializeAndDeserialize(GetOperator(name, type), T(), &options);
@ -99,7 +99,7 @@ class OperatorTest : public ::testing::Test {
} }
template <typename T> template <typename T>
void CheckReducerOperator(const string& name, OperatorType type) { void CheckReducerOperator(const std::string& name, OperatorType type) {
T op; T op;
op.keep_dims = false; op.keep_dims = false;

View File

@ -25,7 +25,7 @@ DataBuffer::FlatBufferOffset CopyStringToBuffer(
const Array& array, flatbuffers::FlatBufferBuilder* builder) { const Array& array, flatbuffers::FlatBufferBuilder* builder) {
const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data; const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
::tflite::DynamicBuffer dyn_buffer; ::tflite::DynamicBuffer dyn_buffer;
for (const string& str : src_data) { for (const std::string& str : src_data) {
dyn_buffer.AddString(str.c_str(), str.length()); dyn_buffer.AddString(str.c_str(), str.length());
} }
char* tensor_buffer; char* tensor_buffer;
@ -58,12 +58,12 @@ DataBuffer::FlatBufferOffset CopyBuffer(
void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) { void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
auto* src_data = reinterpret_cast<const char*>(buffer.data()->data()); auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
std::vector<string>* dst_data = std::vector<std::string>* dst_data =
&array->GetMutableBuffer<ArrayDataType::kString>().data; &array->GetMutableBuffer<ArrayDataType::kString>().data;
int32_t num_strings = ::tflite::GetStringCount(src_data); int32_t num_strings = ::tflite::GetStringCount(src_data);
for (int i = 0; i < num_strings; i++) { for (int i = 0; i < num_strings; i++) {
::tflite::StringRef str_ref = ::tflite::GetString(src_data, i); ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
string this_str(str_ref.str, str_ref.len); std::string this_str(str_ref.str, str_ref.len);
dst_data->push_back(this_str); dst_data->push_back(this_str);
} }
} }