Sync to the head.

This commit is contained in:
sxwang 2019-04-30 11:12:30 +08:00
parent 4f94c911c1
commit 629ddf2e6d
4 changed files with 24 additions and 13 deletions

View File

@ -102,6 +102,7 @@ cc_library(
":gru_cell", ":gru_cell",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:cpu_backend_support",
"//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor",

View File

@ -40,7 +40,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
float* output, float* output_state, float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation, const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat, const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params) { const tflite::FullyConnectedParams& fc_params,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_batch = input_shape.Dims(0); const int n_batch = input_shape.Dims(0);
const int n_input = input_shape.Dims(1); const int n_input = input_shape.Dims(1);
const int n_output = state_shape.Dims(1); const int n_output = state_shape.Dims(1);
@ -61,7 +62,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
// [r u] = [x h] * gate_weight + gate_bias // [r u] = [x h] * gate_weight + gate_bias
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape, FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
gate_weight, gate_bias_shape, gate_bias, activation_shape, gate_weight, gate_bias_shape, gate_bias, activation_shape,
activation); activation, cpu_backend_context);
// [r u] = sigmoid([r u]) // [r u] = sigmoid([r u])
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape); auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
@ -78,7 +79,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
// c = [x hr] * candidate_weight + candidate_bias // c = [x hr] * candidate_weight + candidate_bias
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape, FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
candidate_weight, candidate_bias_shape, candidate_bias, candidate_weight, candidate_bias_shape, candidate_bias,
output_shape, output); output_shape, output, cpu_backend_context);
auto c = MapAsArrayWithLastDimAsRows(output, output_shape); auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
// output = (1 - u) .* tanh(c) + u .* h // output = (1 - u) .* tanh(c) + u .* h

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
namespace tflite { namespace tflite {
@ -34,7 +35,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
float* output, float* output_state, float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation, const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat, const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params); const tflite::FullyConnectedParams& fc_params,
tflite::CpuBackendContext* cpu_backend_context);
} // namespace gru_cell } // namespace gru_cell
} // namespace experimental } // namespace experimental

View File

@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/kernels/gru_cell.h" #include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -31,7 +33,8 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
const TfLiteTensor* candidate_weight, const TfLiteTensor* candidate_weight,
const TfLiteTensor* candidate_bias, TfLiteTensor* output, const TfLiteTensor* candidate_bias, TfLiteTensor* output,
TfLiteTensor* output_state, TfLiteTensor* activation, TfLiteTensor* output_state, TfLiteTensor* activation,
TfLiteTensor* concat) { TfLiteTensor* concat,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_time = input->dims->data[0]; const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1]; const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2]; const int n_input = input->dims->data[2];
@ -61,13 +64,13 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
fc_params.float_activation_min = std::numeric_limits<float>::lowest(); fc_params.float_activation_min = std::numeric_limits<float>::lowest();
fc_params.float_activation_max = std::numeric_limits<float>::max(); fc_params.float_activation_max = std::numeric_limits<float>::max();
for (int i = 0; i < n_time; ++i) { for (int i = 0; i < n_time; ++i) {
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data, gru_cell::GruCell(
gate_weight_shape, gate_weight_data, gate_bias_shape, input_shape, input_data, state_shape, input_state_data,
gate_bias_data, candidate_weight_shape, gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
candidate_weight_data, candidate_bias_shape, candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data, candidate_bias_data, output_shape, output_data, output_state_data,
output_state_data, activation_shape, activation_data, activation_shape, activation_data, concat_shape, concat_data, fc_params,
concat_shape, concat_data, fc_params); cpu_backend_context);
input_data += n_batch_input; input_data += n_batch_input;
output_data += n_batch_output; output_data += n_batch_output;
input_state_data = output_state_data; input_state_data = output_state_data;
@ -109,12 +112,14 @@ enum TemporaryTensor {
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
cpu_backend_support::IncrementUsageCounter(context);
auto* scratch_tensor_index = new int; auto* scratch_tensor_index = new int;
context->AddTensors(context, kTemporaryNum, scratch_tensor_index); context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
return scratch_tensor_index; return scratch_tensor_index;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<int*>(buffer); delete reinterpret_cast<int*>(buffer);
} }
@ -216,10 +221,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_state = GetOutput(context, node, kOutputState); TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
TfLiteTensor* activation = GetTemporary(context, node, kActivation); TfLiteTensor* activation = GetTemporary(context, node, kActivation);
TfLiteTensor* concat = GetTemporary(context, node, kConcat); TfLiteTensor* concat = GetTemporary(context, node, kConcat);
auto cpu_backend_context = cpu_backend_support::GetFromContext(context);
if (gate_weight->type == kTfLiteFloat32) { if (gate_weight->type == kTfLiteFloat32) {
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight, GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
candidate_bias, output, output_state, activation, concat); candidate_bias, output, output_state, activation, concat,
cpu_backend_context);
} else { } else {
context->ReportError(context, context->ReportError(context,
"Unsupported combination of data types for GruCell"); "Unsupported combination of data types for GruCell");