Sync to the head.
This commit is contained in:
parent
4f94c911c1
commit
629ddf2e6d
@ -102,6 +102,7 @@ cc_library(
|
||||
":gru_cell",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:cpu_backend_support",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
|
@ -40,7 +40,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
float* output, float* output_state,
|
||||
const RuntimeShape& activation_shape, float* activation,
|
||||
const RuntimeShape& concat_shape, float* concat,
|
||||
const tflite::FullyConnectedParams& fc_params) {
|
||||
const tflite::FullyConnectedParams& fc_params,
|
||||
tflite::CpuBackendContext* cpu_backend_context) {
|
||||
const int n_batch = input_shape.Dims(0);
|
||||
const int n_input = input_shape.Dims(1);
|
||||
const int n_output = state_shape.Dims(1);
|
||||
@ -61,7 +62,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
// [r u] = [x h] * gate_weight + gate_bias
|
||||
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
||||
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
||||
activation);
|
||||
activation, cpu_backend_context);
|
||||
|
||||
// [r u] = sigmoid([r u])
|
||||
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
|
||||
@ -78,7 +79,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
// c = [x hr] * candidate_weight + candidate_bias
|
||||
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
|
||||
candidate_weight, candidate_bias_shape, candidate_bias,
|
||||
output_shape, output);
|
||||
output_shape, output, cpu_backend_context);
|
||||
|
||||
auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
|
||||
// output = (1 - u) .* tanh(c) + u .* h
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -34,7 +35,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
float* output, float* output_state,
|
||||
const RuntimeShape& activation_shape, float* activation,
|
||||
const RuntimeShape& concat_shape, float* concat,
|
||||
const tflite::FullyConnectedParams& fc_params);
|
||||
const tflite::FullyConnectedParams& fc_params,
|
||||
tflite::CpuBackendContext* cpu_backend_context);
|
||||
|
||||
} // namespace gru_cell
|
||||
} // namespace experimental
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
@ -31,7 +33,8 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
||||
const TfLiteTensor* candidate_weight,
|
||||
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
|
||||
TfLiteTensor* output_state, TfLiteTensor* activation,
|
||||
TfLiteTensor* concat) {
|
||||
TfLiteTensor* concat,
|
||||
tflite::CpuBackendContext* cpu_backend_context) {
|
||||
const int n_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
@ -61,13 +64,13 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||
for (int i = 0; i < n_time; ++i) {
|
||||
gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data,
|
||||
gate_weight_shape, gate_weight_data, gate_bias_shape,
|
||||
gate_bias_data, candidate_weight_shape,
|
||||
candidate_weight_data, candidate_bias_shape,
|
||||
candidate_bias_data, output_shape, output_data,
|
||||
output_state_data, activation_shape, activation_data,
|
||||
concat_shape, concat_data, fc_params);
|
||||
gru_cell::GruCell(
|
||||
input_shape, input_data, state_shape, input_state_data,
|
||||
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
|
||||
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
|
||||
candidate_bias_data, output_shape, output_data, output_state_data,
|
||||
activation_shape, activation_data, concat_shape, concat_data, fc_params,
|
||||
cpu_backend_context);
|
||||
input_data += n_batch_input;
|
||||
output_data += n_batch_output;
|
||||
input_state_data = output_state_data;
|
||||
@ -109,12 +112,14 @@ enum TemporaryTensor {
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
cpu_backend_support::IncrementUsageCounter(context);
|
||||
auto* scratch_tensor_index = new int;
|
||||
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
|
||||
return scratch_tensor_index;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
cpu_backend_support::DecrementUsageCounter(context);
|
||||
delete reinterpret_cast<int*>(buffer);
|
||||
}
|
||||
|
||||
@ -216,10 +221,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
||||
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
||||
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
||||
auto cpu_backend_context = cpu_backend_support::GetFromContext(context);
|
||||
|
||||
if (gate_weight->type == kTfLiteFloat32) {
|
||||
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
|
||||
candidate_bias, output, output_state, activation, concat);
|
||||
candidate_bias, output, output_state, activation, concat,
|
||||
cpu_backend_context);
|
||||
} else {
|
||||
context->ReportError(context,
|
||||
"Unsupported combination of data types for GruCell");
|
||||
|
Loading…
Reference in New Issue
Block a user