233 lines
8.8 KiB
C++
233 lines
8.8 KiB
C++
#include "lm/model.hh"
|
|
#include "util/file_stream.hh"
|
|
#include "util/file.hh"
|
|
#include "util/file_piece.hh"
|
|
#include "util/usage.hh"
|
|
#include "util/thread_pool.hh"
|
|
|
|
#include <boost/range/iterator_range.hpp>
|
|
#include <boost/program_options.hpp>
|
|
|
|
#include <iostream>
|
|
|
|
#include <stdint.h>
|
|
|
|
namespace {
|
|
|
|
template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
|
|
util::FilePiece in(fd_in);
|
|
util::FileStream out(1);
|
|
Width width;
|
|
StringPiece word;
|
|
const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
|
|
while (true) {
|
|
while (in.ReadWordSameLine(word)) {
|
|
width = (Width)model.GetVocabulary().Index(word);
|
|
out.write(&width, sizeof(Width));
|
|
}
|
|
if (!in.ReadLineOrEOF(word)) break;
|
|
out.write(&end_sentence, sizeof(Width));
|
|
}
|
|
}
|
|
|
|
template <class Model, class Width> class Worker {
|
|
public:
|
|
explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {}
|
|
|
|
// Destructors happen in the main thread, so there's no race for add_total_.
|
|
~Worker() { add_total_ += total_; }
|
|
|
|
typedef boost::iterator_range<Width *> Request;
|
|
|
|
void operator()(Request request) {
|
|
const lm::ngram::State *const begin_state = &model_.BeginSentenceState();
|
|
const lm::ngram::State *next_state = begin_state;
|
|
const Width kEOS = model_.GetVocabulary().EndSentence();
|
|
float sum = 0.0;
|
|
// Do even stuff first.
|
|
const Width *even_end = request.begin() + (request.size() & ~1);
|
|
// Alternating states
|
|
const Width *i;
|
|
for (i = request.begin(); i != even_end;) {
|
|
sum += model_.FullScore(*next_state, *i, state_[1]).prob;
|
|
next_state = (*i++ == kEOS) ? begin_state : &state_[1];
|
|
sum += model_.FullScore(*next_state, *i, state_[0]).prob;
|
|
next_state = (*i++ == kEOS) ? begin_state : &state_[0];
|
|
}
|
|
// Odd corner case.
|
|
if (request.size() & 1) {
|
|
sum += model_.FullScore(*next_state, *i, state_[2]).prob;
|
|
next_state = (*i++ == kEOS) ? begin_state : &state_[2];
|
|
}
|
|
total_ += sum;
|
|
}
|
|
|
|
private:
|
|
const Model &model_;
|
|
double total_;
|
|
double &add_total_;
|
|
|
|
lm::ngram::State state_[3];
|
|
};
|
|
|
|
struct Config {
|
|
int fd_in;
|
|
std::size_t threads;
|
|
std::size_t buf_per_thread;
|
|
bool query;
|
|
};
|
|
|
|
template <class Model, class Width> void QueryFromBytes(const Model &model, const Config &config) {
|
|
util::FileStream out(1);
|
|
out << "Threads: " << config.threads << '\n';
|
|
const Width kEOS = model.GetVocabulary().EndSentence();
|
|
double total = 0.0;
|
|
// Number of items to have in queue in addition to everything in flight.
|
|
const std::size_t kInQueue = 3;
|
|
std::size_t total_queue = config.threads + kInQueue;
|
|
std::vector<Width> backing(config.buf_per_thread * total_queue);
|
|
double loaded_cpu;
|
|
double loaded_wall;
|
|
uint64_t queries = 0;
|
|
{
|
|
util::RecyclingThreadPool<Worker<Model, Width> > pool(total_queue, config.threads, Worker<Model, Width>(model, total), boost::iterator_range<Width *>((Width*)0, (Width*)0));
|
|
|
|
for (std::size_t i = 0; i < total_queue; ++i) {
|
|
pool.PopulateRecycling(boost::iterator_range<Width *>(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread]));
|
|
}
|
|
|
|
loaded_cpu = util::CPUTime();
|
|
loaded_wall = util::WallTime();
|
|
out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n';
|
|
boost::iterator_range<Width *> overhang((Width*)0, (Width*)0);
|
|
while (true) {
|
|
boost::iterator_range<Width *> buf = pool.Consume();
|
|
std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width));
|
|
std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width));
|
|
if (!got && overhang.empty()) break;
|
|
UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
|
|
Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width);
|
|
Width *last_eos;
|
|
for (last_eos = read_end - 1; ; --last_eos) {
|
|
UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words. Rerun with increased buffer size. TODO: adaptable buffer");
|
|
if (*last_eos == kEOS) break;
|
|
}
|
|
buf = boost::iterator_range<Width*>(buf.begin(), last_eos + 1);
|
|
overhang = boost::iterator_range<Width*>(last_eos + 1, read_end);
|
|
queries += buf.size();
|
|
pool.Produce(buf);
|
|
}
|
|
} // Drain pool.
|
|
|
|
double after_cpu = util::CPUTime();
|
|
double after_wall = util::WallTime();
|
|
util::FileStream(2, 70) << "Probability sum: " << total << '\n';
|
|
out << "Queries: " << queries << '\n';
|
|
out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n';
|
|
double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast<double>(queries));
|
|
double wall_per_entry = ((after_wall - loaded_wall) / static_cast<double>(queries));
|
|
out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n';
|
|
out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n';
|
|
out << "RSSMax: " << util::RSSMax() << '\n';
|
|
}
|
|
|
|
template <class Model, class Width> void DispatchFunction(const Model &model, const Config &config) {
|
|
if (config.query) {
|
|
QueryFromBytes<Model, Width>(model, config);
|
|
} else {
|
|
ConvertToBytes<Model, Width>(model, config.fd_in);
|
|
}
|
|
}
|
|
|
|
template <class Model> void DispatchWidth(const char *file, const Config &config) {
|
|
lm::ngram::Config model_config;
|
|
model_config.load_method = util::READ;
|
|
Model model(file, model_config);
|
|
uint64_t bound = model.GetVocabulary().Bound();
|
|
if (bound <= 256) {
|
|
DispatchFunction<Model, uint8_t>(model, config);
|
|
} else if (bound <= 65536) {
|
|
DispatchFunction<Model, uint16_t>(model, config);
|
|
} else if (bound <= (1ULL << 32)) {
|
|
DispatchFunction<Model, uint32_t>(model, config);
|
|
} else {
|
|
DispatchFunction<Model, uint64_t>(model, config);
|
|
}
|
|
}
|
|
|
|
void Dispatch(const char *file, const Config &config) {
|
|
using namespace lm::ngram;
|
|
lm::ngram::ModelType model_type;
|
|
if (lm::ngram::RecognizeBinary(file, model_type)) {
|
|
switch(model_type) {
|
|
case PROBING:
|
|
DispatchWidth<lm::ngram::ProbingModel>(file, config);
|
|
break;
|
|
case REST_PROBING:
|
|
DispatchWidth<lm::ngram::RestProbingModel>(file, config);
|
|
break;
|
|
case TRIE:
|
|
DispatchWidth<lm::ngram::TrieModel>(file, config);
|
|
break;
|
|
case QUANT_TRIE:
|
|
DispatchWidth<lm::ngram::QuantTrieModel>(file, config);
|
|
break;
|
|
case ARRAY_TRIE:
|
|
DispatchWidth<lm::ngram::ArrayTrieModel>(file, config);
|
|
break;
|
|
case QUANT_ARRAY_TRIE:
|
|
DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, config);
|
|
break;
|
|
default:
|
|
UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
|
|
}
|
|
} else {
|
|
UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
int main(int argc, char *argv[]) {
|
|
try {
|
|
Config config;
|
|
config.fd_in = 0;
|
|
std::string model;
|
|
namespace po = boost::program_options;
|
|
po::options_description options("Benchmark options");
|
|
options.add_options()
|
|
("help,h", po::bool_switch(), "Show help message")
|
|
("model,m", po::value<std::string>(&model)->required(), "Model to query or convert vocab ids")
|
|
("threads,t", po::value<std::size_t>(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)")
|
|
("buffer,b", po::value<std::size_t>(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.")
|
|
("vocab,v", po::bool_switch(), "Convert strings to vocab ids")
|
|
("query,q", po::bool_switch(), "Query from vocab ids");
|
|
po::variables_map vm;
|
|
po::store(po::parse_command_line(argc, argv, options), vm);
|
|
if (argc == 1 || vm["help"].as<bool>()) {
|
|
std::cerr << "Benchmark program for KenLM. Intended usage:\n"
|
|
<< "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
|
|
<< argv[0] << " -v -m $model <$text >$text.vocab\n"
|
|
<< "#Ensure files are in RAM.\n"
|
|
<< "cat $text.vocab $model >/dev/null\n"
|
|
<< "#Timed query against the model.\n"
|
|
<< argv[0] << " -q -m $model <$text.vocab\n";
|
|
return 0;
|
|
}
|
|
po::notify(vm);
|
|
if (!(vm["vocab"].as<bool>() ^ vm["query"].as<bool>())) {
|
|
std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl;
|
|
return 0;
|
|
}
|
|
config.query = vm["query"].as<bool>();
|
|
if (!config.threads) {
|
|
std::cerr << "Specify a non-zero number of threads with -t." << std::endl;
|
|
}
|
|
Dispatch(model.c_str(), config);
|
|
} catch (const std::exception &e) {
|
|
std::cerr << e.what() << std::endl;
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|