Sync to the head.
This commit is contained in:
parent
4f94c911c1
commit
629ddf2e6d
@ -102,6 +102,7 @@ cc_library(
|
|||||||
":gru_cell",
|
":gru_cell",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/kernels:cpu_backend_support",
|
||||||
"//tensorflow/lite/kernels:kernel_util",
|
"//tensorflow/lite/kernels:kernel_util",
|
||||||
"//tensorflow/lite/kernels:op_macros",
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
"//tensorflow/lite/kernels/internal:tensor",
|
"//tensorflow/lite/kernels/internal:tensor",
|
||||||
|
@ -40,7 +40,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
float* output, float* output_state,
|
float* output, float* output_state,
|
||||||
const RuntimeShape& activation_shape, float* activation,
|
const RuntimeShape& activation_shape, float* activation,
|
||||||
const RuntimeShape& concat_shape, float* concat,
|
const RuntimeShape& concat_shape, float* concat,
|
||||||
const tflite::FullyConnectedParams& fc_params) {
|
const tflite::FullyConnectedParams& fc_params,
|
||||||
|
tflite::CpuBackendContext* cpu_backend_context) {
|
||||||
const int n_batch = input_shape.Dims(0);
|
const int n_batch = input_shape.Dims(0);
|
||||||
const int n_input = input_shape.Dims(1);
|
const int n_input = input_shape.Dims(1);
|
||||||
const int n_output = state_shape.Dims(1);
|
const int n_output = state_shape.Dims(1);
|
||||||
@ -61,7 +62,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
// [r u] = [x h] * gate_weight + gate_bias
|
// [r u] = [x h] * gate_weight + gate_bias
|
||||||
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
||||||
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
||||||
activation);
|
activation, cpu_backend_context);
|
||||||
|
|
||||||
// [r u] = sigmoid([r u])
|
// [r u] = sigmoid([r u])
|
||||||
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
|
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
|
||||||
@ -78,7 +79,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
// c = [x hr] * candidate_weight + candidate_bias
|
// c = [x hr] * candidate_weight + candidate_bias
|
||||||
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
|
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
|
||||||
candidate_weight, candidate_bias_shape, candidate_bias,
|
candidate_weight, candidate_bias_shape, candidate_bias,
|
||||||
output_shape, output);
|
output_shape, output, cpu_backend_context);
|
||||||
|
|
||||||
auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
|
auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
|
||||||
// output = (1 - u) .* tanh(c) + u .* h
|
// output = (1 - u) .* tanh(c) + u .* h
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
||||||
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -34,7 +35,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
|||||||
float* output, float* output_state,
|
float* output, float* output_state,
|
||||||
const RuntimeShape& activation_shape, float* activation,
|
const RuntimeShape& activation_shape, float* activation,
|
||||||
const RuntimeShape& concat_shape, float* concat,
|
const RuntimeShape& concat_shape, float* concat,
|
||||||
const tflite::FullyConnectedParams& fc_params);
|
const tflite::FullyConnectedParams& fc_params,
|
||||||
|
tflite::CpuBackendContext* cpu_backend_context);
|
||||||
|
|
||||||
} // namespace gru_cell
|
} // namespace gru_cell
|
||||||
} // namespace experimental
|
} // namespace experimental
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
@ -31,7 +33,8 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
|||||||
const TfLiteTensor* candidate_weight,
|
const TfLiteTensor* candidate_weight,
|
||||||
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
|
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
|
||||||
TfLiteTensor* output_state, TfLiteTensor* activation,
|
TfLiteTensor* output_state, TfLiteTensor* activation,
|
||||||
TfLiteTensor* concat) {
|
TfLiteTensor* concat,
|
||||||
|
tflite::CpuBackendContext* cpu_backend_context) {
|
||||||
const int n_time = input->dims->data[0];
|
const int n_time = input->dims->data[0];
|
||||||
const int n_batch = input->dims->data[1];
|
const int n_batch = input->dims->data[1];
|
||||||
const int n_input = input->dims->data[2];
|
const int n_input = input->dims->data[2];
|
||||||
@ -61,13 +64,13 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
|||||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||||
for (int i = 0; i < n_time; ++i) {
|
for (int i = 0; i < n_time; ++i) {
|
||||||
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data,
|
gru_cell::GruCell(
|
||||||
gate_weight_shape, gate_weight_data, gate_bias_shape,
|
input_shape, input_data, state_shape, input_state_data,
|
||||||
gate_bias_data, candidate_weight_shape,
|
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
|
||||||
candidate_weight_data, candidate_bias_shape,
|
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
|
||||||
candidate_bias_data, output_shape, output_data,
|
candidate_bias_data, output_shape, output_data, output_state_data,
|
||||||
output_state_data, activation_shape, activation_data,
|
activation_shape, activation_data, concat_shape, concat_data, fc_params,
|
||||||
concat_shape, concat_data, fc_params);
|
cpu_backend_context);
|
||||||
input_data += n_batch_input;
|
input_data += n_batch_input;
|
||||||
output_data += n_batch_output;
|
output_data += n_batch_output;
|
||||||
input_state_data = output_state_data;
|
input_state_data = output_state_data;
|
||||||
@ -109,12 +112,14 @@ enum TemporaryTensor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
cpu_backend_support::IncrementUsageCounter(context);
|
||||||
auto* scratch_tensor_index = new int;
|
auto* scratch_tensor_index = new int;
|
||||||
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
|
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
|
||||||
return scratch_tensor_index;
|
return scratch_tensor_index;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
|
cpu_backend_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<int*>(buffer);
|
delete reinterpret_cast<int*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,10 +221,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
||||||
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
||||||
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
||||||
|
auto cpu_backend_context = cpu_backend_support::GetFromContext(context);
|
||||||
|
|
||||||
if (gate_weight->type == kTfLiteFloat32) {
|
if (gate_weight->type == kTfLiteFloat32) {
|
||||||
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
|
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
|
||||||
candidate_bias, output, output_state, activation, concat);
|
candidate_bias, output, output_state, activation, concat,
|
||||||
|
cpu_backend_context);
|
||||||
} else {
|
} else {
|
||||||
context->ReportError(context,
|
context->ReportError(context,
|
||||||
"Unsupported combination of data types for GruCell");
|
"Unsupported combination of data types for GruCell");
|
||||||
|
Loading…
Reference in New Issue
Block a user