Embed alphabet directly in model

This commit is contained in:
Reuben Morais 2019-10-31 15:02:01 +01:00
parent 493aaed151
commit 8c82081779
16 changed files with 164 additions and 112 deletions

View File

@ -780,12 +780,11 @@ def export():
graph_version = int(file_relative_read('GRAPH_VERSION').strip()) graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0 assert graph_version > 0
# Reshape with dimension [1] required to avoid this error:
# ERROR: Input array not provided for operation 'reshape'.
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language: if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language') outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')

View File

@ -1 +1 @@
4 5

View File

@ -36,7 +36,7 @@ public:
if (line == " ") { if (line == " ") {
space_label_ = label; space_label_ = label;
} }
label_to_str_.push_back(line); label_to_str_[label] = line;
str_to_label_[line] = label; str_to_label_[line] = label;
++label; ++label;
} }
@ -45,9 +45,53 @@ public:
return 0; return 0;
} }
int deserialize(const char* buffer, const int buffer_size) {
int offset = 0;
if (buffer_size - offset < sizeof(int16_t)) {
return 1;
}
int16_t size = *(int16_t*)(buffer + offset);
offset += sizeof(int16_t);
size_ = size;
for (int i = 0; i < size; ++i) {
if (buffer_size - offset < sizeof(int16_t)) {
return 1;
}
int16_t label = *(int16_t*)(buffer + offset);
offset += sizeof(int16_t);
if (buffer_size - offset < sizeof(int16_t)) {
return 1;
}
int16_t val_len = *(int16_t*)(buffer + offset);
offset += sizeof(int16_t);
if (buffer_size - offset < val_len) {
return 1;
}
std::string val(buffer+offset, val_len);
offset += val_len;
label_to_str_[label] = val;
str_to_label_[val] = label;
if (val == " ") {
space_label_ = label;
}
}
return 0;
}
const std::string& StringFromLabel(unsigned int label) const { const std::string& StringFromLabel(unsigned int label) const {
assert(label < size_); auto it = label_to_str_.find(label);
return label_to_str_[label]; if (it != label_to_str_.end()) {
return it->second;
} else {
std::cerr << "Invalid label " << label << std::endl;
abort();
}
} }
unsigned int LabelFromString(const std::string& string) const { unsigned int LabelFromString(const std::string& string) const {
@ -55,7 +99,7 @@ public:
if (it != str_to_label_.end()) { if (it != str_to_label_.end()) {
return it->second; return it->second;
} else { } else {
std::cerr << "Invalid label " << string << std::endl; std::cerr << "Invalid string " << string << std::endl;
abort(); abort();
} }
} }
@ -84,7 +128,7 @@ public:
private: private:
size_t size_; size_t size_;
unsigned int space_label_; unsigned int space_label_;
std::vector<std::string> label_to_str_; std::unordered_map<unsigned int, std::string> label_to_str_;
std::unordered_map<std::string, unsigned int> str_to_label_; std::unordered_map<std::string, unsigned int> str_to_label_;
}; };

View File

