Simple modification.
This commit is contained in:
parent
742ba7d1de
commit
4f94c911c1
@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||||
|
|
||||||
#include <limits>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
@ -40,7 +39,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
const float* candidate_bias, const RuntimeShape& output_shape,
|
const float* candidate_bias, const RuntimeShape& output_shape,
|
||||||
float* output, float* output_state,
|
float* output, float* output_state,
|
||||||
const RuntimeShape& activation_shape, float* activation,
|
const RuntimeShape& activation_shape, float* activation,
|
||||||
const RuntimeShape& concat_shape, float* concat) {
|
const RuntimeShape& concat_shape, float* concat,
|
||||||
|
const tflite::FullyConnectedParams& fc_params) {
|
||||||
const int n_batch = input_shape.Dims(0);
|
const int n_batch = input_shape.Dims(0);
|
||||||
const int n_input = input_shape.Dims(1);
|
const int n_input = input_shape.Dims(1);
|
||||||
const int n_output = state_shape.Dims(1);
|
const int n_output = state_shape.Dims(1);
|
||||||
@ -59,9 +59,6 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
&(concat_arrays_data[0]), concat_shape, concat);
|
&(concat_arrays_data[0]), concat_shape, concat);
|
||||||
|
|
||||||
// [r u] = [x h] * gate_weight + gate_bias
|
// [r u] = [x h] * gate_weight + gate_bias
|
||||||
tflite::FullyConnectedParams fc_params;
|
|
||||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
|
||||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
|
||||||
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
||||||
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
||||||
activation);
|
activation);
|
||||||
|
@ -33,7 +33,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
const float* candidate_bias, const RuntimeShape& output_shape,
|
const float* candidate_bias, const RuntimeShape& output_shape,
|
||||||
float* output, float* output_state,
|
float* output, float* output_state,
|
||||||
const RuntimeShape& activation_shape, float* activation,
|
const RuntimeShape& activation_shape, float* activation,
|
||||||
const RuntimeShape& concat_shape, float* concat);
|
const RuntimeShape& concat_shape, float* concat,
|
||||||
|
const tflite::FullyConnectedParams& fc_params);
|
||||||
|
|
||||||
} // namespace gru_cell
|
} // namespace gru_cell
|
||||||
} // namespace experimental
|
} // namespace experimental
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
@ -55,13 +57,17 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
|||||||
float* activation_data = GetTensorData<float>(activation);
|
float* activation_data = GetTensorData<float>(activation);
|
||||||
const RuntimeShape concat_shape = GetTensorShape(concat);
|
const RuntimeShape concat_shape = GetTensorShape(concat);
|
||||||
float* concat_data = GetTensorData<float>(concat);
|
float* concat_data = GetTensorData<float>(concat);
|
||||||
|
tflite::FullyConnectedParams fc_params;
|
||||||
|
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||||
|
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||||
for (int i = 0; i < n_time; ++i) {
|
for (int i = 0; i < n_time; ++i) {
|
||||||
gru_cell::GruCell(
|
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data,
|
||||||
input_shape, input_data, state_shape, input_state_data,
|
gate_weight_shape, gate_weight_data, gate_bias_shape,
|
||||||
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
|
gate_bias_data, candidate_weight_shape,
|
||||||
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
|
candidate_weight_data, candidate_bias_shape,
|
||||||
candidate_bias_data, output_shape, output_data, output_state_data,
|
candidate_bias_data, output_shape, output_data,
|
||||||
activation_shape, activation_data, concat_shape, concat_data);
|
output_state_data, activation_shape, activation_data,
|
||||||
|
concat_shape, concat_data, fc_params);
|
||||||
input_data += n_batch_input;
|
input_data += n_batch_input;
|
||||||
output_data += n_batch_output;
|
output_data += n_batch_output;
|
||||||
input_state_data = output_state_data;
|
input_state_data = output_state_data;
|
||||||
|
@ -110,6 +110,7 @@ TEST(GRUTest, SimpleTest) {
|
|||||||
{2 * n_output},
|
{2 * n_output},
|
||||||
{n_output, n_input + n_output},
|
{n_output, n_input + n_output},
|
||||||
{n_output}});
|
{n_output}});
|
||||||
|
// All data is randomly generated.
|
||||||
m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
|
m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
|
||||||
0.93647677, 0.47361764, 0.39643127});
|
0.93647677, 0.47361764, 0.39643127});
|
||||||
m.SetInputState(
|
m.SetInputState(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user