Simple modification.

This commit is contained in:
sxwang 2019-04-29 16:33:13 +08:00
parent 742ba7d1de
commit 4f94c911c1
4 changed files with 17 additions and 12 deletions

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include <limits>
#include <vector>
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
@ -40,7 +39,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
const float* candidate_bias, const RuntimeShape& output_shape,
float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat) {
const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params) {
const int n_batch = input_shape.Dims(0);
const int n_input = input_shape.Dims(1);
const int n_output = state_shape.Dims(1);
@ -59,9 +59,6 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
&(concat_arrays_data[0]), concat_shape, concat);
// [r u] = [x h] * gate_weight + gate_bias
tflite::FullyConnectedParams fc_params;
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
fc_params.float_activation_max = std::numeric_limits<float>::max();
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
gate_weight, gate_bias_shape, gate_bias, activation_shape,
activation);

View File

@ -33,7 +33,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
const float* candidate_bias, const RuntimeShape& output_shape,
float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat);
const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params);
} // namespace gru_cell
} // namespace experimental

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
@ -55,13 +57,17 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
float* activation_data = GetTensorData<float>(activation);
const RuntimeShape concat_shape = GetTensorShape(concat);
float* concat_data = GetTensorData<float>(concat);
tflite::FullyConnectedParams fc_params;
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
fc_params.float_activation_max = std::numeric_limits<float>::max();
for (int i = 0; i < n_time; ++i) {
gru_cell::GruCell(
input_shape, input_data, state_shape, input_state_data,
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data, output_state_data,
activation_shape, activation_data, concat_shape, concat_data);
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data,
gate_weight_shape, gate_weight_data, gate_bias_shape,
gate_bias_data, candidate_weight_shape,
candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data,
output_state_data, activation_shape, activation_data,
concat_shape, concat_data, fc_params);
input_data += n_batch_input;
output_data += n_batch_output;
input_state_data = output_state_data;

View File

@ -110,6 +110,7 @@ TEST(GRUTest, SimpleTest) {
{2 * n_output},
{n_output, n_input + n_output},
{n_output}});
// All data is randomly generated.
m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
0.93647677, 0.47361764, 0.39643127});
m.SetInputState(