STT/native_client/tfmodelstate.cc
Reuben Morais ae0cf8db6a Revert "Merge branch 'rename-real'"
This reverts commit ae9fdb183e, reversing
changes made to 2eb75b6206.
2020-08-26 11:46:09 +02:00

265 lines
8.2 KiB
C++

#include "tfmodelstate.h"
#include "workspace_status.h"
using namespace tensorflow;
using std::vector;
TFModelState::TFModelState()
: ModelState()
, mmap_env_(nullptr)
, session_(nullptr)
{
}
TFModelState::~TFModelState()
{
if (session_) {
Status status = session_->Close();
if (!status.ok()) {
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
}
}
}
int
TFModelState::init(const char* model_path)
{
int err = ModelState::init(model_path);
if (err != DS_ERR_OK) {
return err;
}
Status status;
SessionOptions options;
mmap_env_.reset(new MemmappedEnv(Env::Default()));
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
if (!is_mmap) {
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
} else {
status = mmap_env_->InitializeFromFile(model_path);
if (!status.ok()) {
std::cerr << status << std::endl;
return DS_ERR_FAIL_INIT_MMAP;
}
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(::OptimizerOptions::L0);
options.env = mmap_env_.get();
}
Session* session;
status = NewSession(options, &session);
if (!status.ok()) {
std::cerr << status << std::endl;
return DS_ERR_FAIL_INIT_SESS;
}
session_.reset(session);
if (is_mmap) {
status = ReadBinaryProto(mmap_env_.get(),
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&graph_def_);
} else {
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
}
if (!status.ok()) {
std::cerr << status << std::endl;
return DS_ERR_FAIL_READ_PROTOBUF;
}
status = session_->Create(graph_def_);
if (!status.ok()) {
std::cerr << status << std::endl;
return DS_ERR_FAIL_CREATE_SESS;
}
std::vector<tensorflow::Tensor> version_output;
status = session_->Run({}, {
"metadata_version"
}, {}, &version_output);
if (!status.ok()) {
std::cerr << "Unable to fetch graph version: " << status << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
int graph_version = version_output[0].scalar<int>()();
if (graph_version < ds_graph_version()) {
std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
<< ds_graph_version() << "). See "
<< "https://github.com/mozilla/DeepSpeech/blob/"
<< ds_git_version() << "/doc/USING.rst#model-compatibility "
<< "for more information" << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
std::vector<tensorflow::Tensor> metadata_outputs;
status = session_->Run({}, {
"metadata_sample_rate",
"metadata_feature_win_len",
"metadata_feature_win_step",
"metadata_beam_width",
"metadata_alphabet",
}, {}, &metadata_outputs);
if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
sample_rate_ = metadata_outputs[0].scalar<int>()();
int win_len_ms = metadata_outputs[1].scalar<int>()();
int win_step_ms = metadata_outputs[2].scalar<int>()();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
int beam_width = metadata_outputs[3].scalar<int>()();
beam_width_ = (unsigned int)(beam_width);
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
if (err != 0) {
return DS_ERR_INVALID_ALPHABET;
}
assert(sample_rate_ > 0);
assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0);
assert(beam_width_ > 0);
assert(alphabet_.GetSize() > 0);
for (int i = 0; i < graph_def_.node_size(); ++i) {
NodeDef node = graph_def_.node(i);
if (node.name() == "input_node") {
const auto& shape = node.attr().at("shape").shape();
n_steps_ = shape.dim(1).size();
n_context_ = (shape.dim(2).size()-1)/2;
n_features_ = shape.dim(3).size();
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
} else if (node.name() == "previous_state_c") {
const auto& shape = node.attr().at("shape").shape();
state_size_ = shape.dim(1).size();
} else if (node.name() == "logits_shape") {
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
continue;
}
int final_dim_size = logits_shape.vec<int>()(2) - 1;
if (final_dim_size != alphabet_.GetSize()) {
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
<< "has size " << alphabet_.GetSize()
<< ", but model has " << final_dim_size
<< " classes in its output. Make sure you're passing an alphabet "
<< "file with the same size as the one used for training."
<< std::endl;
return DS_ERR_INVALID_ALPHABET;
}
}
}
if (n_context_ == -1 || n_features_ == -1) {
std::cerr << "Error: Could not infer input shape from model file. "
<< "Make sure input_node is a 4D tensor with shape "
<< "[batch_size=1, time, window_size, n_features]."
<< std::endl;
return DS_ERR_INVALID_SHAPE;
}
return DS_ERR_OK;
}
Tensor
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
{
Tensor ret(DT_FLOAT, shape);
auto ret_mapped = ret.flat<float>();
int i;
for (i = 0; i < vec.size(); ++i) {
ret_mapped(i) = vec[i];
}
for (; i < shape.num_elements(); ++i) {
ret_mapped(i) = 0.f;
}
return ret;
}
void
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
{
auto tensor_mapped = tensor.flat<float>();
if (num_elements == -1) {
num_elements = tensor.shape().num_elements();
}
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor_mapped(i));
}
}
void
TFModelState::infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor input_lengths(DT_INT32, TensorShape({1}));
input_lengths.scalar<int>()() = n_frames;
vector<Tensor> outputs;
Status status = session_->Run(
{
{"input_node", input},
{"input_lengths", input_lengths},
{"previous_state_c", previous_state_c_t},
{"previous_state_h", previous_state_h_t}
},
{"logits", "new_state_c", "new_state_h"},
{},
&outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
state_c_output.clear();
state_c_output.reserve(state_size_);
copy_tensor_to_vector(outputs[1], state_c_output);
state_h_output.clear();
state_h_output.reserve(state_size_);
copy_tensor_to_vector(outputs[2], state_h_output);
}
void
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
vector<Tensor> outputs;
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
// The feature computation graph is hardcoded to one audio length for now
const int n_windows = 1;
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
copy_tensor_to_vector(outputs[0], mfcc_output);
}