Explicitly passing the number of dimensions for RNN states.

PiperOrigin-RevId: 236125208
This commit is contained in:
A. Unique TensorFlower 2019-02-28 08:17:47 -08:00 committed by TensorFlower Gardener
parent f38eea2aec
commit fec5b73955
3 changed files with 20 additions and 12 deletions

View File

@ -85,6 +85,7 @@ message RnnState {
// Will be expanded with 1's to fit the model.
// TODO(benoitjacob): should allow a generic, explicit shape.
optional int32 size = 3;
optional int32 num_dims = 4;
}
// An ArraysExtraInfo message stores a collection of additional Information

View File

@ -27,11 +27,11 @@ limitations under the License.
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "re2/re2.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/dump_graphviz.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
@ -1462,16 +1462,22 @@ void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
}
}
void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
void CreateOrCheckRnnStateArray(const string& name, int size,
int state_num_dims, Model* model) {
int batch = 1;
int num_dims = -1;
for (const auto& input_array : model->flags.input_arrays()) {
// Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
// a better match by name.
if (input_array.name() == name || num_dims == -1) {
num_dims = input_array.shape().dims_size();
if (num_dims > 0) {
batch = input_array.shape().dims(0);
if (state_num_dims > 0) {
num_dims = state_num_dims;
} else {
// state_num_dims is not given. We will infer it from an input tensor.
for (const auto& input_array : model->flags.input_arrays()) {
// Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
// a better match by name.
if (input_array.name() == name || num_dims == -1) {
num_dims = input_array.shape().dims_size();
if (num_dims > 0) {
batch = input_array.shape().dims(0);
}
}
}
}
@ -1675,7 +1681,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
model);
rnn_state.num_dims(), model);
}
model->flags.set_change_concat_input_ranges(

View File

@ -250,7 +250,8 @@ void DropMinMax(Model* model, const string& array_name);
bool IsAllocatableTransientArray(const Model& model, const string& array_name);
void CreateOrCheckRnnStateArray(const string& name, int size, Model* model);
void CreateOrCheckRnnStateArray(const string& name, int size,
int state_num_dims, Model* model);
string AvailableArrayName(const Model& model, const string& name);