Simple modification.
This commit is contained in:
parent
742ba7d1de
commit
4f94c911c1
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
@ -40,7 +39,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
const float* candidate_bias, const RuntimeShape& output_shape,
|
||||
float* output, float* output_state,
|
||||
const RuntimeShape& activation_shape, float* activation,
|
||||
const RuntimeShape& concat_shape, float* concat) {
|
||||
const RuntimeShape& concat_shape, float* concat,
|
||||
const tflite::FullyConnectedParams& fc_params) {
|
||||
const int n_batch = input_shape.Dims(0);
|
||||
const int n_input = input_shape.Dims(1);
|
||||
const int n_output = state_shape.Dims(1);
|
||||
@ -59,9 +59,6 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
&(concat_arrays_data[0]), concat_shape, concat);
|
||||
|
||||
// [r u] = [x h] * gate_weight + gate_bias
|
||||
tflite::FullyConnectedParams fc_params;
|
||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
||||
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
||||
activation);
|
||||
|
@ -33,7 +33,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
const float* candidate_bias, const RuntimeShape& output_shape,
|
||||
float* output, float* output_state,
|
||||
const RuntimeShape& activation_shape, float* activation,
|
||||
const RuntimeShape& concat_shape, float* concat);
|
||||
const RuntimeShape& concat_shape, float* concat,
|
||||
const tflite::FullyConnectedParams& fc_params);
|
||||
|
||||
} // namespace gru_cell
|
||||
} // namespace experimental
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
@ -55,13 +57,17 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
||||
float* activation_data = GetTensorData<float>(activation);
|
||||
const RuntimeShape concat_shape = GetTensorShape(concat);
|
||||
float* concat_data = GetTensorData<float>(concat);
|
||||
tflite::FullyConnectedParams fc_params;
|
||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||
for (int i = 0; i < n_time; ++i) {
|
||||
gru_cell::GruCell(
|
||||
input_shape, input_data, state_shape, input_state_data,
|
||||
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
|
||||
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
|
||||
candidate_bias_data, output_shape, output_data, output_state_data,
|
||||
activation_shape, activation_data, concat_shape, concat_data);
|
||||
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data,
|
||||
gate_weight_shape, gate_weight_data, gate_bias_shape,
|
||||
gate_bias_data, candidate_weight_shape,
|
||||
candidate_weight_data, candidate_bias_shape,
|
||||
candidate_bias_data, output_shape, output_data,
|
||||
output_state_data, activation_shape, activation_data,
|
||||
concat_shape, concat_data, fc_params);
|
||||
input_data += n_batch_input;
|
||||
output_data += n_batch_output;
|
||||
input_state_data = output_state_data;
|
||||
|
@ -110,6 +110,7 @@ TEST(GRUTest, SimpleTest) {
|
||||
{2 * n_output},
|
||||
{n_output, n_input + n_output},
|
||||
{n_output}});
|
||||
// All data is randomly generated.
|
||||
m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
|
||||
0.93647677, 0.47361764, 0.39643127});
|
||||
m.SetInputState(
|
||||
|
Loading…
x
Reference in New Issue
Block a user