Explicitly passing the number of dimensions for RNN states.
PiperOrigin-RevId: 236125208
This commit is contained in:
parent
f38eea2aec
commit
fec5b73955
@ -85,6 +85,7 @@ message RnnState {
|
||||
// Will be expanded with 1's to fit the model.
|
||||
// TODO(benoitjacob): should allow a generic, explicit shape.
|
||||
optional int32 size = 3;
|
||||
optional int32 num_dims = 4;
|
||||
}
|
||||
|
||||
// An ArraysExtraInfo message stores a collection of additional Information
|
||||
|
@ -27,11 +27,11 @@ limitations under the License.
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "re2/re2.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/toco/dump_graphviz.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
@ -1462,16 +1462,22 @@ void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
|
||||
}
|
||||
}
|
||||
|
||||
void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
|
||||
void CreateOrCheckRnnStateArray(const string& name, int size,
|
||||
int state_num_dims, Model* model) {
|
||||
int batch = 1;
|
||||
int num_dims = -1;
|
||||
for (const auto& input_array : model->flags.input_arrays()) {
|
||||
// Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
|
||||
// a better match by name.
|
||||
if (input_array.name() == name || num_dims == -1) {
|
||||
num_dims = input_array.shape().dims_size();
|
||||
if (num_dims > 0) {
|
||||
batch = input_array.shape().dims(0);
|
||||
if (state_num_dims > 0) {
|
||||
num_dims = state_num_dims;
|
||||
} else {
|
||||
// state_num_dims is not given. We will infer it from an input tensor.
|
||||
for (const auto& input_array : model->flags.input_arrays()) {
|
||||
// Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
|
||||
// a better match by name.
|
||||
if (input_array.name() == name || num_dims == -1) {
|
||||
num_dims = input_array.shape().dims_size();
|
||||
if (num_dims > 0) {
|
||||
batch = input_array.shape().dims(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1675,7 +1681,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
|
||||
// Creation of the RNN state arrays
|
||||
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
|
||||
model);
|
||||
rnn_state.num_dims(), model);
|
||||
}
|
||||
|
||||
model->flags.set_change_concat_input_ranges(
|
||||
|
@ -250,7 +250,8 @@ void DropMinMax(Model* model, const string& array_name);
|
||||
|
||||
bool IsAllocatableTransientArray(const Model& model, const string& array_name);
|
||||
|
||||
void CreateOrCheckRnnStateArray(const string& name, int size, Model* model);
|
||||
void CreateOrCheckRnnStateArray(const string& name, int size,
|
||||
int state_num_dims, Model* model);
|
||||
|
||||
string AvailableArrayName(const Model& model, const string& name);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user