@ -369,7 +369,7 @@ main(int argc, char **argv)
// Initialise DeepSpeech // Initialise DeepSpeech
ModelState* ctx; ModelState* ctx;
int status = DS_CreateModel(model, alphabet, beam_width, &ctx); int status = DS_CreateModel(model, beam_width, &ctx);
if (status != 0) { if (status != 0) {
fprintf(stderr, "Could not create model.\n"); fprintf(stderr, "Could not create model.\n");
return 1; return 1;

View File

@ -18,9 +18,18 @@ class Scorer(swigwrapper.Scorer):
def __init__(self, alpha, beta, model_path, trie_path, alphabet): def __init__(self, alpha, beta, model_path, trie_path, alphabet):
super(Scorer, self).__init__() super(Scorer, self).__init__()
err = self.init(alpha, beta, model_path, trie_path, alphabet.config_file()) serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0: if err != 0:
raise ValueError("Scorer initialization failed with error code {}".format(err), err) raise ValueError("Error when deserializing alphabet.")
err = self.init(alpha, beta,
model_path.encode('utf-8'),
trie_path.encode('utf-8'),
native_alphabet)
if err != 0:
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
@ -35,8 +44,7 @@ def ctc_beam_search_decoder(probs_seq,
step, with each element being a list of normalized step, with each element being a list of normalized
probabilities over alphabet and blank. probabilities over alphabet and blank.
:type probs_seq: 2-D list :type probs_seq: 2-D list
:param alphabet: alphabet list. :param alphabet: Alphabet
:alphabet: Alphabet
:param beam_size: Width for beam search. :param beam_size: Width for beam search.
:type beam_size: int :type beam_size: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
@ -53,8 +61,13 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the confidence. results, in descending order of the confidence.
:rtype: list :rtype: list
""" """
serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0:
raise ValueError("Error when deserializing alphabet.")
beam_results = swigwrapper.ctc_beam_search_decoder( beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet.config_file(), beam_size, cutoff_prob, cutoff_top_n, probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n,
scorer) scorer)
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results] beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
return beam_results return beam_results
@ -95,9 +108,12 @@ def ctc_beam_search_decoder_batch(probs_seq,
results, in descending order of the confidence. results, in descending order of the confidence.
:rtype: list :rtype: list
""" """
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch( serialized = alphabet.serialize()
probs_seq, seq_lengths, alphabet.config_file(), beam_size, num_processes, native_alphabet = swigwrapper.Alphabet()
cutoff_prob, cutoff_top_n, scorer) err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0:
raise ValueError("Error when deserializing alphabet.")
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
batch_beam_results = [ batch_beam_results = [
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results] [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results for beam_results in batch_beam_results

View File

@ -3,6 +3,7 @@
%{ %{
#include "ctc_beam_search_decoder.h" #include "ctc_beam_search_decoder.h"
#define SWIG_FILE_WITH_INIT #define SWIG_FILE_WITH_INIT
#define SWIG_PYTHON_STRICT_BYTE_CHAR
%} %}
%include "pyabc.i" %include "pyabc.i"
@ -16,57 +17,12 @@ import_array();
// Convert NumPy arrays to pointer+lengths // Convert NumPy arrays to pointer+lengths
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)}; %apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)}; %apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)}; %apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
// Add overloads converting char* to Alphabet
%inline %{
std::vector<Output>
ctc_beam_search_decoder(const double *probs,
int time_dim,
int class_dim,
char* alphabet_config_path,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer)
{
Alphabet a;
if (a.init(alphabet_config_path)) {
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
}
return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
cutoff_prob, cutoff_top_n, ext_scorer);
}
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(const double *probs,
int batch_dim,
int time_dim,
int class_dim,
const int *seq_lengths,
int seq_lengths_size,
char* alphabet_config_path,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer)
{
Alphabet a;
if (a.init(alphabet_config_path)) {
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
}
return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
seq_lengths, seq_lengths_size, a, beam_size,
num_processes, cutoff_prob, cutoff_top_n,
ext_scorer);
}
%}
%ignore Scorer::dictionary; %ignore Scorer::dictionary;
%include "../alphabet.h"
%include "output.h" %include "output.h"
%include "scorer.h" %include "scorer.h"
%include "ctc_beam_search_decoder.h" %include "ctc_beam_search_decoder.h"

View File

