From fec5b73955e39cdfa9d8ba710aaa7f555aa0c4e6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Feb 2019 08:17:47 -0800 Subject: [PATCH] Explicitly passing the number of dimensions for RNN states. PiperOrigin-RevId: 236125208 --- tensorflow/lite/toco/model_flags.proto | 1 + tensorflow/lite/toco/tooling_util.cc | 28 ++++++++++++++++---------- tensorflow/lite/toco/tooling_util.h | 3 ++- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/toco/model_flags.proto b/tensorflow/lite/toco/model_flags.proto index bcdac295d26..dfc425073f5 100644 --- a/tensorflow/lite/toco/model_flags.proto +++ b/tensorflow/lite/toco/model_flags.proto @@ -85,6 +85,7 @@ message RnnState { // Will be expanded with 1's to fit the model. // TODO(benoitjacob): should allow a generic, explicit shape. optional int32 size = 3; + optional int32 num_dims = 4; } // An ArraysExtraInfo message stores a collection of additional Information diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 82a86d7fc88..08ec795cee3 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -27,11 +27,11 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" #include "re2/re2.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/dump_graphviz.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" namespace toco { @@ -1462,16 +1462,22 @@ void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, } } -void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { +void CreateOrCheckRnnStateArray(const string& name, int size, + int state_num_dims, Model* model) { int batch = 1; int num_dims = -1; - for (const auto& input_array : model->flags.input_arrays()) { - // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find - // a better match by name. - if (input_array.name() == name || num_dims == -1) { - num_dims = input_array.shape().dims_size(); - if (num_dims > 0) { - batch = input_array.shape().dims(0); + if (state_num_dims > 0) { + num_dims = state_num_dims; + } else { + // state_num_dims is not given. We will infer it from an input tensor. + for (const auto& input_array : model->flags.input_arrays()) { + // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find + // a better match by name. + if (input_array.name() == name || num_dims == -1) { + num_dims = input_array.shape().dims_size(); + if (num_dims > 0) { + batch = input_array.shape().dims(0); + } } } } @@ -1675,7 +1681,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { // Creation of the RNN state arrays for (const auto& rnn_state : model->flags.rnn_states()) { CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), - model); + rnn_state.num_dims(), model); } model->flags.set_change_concat_input_ranges( diff --git a/tensorflow/lite/toco/tooling_util.h b/tensorflow/lite/toco/tooling_util.h index fc4aac7740c..b8a3dfca933 100644 --- a/tensorflow/lite/toco/tooling_util.h +++ b/tensorflow/lite/toco/tooling_util.h @@ -250,7 +250,8 @@ void DropMinMax(Model* model, const string& array_name); bool IsAllocatableTransientArray(const Model& model, const string& array_name); -void CreateOrCheckRnnStateArray(const string& name, int size, Model* model); +void CreateOrCheckRnnStateArray(const string& name, int size, + int state_num_dims, Model* model); string AvailableArrayName(const Model& model, const string& name);