Fix parameter inference from model shape
This commit is contained in:
parent
778f5deb9d
commit
799baf1f99
@ -323,14 +323,15 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
return;
|
return;
|
||||||
#endif // DS_NATIVE_MODEL
|
#endif // DS_NATIVE_MODEL
|
||||||
} else {
|
} else {
|
||||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, mfcc_feats_per_timestep}));
|
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
|
||||||
|
|
||||||
auto input_mapped = input.tensor<float, 3>();
|
auto input_mapped = input.flat<float>();
|
||||||
int idx = 0;
|
int i;
|
||||||
for (int i = 0; i < n_frames; i++) {
|
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||||
for (int j = 0; j < mfcc_feats_per_timestep; j++, idx++) {
|
input_mapped(i) = aMfcc[i];
|
||||||
input_mapped(0, i, j) = aMfcc[idx];
|
}
|
||||||
}
|
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||||
|
input_mapped(i) = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||||
@ -482,9 +483,8 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
if (node.name() == "input_node") {
|
if (node.name() == "input_node") {
|
||||||
const auto& shape = node.attr().at("shape").shape();
|
const auto& shape = node.attr().at("shape").shape();
|
||||||
model->n_steps = shape.dim(1).size();
|
model->n_steps = shape.dim(1).size();
|
||||||
model->mfcc_feats_per_timestep = shape.dim(2).size();
|
model->n_context = (shape.dim(2).size()-1)/2;
|
||||||
// mfcc_features_per_timestep = MFCC_FEATURES * ((2*n_context) + 1)
|
model->mfcc_feats_per_timestep = shape.dim(2).size() * shape.dim(3).size();
|
||||||
model->n_context = (model->mfcc_feats_per_timestep - MFCC_FEATURES) / (2 * MFCC_FEATURES);
|
|
||||||
} else if (node.name() == "logits_shape") {
|
} else if (node.name() == "logits_shape") {
|
||||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user