@ -257,7 +257,6 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
int int
DS_CreateModel(const char* aModelPath, DS_CreateModel(const char* aModelPath,
const char* aAlphabetConfigPath,
unsigned int aBeamWidth, unsigned int aBeamWidth,
ModelState** retval) ModelState** retval)
{ {
@ -283,7 +282,7 @@ DS_CreateModel(const char* aModelPath,
return DS_ERR_FAIL_CREATE_MODEL; return DS_ERR_FAIL_CREATE_MODEL;
} }
int err = model->init(aModelPath, aAlphabetConfigPath, aBeamWidth); int err = model->init(aModelPath, aBeamWidth);
if (err != DS_ERR_OK) { if (err != DS_ERR_OK) {
return err; return err;
} }

View File

@ -77,8 +77,6 @@ enum DeepSpeech_Error_Codes
* @brief An object providing an interface to a trained DeepSpeech model. * @brief An object providing an interface to a trained DeepSpeech model.
* *
* @param aModelPath The path to the frozen model graph. * @param aModelPath The path to the frozen model graph.
* @param aAlphabetConfigPath The path to the configuration file specifying
* the alphabet used by the network. See alphabet.h.
* @param aBeamWidth The beam width used by the decoder. A larger beam * @param aBeamWidth The beam width used by the decoder. A larger beam
* width generates better results at the cost of decoding * width generates better results at the cost of decoding
* time. * time.
@ -88,7 +86,6 @@ enum DeepSpeech_Error_Codes
*/ */
DEEPSPEECH_EXPORT DEEPSPEECH_EXPORT
int DS_CreateModel(const char* aModelPath, int DS_CreateModel(const char* aModelPath,
const char* aAlphabetConfigPath,
unsigned int aBeamWidth, unsigned int aBeamWidth,
ModelState** retval); ModelState** retval);

View File

@ -13,8 +13,7 @@
* @param aModelPath The path to the frozen model graph. * @param aModelPath The path to the frozen model graph.
* @param aNCep UNUSED, DEPRECATED. * @param aNCep UNUSED, DEPRECATED.
* @param aNContext UNUSED, DEPRECATED. * @param aNContext UNUSED, DEPRECATED.
* @param aAlphabetConfigPath The path to the configuration file specifying * @param aAlphabetConfigPath UNUSED, DEPRECATED.
* the alphabet used by the network. See alphabet.h.
* @param aBeamWidth The beam width used by the decoder. A larger beam * @param aBeamWidth The beam width used by the decoder. A larger beam
* width generates better results at the cost of decoding * width generates better results at the cost of decoding
* time. * time.
@ -25,11 +24,11 @@
int DS_CreateModel(const char* aModelPath, int DS_CreateModel(const char* aModelPath,
unsigned int /*aNCep*/, unsigned int /*aNCep*/,
unsigned int /*aNContext*/, unsigned int /*aNContext*/,
const char* aAlphabetConfigPath, const char* /*aAlphabetConfigPath*/,
unsigned int aBeamWidth, unsigned int aBeamWidth,
ModelState** retval) ModelState** retval)
{ {
return DS_CreateModel(aModelPath, aAlphabetConfigPath, aBeamWidth, retval); return DS_CreateModel(aModelPath, aBeamWidth, retval);
} }
/** /**

View File

@ -25,12 +25,8 @@ ModelState::~ModelState()
int int
ModelState::init(const char* model_path, ModelState::init(const char* model_path,
const char* alphabet_path,
unsigned int beam_width) unsigned int beam_width)
{ {
if (alphabet_.init(alphabet_path)) {
return DS_ERR_INVALID_ALPHABET;
}
beam_width_ = beam_width; beam_width_ = beam_width;
return DS_ERR_OK; return DS_ERR_OK;
} }

View File

@ -30,9 +30,7 @@ struct ModelState {
ModelState(); ModelState();
virtual ~ModelState(); virtual ~ModelState();
virtual int init(const char* model_path, virtual int init(const char* model_path, unsigned int beam_width);
const char* alphabet_path,
unsigned int beam_width);
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0; virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;

View File

@ -1,5 +1,6 @@
#include "tflitemodelstate.h" #include "tflitemodelstate.h"
#include "tensorflow/lite/string_util.h"
#include "workspace_status.h" #include "workspace_status.h"
using namespace tflite; using namespace tflite;
@ -91,10 +92,9 @@ TFLiteModelState::~TFLiteModelState()
int int
TFLiteModelState::init(const char* model_path, TFLiteModelState::init(const char* model_path,
const char* alphabet_path,
unsigned int beam_width) unsigned int beam_width)
{ {
int err = ModelState::init(model_path, alphabet_path, beam_width); int err = ModelState::init(model_path, beam_width);
if (err != DS_ERR_OK) { if (err != DS_ERR_OK) {
return err; return err;
} }
@ -126,17 +126,17 @@ TFLiteModelState::init(const char* model_path,
mfccs_idx_ = get_output_tensor_by_name("mfccs"); mfccs_idx_ = get_output_tensor_by_name("mfccs");
int metadata_version_idx = get_output_tensor_by_name("metadata_version"); int metadata_version_idx = get_output_tensor_by_name("metadata_version");
// int metadata_language_idx = get_output_tensor_by_name("metadata_language");
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate"); int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len"); int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step"); int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
std::vector<int> metadata_exec_plan; std::vector<int> metadata_exec_plan;
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
// metadata_exec_plan.push_back(find_parent_node_ids(metadata_language_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
for (int i = 0; i < metadata_exec_plan.size(); ++i) { for (int i = 0; i < metadata_exec_plan.size(); ++i) {
assert(metadata_exec_plan[i] > -1); assert(metadata_exec_plan[i] > -1);
@ -200,6 +200,12 @@ TFLiteModelState::init(const char* model_path,
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0); audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0); audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
if (err != 0) {
return DS_ERR_INVALID_ALPHABET;
}
assert(sample_rate_ > 0); assert(sample_rate_ > 0);
assert(audio_win_len_ > 0); assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0); assert(audio_win_step_ > 0);

View File

@ -31,7 +31,6 @@ struct TFLiteModelState : public ModelState
virtual ~TFLiteModelState(); virtual ~TFLiteModelState();
virtual int init(const char* model_path, virtual int init(const char* model_path,
const char* alphabet_path,
unsigned int beam_width) override; unsigned int beam_width) override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer, virtual void compute_mfcc(const std::vector<float>& audio_buffer,

View File

@ -25,10 +25,9 @@ TFModelState::~TFModelState()
int int
TFModelState::init(const char* model_path, TFModelState::init(const char* model_path,
const char* alphabet_path,
unsigned int beam_width) unsigned int beam_width)
{ {
int err = ModelState::init(model_path, alphabet_path, beam_width); int err = ModelState::init(model_path, beam_width);
if (err != DS_ERR_OK) { if (err != DS_ERR_OK) {
return err; return err;
} }
@ -78,20 +77,16 @@ TFModelState::init(const char* model_path,
return DS_ERR_FAIL_CREATE_SESS; return DS_ERR_FAIL_CREATE_SESS;
} }
std::vector<tensorflow::Tensor> metadata_outputs; std::vector<tensorflow::Tensor> version_output;
status = session_->Run({}, { status = session_->Run({}, {
"metadata_version", "metadata_version"
// "metadata_language", }, {}, &version_output);
"metadata_sample_rate",
"metadata_feature_win_len",
"metadata_feature_win_step"
}, {}, &metadata_outputs);
if (!status.ok()) { if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl; std::cerr << "Unable to fetch graph version: " << status << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE; return DS_ERR_MODEL_INCOMPATIBLE;
} }
int graph_version = metadata_outputs[0].scalar<int>()(); int graph_version = version_output[0].scalar<int>()();
if (graph_version < ds_graph_version()) { if (graph_version < ds_graph_version()) {
std::cerr << "Specified model file version (" << graph_version << ") is " std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client (" << "incompatible with minimum version supported by this client ("
@ -101,12 +96,30 @@ TFModelState::init(const char* model_path,
return DS_ERR_MODEL_INCOMPATIBLE; return DS_ERR_MODEL_INCOMPATIBLE;
} }
sample_rate_ = metadata_outputs[1].scalar<int>()(); std::vector<tensorflow::Tensor> metadata_outputs;
int win_len_ms = metadata_outputs[2].scalar<int>()(); status = session_->Run({}, {
int win_step_ms = metadata_outputs[3].scalar<int>()(); "metadata_sample_rate",
"metadata_feature_win_len",
"metadata_feature_win_step",
"metadata_alphabet",
}, {}, &metadata_outputs);
if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
sample_rate_ = metadata_outputs[0].scalar<int>()();
int win_len_ms = metadata_outputs[1].scalar<int>()();
int win_step_ms = metadata_outputs[2].scalar<int>()();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0); audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0); audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
string serialized_alphabet = metadata_outputs[3].scalar<string>()();
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
if (err != 0) {
return DS_ERR_INVALID_ALPHABET;
}
assert(sample_rate_ > 0); assert(sample_rate_ > 0);
assert(audio_win_len_ > 0); assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0); assert(audio_win_step_ > 0);

View File

@ -19,7 +19,6 @@ struct TFModelState : public ModelState
virtual ~TFModelState(); virtual ~TFModelState();
virtual int init(const char* model_path, virtual int init(const char* model_path,
const char* alphabet_path,
unsigned int beam_width) override; unsigned int beam_width) override;
virtual void infer(const std::vector<float>& mfcc, virtual void infer(const std::vector<float>& mfcc,

View File

@ -1,27 +1,29 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import codecs import codecs
import re
import numpy as np import numpy as np
import re
import struct
from util.flags import FLAGS
from six.moves import range from six.moves import range
class Alphabet(object): class Alphabet(object):
def __init__(self, config_file): def __init__(self, config_file):
self._config_file = config_file self._config_file = config_file
self._label_to_str = [] self._label_to_str = {}
self._str_to_label = {} self._str_to_label = {}
self._size = 0 self._size = 0
with codecs.open(config_file, 'r', 'utf-8') as fin: if config_file:
for line in fin: with codecs.open(config_file, 'r', 'utf-8') as fin:
if line[0:2] == '\\#': for line in fin:
line = '#\n' if line[0:2] == '\\#':
elif line[0] == '#': line = '#\n'
continue elif line[0] == '#':
self._label_to_str += line[:-1] # remove the line ending continue
self._str_to_label[line[:-1]] = self._size self._label_to_str[self._size] = line[:-1] # remove the line ending
self._size += 1 self._str_to_label[line[:-1]] = self._size
self._size += 1
def _string_from_label(self, label): def _string_from_label(self, label):
return self._label_to_str[label] return self._label_to_str[label]
@ -51,6 +53,35 @@ class Alphabet(object):
res += self._string_from_label(label) res += self._string_from_label(label)
return res return res
def serialize(self):
res = bytearray()
res += struct.pack('<h', self._size)
for key, value in self._label_to_str.items():
value = value.encode('utf-8')
res += struct.pack('<hh{}s'.format(len(value)), key, len(value), value)
return bytes(res)
@staticmethod
def deserialize(buf):
#pylint: disable=protected-access
res = Alphabet(config_file=None)
offset = 0
def unpack_and_fwd(fmt, buf):
nonlocal offset
result = struct.unpack_from(fmt, buf, offset)
offset += struct.calcsize(fmt)
return result
res.size = unpack_and_fwd('<h', buf)[0]
for _ in range(res.size):
label, val_len = unpack_and_fwd('<hh', buf)
val = unpack_and_fwd('<{}s'.format(val_len), buf)[0].decode('utf-8')
res._label_to_str[label] = val
res._str_to_label[val] = label
return res
def size(self): def size(self):
return self._size return self._size