Merge pull request #27731 from mobvoi:tflite_gru_new

PiperOrigin-RevId: 246978454
This commit is contained in:
TensorFlower Gardener 2019-05-07 02:07:18 -07:00
commit dc77be8f92
5 changed files with 590 additions and 2 deletions

View File

@ -35,7 +35,7 @@ cc_library(
)
cc_library(
name = "experimental_ops",
name = "ctc_beam_search_decoder_op",
srcs = [
"ctc_beam_search_decoder.cc",
],
@ -66,7 +66,7 @@ cc_test(
srcs = ["ctc_beam_search_decoder_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":experimental_ops",
":ctc_beam_search_decoder_op",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:test_util",
@ -74,3 +74,54 @@ cc_test(
"@flatbuffers",
],
)
cc_library(
name = "gru_cell",
srcs = ["gru_cell.cc"],
hdrs = ["gru_cell.h"],
deps = [
"//tensorflow/lite/kernels:cpu_backend_context",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:tensor",
"//third_party/eigen3",
],
)
cc_library(
name = "unidirectional_sequence_gru_op",
srcs = [
"unidirectional_sequence_gru.cc",
],
# Suppress warnings that are introduced by Eigen Tensor.
copts = tflite_copts() + [
"-Wno-error=reorder",
] + select({
"//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
"//conditions:default": [
],
}),
deps = [
":gru_cell",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:cpu_backend_context",
"//tensorflow/lite/kernels:cpu_backend_support",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:tensor",
"@flatbuffers",
],
)
cc_test(
name = "unidirectional_sequence_gru_test",
size = "small",
srcs = ["unidirectional_sequence_gru_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":unidirectional_sequence_gru_op",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)

View File

@ -0,0 +1,94 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include <vector>
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
namespace tflite {
namespace ops {
namespace experimental {
namespace gru_cell {
using optimized_ops::ArrayMap;
using optimized_ops::FullyConnected;
using optimized_ops::MapAsArrayWithLastDimAsRows;
using reference_ops::Concatenation;
void GruCell(const RuntimeShape& input_shape, const float* input,
const RuntimeShape& state_shape, const float* input_state,
const RuntimeShape& gate_weight_shape, const float* gate_weight,
const RuntimeShape& gate_bias_shape, const float* gate_bias,
const RuntimeShape& candidate_weight_shape,
const float* candidate_weight,
const RuntimeShape& candidate_bias_shape,
const float* candidate_bias, const RuntimeShape& output_shape,
float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_batch = input_shape.Dims(0);
const int n_input = input_shape.Dims(1);
const int n_output = state_shape.Dims(1);
// [x h] = concat(input, state)
std::vector<float const*> concat_arrays_data;
std::vector<RuntimeShape const*> concat_arrays_shapes;
concat_arrays_data.push_back(input);
concat_arrays_data.push_back(input_state);
concat_arrays_shapes.push_back(&input_shape);
concat_arrays_shapes.push_back(&state_shape);
tflite::ConcatenationParams concat_params;
concat_params.axis = 1;
concat_params.inputs_count = concat_arrays_data.size();
Concatenation(concat_params, &(concat_arrays_shapes[0]),
&(concat_arrays_data[0]), concat_shape, concat);
// [r u] = [x h] * gate_weight + gate_bias
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
gate_weight, gate_bias_shape, gate_bias, activation_shape,
activation, cpu_backend_context);
// [r u] = sigmoid([r u])
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
ru = ru.unaryExpr(Eigen::internal::scalar_logistic_op<float>());
auto r = ru.block(0 * n_output, 0, n_output, n_batch);
auto u = ru.block(1 * n_output, 0, n_output, n_batch);
// hr = h .* r
auto h = MapAsArrayWithLastDimAsRows(input_state, state_shape);
auto xh = MapAsArrayWithLastDimAsRows(concat, concat_shape);
auto hr = xh.block(n_input, 0, n_output, n_batch);
hr = h * r;
// c = [x hr] * candidate_weight + candidate_bias
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
candidate_weight, candidate_bias_shape, candidate_bias,
output_shape, output, cpu_backend_context);
auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
// output = (1 - u) .* tanh(c) + u .* h
c = (1.0 - u) * c.tanh() + u * h;
memcpy(output_state, output, n_batch * n_output * sizeof(float));
}
} // namespace gru_cell
} // namespace experimental
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
namespace tflite {
namespace ops {
namespace experimental {
namespace gru_cell {
void GruCell(const RuntimeShape& input_shape, const float* input,
const RuntimeShape& state_shape, const float* input_state,
const RuntimeShape& gate_weight_shape, const float* gate_weight,
const RuntimeShape& gate_bias_shape, const float* gate_bias,
const RuntimeShape& candidate_weight_shape,
const float* candidate_weight,
const RuntimeShape& candidate_bias_shape,
const float* candidate_bias, const RuntimeShape& output_shape,
float* output, float* output_state,
const RuntimeShape& activation_shape, float* activation,
const RuntimeShape& concat_shape, float* concat,
const tflite::FullyConnectedParams& fc_params,
tflite::CpuBackendContext* cpu_backend_context);
} // namespace gru_cell
} // namespace experimental
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_

View File

@ -0,0 +1,250 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace experimental {
namespace unidirectional_sequence_gru {
namespace {
void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
const TfLiteTensor* candidate_weight,
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
TfLiteTensor* output_state, TfLiteTensor* activation,
TfLiteTensor* concat,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
const int n_output = output->dims->data[2];
const int n_batch_input = n_batch * n_input;
const int n_batch_output = n_batch * n_output;
const RuntimeShape input_shape({n_batch, n_input});
const float* input_data = GetTensorData<float>(input);
const RuntimeShape state_shape = GetTensorShape(input_state);
const float* input_state_data = GetTensorData<float>(input_state);
const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
const float* gate_weight_data = GetTensorData<float>(gate_weight);
const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
const float* gate_bias_data = GetTensorData<float>(gate_bias);
const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
const RuntimeShape activation_shape = GetTensorShape(activation);
const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
float* output_data = GetTensorData<float>(output);
float* output_state_data = GetTensorData<float>(output_state);
float* activation_data = GetTensorData<float>(activation);
const RuntimeShape concat_shape = GetTensorShape(concat);
float* concat_data = GetTensorData<float>(concat);
tflite::FullyConnectedParams fc_params;
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
fc_params.float_activation_max = std::numeric_limits<float>::max();
for (int i = 0; i < n_time; ++i) {
gru_cell::GruCell(
input_shape, input_data, state_shape, input_state_data,
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data, output_state_data,
activation_shape, activation_data, concat_shape, concat_data, fc_params,
cpu_backend_context);
input_data += n_batch_input;
output_data += n_batch_output;
input_state_data = output_state_data;
}
}
} // namespace
enum InputTensor {
// Input tensor of size [n_time, n_batch, n_input]
kInput = 0,
// Input state tensor of size [n_batch, n_output]
kInputState = 1,
// Gate weight tensor of size [2*n_output, n_input+n_output]
kGateWeight = 2,
// Gate bias tensor of size [2*n_output]
kGateBias = 3,
// Candidate weight tensor of size [n_output, n_input+n_output]
kCandidateWeight = 4,
// Candidate bias tensor of size [n_output]
kCandidateBias = 5,
kInputNum = 6
};
enum OutputTensor {
// Input tensor of size [n_time, n_batch, n_output]
kOutput = 0,
// Output state tensor of size [n_batch, n_output]
kOutputState = 1,
kOutputNum = 2
};
enum TemporaryTensor {
// Scratch buffer for activation of size [n_batch, 2*n_output]
kActivation = 0,
// Scratch buffer for activation of size [n_batch, n_input+n_output]
kConcat = 1,
kTemporaryNum = 2
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
cpu_backend_support::IncrementUsageCounter(context);
auto* scratch_tensor_index = new int;
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
return scratch_tensor_index;
}
void Free(TfLiteContext* context, void* buffer) {
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<int*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
// input's dim = [n_time, n_batch, n_input]
const TfLiteTensor* input = GetInput(context, node, kInput);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
// input_state's dim = [n_batch, n_output]
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
const int n_output = input_state->dims->data[1];
// gate_weight' dim = [2 * n_output, n_input + n_output]
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
// gate_bias' dim = [2 * n_output]
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
// candidate_weight' dim = [n_output, n_input + n_output]
const TfLiteTensor* candidate_weight =
GetInput(context, node, kCandidateWeight);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
n_input + n_output);
// candidate_bias' dim = [n_output]
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
// output's dim = [n_time, n_batch, n_output]
TfLiteTensor* output = GetOutput(context, node, kOutput);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = n_time;
output_size->data[1] = n_batch;
output_size->data[2] = n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
// output_state's dim = [n_batch, n_output]
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state,
TfLiteIntArrayCopy(input_state->dims)));
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
// activation's dim = [n_batch, 2 * n_output]
node->temporaries->data[kActivation] = *scratch_tensor_index;
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
activation->type = input->type;
activation->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
activation_size->data[0] = n_batch;
activation_size->data[1] = 2 * n_output;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, activation, activation_size));
// concat's dim = [n_batch, n_input + n_output]
node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
concat->type = input->type;
concat->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
concat_size->data[0] = n_batch;
concat_size->data[1] = n_input + n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, concat, concat_size));
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInput);
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
const TfLiteTensor* candidate_weight =
GetInput(context, node, kCandidateWeight);
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
TfLiteTensor* output = GetOutput(context, node, kOutput);
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
auto cpu_backend_context = cpu_backend_support::GetFromContext(context);
if (gate_weight->type == kTfLiteFloat32) {
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
candidate_bias, output, output_state, activation, concat,
cpu_backend_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for GruCell");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace unidirectional_sequence_gru
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
static TfLiteRegistration r = {
unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
return &r;
}
} // namespace experimental
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,147 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/test_util.h"
namespace tflite {
namespace ops {
namespace experimental {
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU();
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
class GRUOpModel : public SingleOpModel {
public:
explicit GRUOpModel(const std::vector<std::vector<int>>& input_shapes,
const TensorType& weight_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_FLOAT32);
input_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
gate_weight_ = AddInput(TensorType_FLOAT32);
gate_bias_ = AddInput(TensorType_FLOAT32);
candidate_weight_ = AddInput(TensorType_FLOAT32);
candidate_bias_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
output_state_ = AddOutput(TensorType_FLOAT32);
SetCustomOp("UNIDIRECTIONAL_SEQUENCE_GRU", {},
Register_UNIDIRECTIONAL_SEQUENCE_GRU);
BuildInterpreter(input_shapes);
}
void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
void SetInputState(const std::vector<float>& f) {
PopulateTensor(input_state_, f);
}
void SetGateWeight(const std::vector<float>& f) {
PopulateTensor(gate_weight_, f);
}
void SetGateBias(const std::vector<float>& f) {
PopulateTensor(gate_bias_, f);
}
void SetCandidateWeight(const std::vector<float>& f) {
PopulateTensor(candidate_weight_, f);
}
void SetCandidateBias(const std::vector<float>& f) {
PopulateTensor(candidate_bias_, f);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int num_batches() { return n_batch_; }
int num_inputs() { return n_input_; }
int num_outputs() { return n_output_; }
private:
int input_;
int input_state_;
int gate_weight_;
int gate_bias_;
int candidate_weight_;
int candidate_bias_;
int output_;
int output_state_;
int n_batch_;
int n_input_;
int n_output_;
};
TEST(GRUTest, SimpleTest) {
const int n_time = 2;
const int n_batch = 2;
const int n_input = 2;
const int n_output = 3;
GRUOpModel m({{n_time, n_batch, n_input},
{n_batch, n_output},
{2 * n_output, n_input + n_output},
{2 * n_output},
{n_output, n_input + n_output},
{n_output}});
// All data is randomly generated.
m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
0.93647677, 0.47361764, 0.39643127});
m.SetInputState(
{0.09992421, 0.3028481, 0.78305984, 0.50438094, 0.11269058, 0.10244724});
m.SetGateWeight({0.7256918, 0.8945897, 0.03285786, 0.42637166, 0.119376324,
0.83035135, 0.16997327, 0.42302176, 0.77598256, 0.2660894,
0.9587266, 0.6218451, 0.88164485, 0.12272458, 0.2699055,
0.18399088, 0.21930052, 0.3374841, 0.70866305, 0.9523419,
0.25170696, 0.60988617, 0.79823977, 0.64477515, 0.2602957,
0.5053131, 0.93722224, 0.8451359, 0.97905475, 0.38669217});
m.SetGateBias(
{0.032708533, 0.018445263, 0.15320699, 0.8163046, 0.26683575, 0.1412022});
m.SetCandidateWeight({0.96165305, 0.95572084, 0.11534478, 0.96965164,
0.33562955, 0.8680755, 0.003066936, 0.057793964,
0.8671354, 0.33354893, 0.7313398, 0.78492093,
0.19530584, 0.116550304, 0.13599132});
m.SetCandidateBias({0.89837056, 0.54769796, 0.63364106});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(n_time, n_batch, n_output));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{0.20112592, 0.45286041, 0.80842507, 0.59567153, 0.2619998,
0.22922856, 0.27715868, 0.5247152, 0.82300174, 0.65812796,
0.38217607, 0.3401444})));
}
} // namespace
} // namespace experimental
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}