diff --git a/tensorflow/lite/experimental/kernels/BUILD b/tensorflow/lite/experimental/kernels/BUILD index 82bfdf60d18..fa171786683 100644 --- a/tensorflow/lite/experimental/kernels/BUILD +++ b/tensorflow/lite/experimental/kernels/BUILD @@ -102,6 +102,7 @@ cc_library( ":gru_cell", "//tensorflow/lite:framework", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:cpu_backend_support", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels/internal:tensor", diff --git a/tensorflow/lite/experimental/kernels/gru_cell.cc b/tensorflow/lite/experimental/kernels/gru_cell.cc index 4fc308c085c..2801db85a32 100644 --- a/tensorflow/lite/experimental/kernels/gru_cell.cc +++ b/tensorflow/lite/experimental/kernels/gru_cell.cc @@ -40,7 +40,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input, float* output, float* output_state, const RuntimeShape& activation_shape, float* activation, const RuntimeShape& concat_shape, float* concat, - const tflite::FullyConnectedParams& fc_params) { + const tflite::FullyConnectedParams& fc_params, + tflite::CpuBackendContext* cpu_backend_context) { const int n_batch = input_shape.Dims(0); const int n_input = input_shape.Dims(1); const int n_output = state_shape.Dims(1); @@ -61,7 +62,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input, // [r u] = [x h] * gate_weight + gate_bias FullyConnected(fc_params, concat_shape, concat, gate_weight_shape, gate_weight, gate_bias_shape, gate_bias, activation_shape, - activation); + activation, cpu_backend_context); // [r u] = sigmoid([r u]) auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape); @@ -78,7 +79,7 @@ void GruCell(const RuntimeShape& input_shape, const float* input, // c = [x hr] * candidate_weight + candidate_bias FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape, candidate_weight, candidate_bias_shape, candidate_bias, - output_shape, output); + output_shape, output, cpu_backend_context); auto c = MapAsArrayWithLastDimAsRows(output, output_shape); // output = (1 - u) .* tanh(c) + u .* h diff --git a/tensorflow/lite/experimental/kernels/gru_cell.h b/tensorflow/lite/experimental/kernels/gru_cell.h index 195614bb79f..cd7b02e2a69 100644 --- a/tensorflow/lite/experimental/kernels/gru_cell.h +++ b/tensorflow/lite/experimental/kernels/gru_cell.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/tensor.h" namespace tflite { @@ -34,7 +35,8 @@ void GruCell(const RuntimeShape& input_shape, const float* input, float* output, float* output_state, const RuntimeShape& activation_shape, float* activation, const RuntimeShape& concat_shape, float* concat, - const tflite::FullyConnectedParams& fc_params); + const tflite::FullyConnectedParams& fc_params, + tflite::CpuBackendContext* cpu_backend_context); } // namespace gru_cell } // namespace experimental diff --git a/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc index d91bd3bd044..fc0d681f3bc 100644 --- a/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc +++ b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/kernels/gru_cell.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" +#include "tensorflow/lite/kernels/cpu_backend_support.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -31,7 +33,8 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state, const TfLiteTensor* candidate_weight, const TfLiteTensor* candidate_bias, TfLiteTensor* output, TfLiteTensor* output_state, TfLiteTensor* activation, - TfLiteTensor* concat) { + TfLiteTensor* concat, + tflite::CpuBackendContext* cpu_backend_context) { const int n_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -61,13 +64,13 @@ void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state, fc_params.float_activation_min = std::numeric_limits::lowest(); fc_params.float_activation_max = std::numeric_limits::max(); for (int i = 0; i < n_time; ++i) { - gru_cell::GruCell(input_shape, input_data, state_shape, input_state_data, - gate_weight_shape, gate_weight_data, gate_bias_shape, - gate_bias_data, candidate_weight_shape, - candidate_weight_data, candidate_bias_shape, - candidate_bias_data, output_shape, output_data, - output_state_data, activation_shape, activation_data, - concat_shape, concat_data, fc_params); + gru_cell::GruCell( + input_shape, input_data, state_shape, input_state_data, + gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data, + candidate_weight_shape, candidate_weight_data, candidate_bias_shape, + candidate_bias_data, output_shape, output_data, output_state_data, + activation_shape, activation_data, concat_shape, concat_data, fc_params, + cpu_backend_context); input_data += n_batch_input; output_data += n_batch_output; input_state_data = output_state_data; @@ -109,12 +112,14 @@ enum TemporaryTensor { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { + cpu_backend_support::IncrementUsageCounter(context); auto* scratch_tensor_index = new int; context->AddTensors(context, kTemporaryNum, scratch_tensor_index); return scratch_tensor_index; } void Free(TfLiteContext* context, void* buffer) { + cpu_backend_support::DecrementUsageCounter(context); delete reinterpret_cast(buffer); } @@ -216,10 +221,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output_state = GetOutput(context, node, kOutputState); TfLiteTensor* activation = GetTemporary(context, node, kActivation); TfLiteTensor* concat = GetTemporary(context, node, kConcat); + auto cpu_backend_context = cpu_backend_support::GetFromContext(context); if (gate_weight->type == kTfLiteFloat32) { GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight, - candidate_bias, output, output_state, activation, concat); + candidate_bias, output, output_state, activation, concat, + cpu_backend_context); } else { context->ReportError(context, "Unsupported combination of data types for GruCell");