Sync to the head.

This commit is contained in:
sxwang 2019-04-30 11:12:30 +08:00
parent 4f94c911c1
commit 629ddf2e6d
4 changed files with 24 additions and 13 deletions

View File

@ -102,6 +102,7 @@ cc_library(
":gru_cell",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:cpu_backend_support",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:tensor",

View File

@ -40,7 +40,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params) {
const tflite::FullyConnectedParams& fc_params,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_batch = input_shape.Dims(0);
const int n_input = input_shape.Dims(1);
const int n_output = state_shape.Dims(1);
@ -61,7 +62,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
// [r u] = [x h] * gate_weight + gate_bias
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
gate_weight, gate_bias_shape, gate_bias, activation_shape,
activation);
activation, cpu_backend_context);
// [r u] = sigmoid([r u])
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
@ -78,7 +79,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
// c = [x hr] * candidate_weight + candidate_bias
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
candidate_weight, candidate_bias_shape, candidate_bias,
output_shape, output);
output_shape, output, cpu_backend_context);
auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
// output = (1 - u) .* tanh(c) + u .* h

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
namespace tflite {
@ -34,7 +35,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params);
const tflite::FullyConnectedParams& fc_params,
tflite::CpuBackendContext* cpu_backend_context);
} // namespace gru_cell
} // namespace experimental

View File

@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -31,7 +33,8 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
const TfLiteTensor* candidate_weight,
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
TfLiteTensor* output_state, TfLiteTensor* activation,
TfLiteTensor* concat) {
TfLiteTensor* concat,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@ -61,13 +64,13 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
fc_params.float_activation_max = std::numeric_limits<float>::max();
for (int i = 0; i < n_time; ++i) {
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data,
gate_weight_shape, gate_weight_data, gate_bias_shape,
gate_bias_data, candidate_weight_shape,
candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data,
output_state_data, activation_shape, activation_data,
concat_shape, concat_data, fc_params);
gru_cell::GruCell(
input_shape, input_data, state_shape, input_state_data,
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data, output_state_data,
activation_shape, activation_data, concat_shape, concat_data, fc_params,
cpu_backend_context);
input_data += n_batch_input;
output_data += n_batch_output;
input_state_data = output_state_data;
@ -109,12 +112,14 @@ enum TemporaryTensor {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
cpu_backend_support::IncrementUsageCounter(context);
auto* scratch_tensor_index = new int;
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
return scratch_tensor_index;
}
void Free(TfLiteContext* context, void* buffer) {
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<int*>(buffer);
}
@ -216,10 +221,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
auto cpu_backend_context = cpu_backend_support::GetFromContext(context);
if (gate_weight->type == kTfLiteFloat32) {
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
candidate_bias, output, output_state, activation, concat);
candidate_bias, output, output_state, activation, concat,
cpu_backend_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for GruCell");