STT-tensorflow/tensorflow/examples/ios/camera/tensorflow_utils.mm
2017-06-07 11:03:49 -07:00

220 lines
7.5 KiB
Plaintext

// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#include "tensorflow_utils.h"
#include <pthread.h>
#include <unistd.h>
#include <fstream>
#include <queue>
#include <sstream>
#include <string>
namespace {
// Helper class used to load protobufs efficiently.
class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
public:
explicit IfstreamInputStream(const std::string& file_name)
: ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {}
~IfstreamInputStream() { ifs_.close(); }
int Read(void* buffer, int size) {
if (!ifs_) {
return -1;
}
ifs_.read(static_cast<char*>(buffer), size);
return ifs_.gcount();
}
private:
std::ifstream ifs_;
};
} // namespace
// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
void GetTopN(const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
Eigen::Aligned>& prediction,
const int num_results, const float threshold,
std::vector<std::pair<float, int> >* top_results) {
// Will contain top N results in ascending order.
std::priority_queue<std::pair<float, int>,
std::vector<std::pair<float, int> >,
std::greater<std::pair<float, int> > >
top_result_pq;
const int count = prediction.size();
for (int i = 0; i < count; ++i) {
const float value = prediction(i);
// Only add it if it beats the threshold and has a chance at being in
// the top N.
if (value < threshold) {
continue;
}
top_result_pq.push(std::pair<float, int>(value, i));
// If at capacity, kick the smallest value out.
if (top_result_pq.size() > num_results) {
top_result_pq.pop();
}
}
// Copy to output vector and reverse into descending order.
while (!top_result_pq.empty()) {
top_results->push_back(top_result_pq.top());
top_result_pq.pop();
}
std::reverse(top_results->begin(), top_results->end());
}
bool PortableReadFileToProto(const std::string& file_name,
::google::protobuf::MessageLite* proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream(
new IfstreamInputStream(file_name));
stream.SetOwnsCopyingStream(true);
::google::protobuf::io::CodedInputStream coded_stream(&stream);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path =
[[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
<< [extension UTF8String] << "' in bundle.";
return nullptr;
}
return file_path;
}
tensorflow::Status LoadModel(NSString* file_name, NSString* file_type,
std::unique_ptr<tensorflow::Session>* session) {
tensorflow::SessionOptions options;
tensorflow::Session* session_pointer = nullptr;
tensorflow::Status session_status =
tensorflow::NewSession(options, &session_pointer);
if (!session_status.ok()) {
LOG(ERROR) << "Could not create TensorFlow Session: " << session_status;
return session_status;
}
session->reset(session_pointer);
tensorflow::GraphDef tensorflow_graph;
NSString* model_path = FilePathForResourceName(file_name, file_type);
if (!model_path) {
LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String]
<< [file_type UTF8String];
return tensorflow::errors::NotFound([file_name UTF8String],
[file_type UTF8String]);
}
const bool read_proto_succeeded =
PortableReadFileToProto([model_path UTF8String], &tensorflow_graph);
if (!read_proto_succeeded) {
LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String];
return tensorflow::errors::NotFound([model_path UTF8String]);
}
tensorflow::Status create_status = (*session)->Create(tensorflow_graph);
if (!create_status.ok()) {
LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status;
return create_status;
}
return tensorflow::Status::OK();
}
tensorflow::Status LoadMemoryMappedModel(
NSString* file_name, NSString* file_type,
std::unique_ptr<tensorflow::Session>* session,
std::unique_ptr<tensorflow::MemmappedEnv>* memmapped_env) {
NSString* network_path = FilePathForResourceName(file_name, file_type);
memmapped_env->reset(
new tensorflow::MemmappedEnv(tensorflow::Env::Default()));
tensorflow::Status mmap_status =
(memmapped_env->get())->InitializeFromFile([network_path UTF8String]);
if (!mmap_status.ok()) {
LOG(ERROR) << "MMap failed with " << mmap_status.error_message();
return mmap_status;
}
tensorflow::GraphDef tensorflow_graph;
tensorflow::Status load_graph_status = ReadBinaryProto(
memmapped_env->get(),
tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&tensorflow_graph);
if (!load_graph_status.ok()) {
LOG(ERROR) << "MMap load graph failed with "
<< load_graph_status.error_message();
return load_graph_status;
}
tensorflow::SessionOptions options;
// Disable optimizations on this graph so that constant folding doesn't
// increase the memory footprint by creating new constant copies of the weight
// parameters.
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(::tensorflow::OptimizerOptions::L0);
options.env = memmapped_env->get();
tensorflow::Session* session_pointer = nullptr;
tensorflow::Status session_status =
tensorflow::NewSession(options, &session_pointer);
if (!session_status.ok()) {
LOG(ERROR) << "Could not create TensorFlow Session: " << session_status;
return session_status;
}
tensorflow::Status create_status = session_pointer->Create(tensorflow_graph);
if (!create_status.ok()) {
LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status;
return create_status;
}
session->reset(session_pointer);
return tensorflow::Status::OK();
}
tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type,
std::vector<std::string>* label_strings) {
// Read the label list
NSString* labels_path = FilePathForResourceName(file_name, file_type);
if (!labels_path) {
LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String]
<< [file_type UTF8String];
return tensorflow::errors::NotFound([file_name UTF8String],
[file_type UTF8String]);
}
std::ifstream t;
t.open([labels_path UTF8String]);
std::string line;
while (t) {
std::getline(t, line);
label_strings->push_back(line);
}
t.close();
return tensorflow::Status::OK();
}