Merge pull request #27731 from mobvoi:tflite_gru_new
PiperOrigin-RevId: 246978454
This commit is contained in:
commit
dc77be8f92
@ -35,7 +35,7 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "experimental_ops",
|
||||
name = "ctc_beam_search_decoder_op",
|
||||
srcs = [
|
||||
"ctc_beam_search_decoder.cc",
|
||||
],
|
||||
@ -66,7 +66,7 @@ cc_test(
|
||||
srcs = ["ctc_beam_search_decoder_test.cc"],
|
||||
tags = ["tflite_not_portable_ios"],
|
||||
deps = [
|
||||
":experimental_ops",
|
||||
":ctc_beam_search_decoder_op",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
@ -74,3 +74,54 @@ cc_test(
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gru_cell",
|
||||
srcs = ["gru_cell.cc"],
|
||||
hdrs = ["gru_cell.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||
"//tensorflow/lite/kernels/internal:optimized_base",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unidirectional_sequence_gru_op",
|
||||
srcs = [
|
||||
"unidirectional_sequence_gru.cc",
|
||||
],
|
||||
# Suppress warnings that are introduced by Eigen Tensor.
|
||||
copts = tflite_copts() + [
|
||||
"-Wno-error=reorder",
|
||||
] + select({
|
||||
"//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
|
||||
"//conditions:default": [
|
||||
],
|
||||
}),
|
||||
deps = [
|
||||
":gru_cell",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||
"//tensorflow/lite/kernels:cpu_backend_support",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "unidirectional_sequence_gru_test",
|
||||
size = "small",
|
||||
srcs = ["unidirectional_sequence_gru_test.cc"],
|
||||
tags = ["tflite_not_portable_ios"],
|
||||
deps = [
|
||||
":unidirectional_sequence_gru_op",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
94
tensorflow/lite/experimental/kernels/gru_cell.cc
Normal file
94
tensorflow/lite/experimental/kernels/gru_cell.cc
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace experimental {
|
||||
namespace gru_cell {
|
||||
|
||||
using optimized_ops::ArrayMap;
|
||||
using optimized_ops::FullyConnected;
|
||||
using optimized_ops::MapAsArrayWithLastDimAsRows;
|
||||
using reference_ops::Concatenation;
|
||||
|
||||
void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
const RuntimeShape& state_shape, const float* input_state,
|
||||
const RuntimeShape& gate_weight_shape, const float* gate_weight,
|
||||
const RuntimeShape& gate_bias_shape, const float* gate_bias,
|
||||
const RuntimeShape& candidate_weight_shape,
|
||||
const float* candidate_weight,
|
||||
const RuntimeShape& candidate_bias_shape,
|
||||
const float* candidate_bias, const RuntimeShape& output_shape,
|
||||
float* output, float* output_state,
|
||||
const RuntimeShape& activation_shape, float* activation,
|
||||
const RuntimeShape& concat_shape, float* concat,
|
||||
const tflite::FullyConnectedParams& fc_params,
|
||||
tflite::CpuBackendContext* cpu_backend_context) {
|
||||
const int n_batch = input_shape.Dims(0);
|
||||
const int n_input = input_shape.Dims(1);
|
||||
const int n_output = state_shape.Dims(1);
|
||||
|
||||
// [x h] = concat(input, state)
|
||||
std::vector<float const*> concat_arrays_data;
|
||||
std::vector<RuntimeShape const*> concat_arrays_shapes;
|
||||
concat_arrays_data.push_back(input);
|
||||
concat_arrays_data.push_back(input_state);
|
||||
concat_arrays_shapes.push_back(&input_shape);
|
||||
concat_arrays_shapes.push_back(&state_shape);
|
||||
tflite::ConcatenationParams concat_params;
|
||||
concat_params.axis = 1;
|
||||
concat_params.inputs_count = concat_arrays_data.size();
|
||||
Concatenation(concat_params, &(concat_arrays_shapes[0]),
|
||||
&(concat_arrays_data[0]), concat_shape, concat);
|
||||
|
||||
// [r u] = [x h] * gate_weight + gate_bias
|
||||
FullyConnected(fc_params, concat_shape, concat, gate_weight_shape,
|
||||
gate_weight, gate_bias_shape, gate_bias, activation_shape,
|
||||
activation, cpu_backend_context);
|
||||
|
||||
// [r u] = sigmoid([r u])
|
||||
auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape);
|
||||
ru = ru.unaryExpr(Eigen::internal::scalar_logistic_op<float>());
|
||||
auto r = ru.block(0 * n_output, 0, n_output, n_batch);
|
||||
auto u = ru.block(1 * n_output, 0, n_output, n_batch);
|
||||
|
||||
// hr = h .* r
|
||||
auto h = MapAsArrayWithLastDimAsRows(input_state, state_shape);
|
||||
auto xh = MapAsArrayWithLastDimAsRows(concat, concat_shape);
|
||||
auto hr = xh.block(n_input, 0, n_output, n_batch);
|
||||
hr = h * r;
|
||||
|
||||
// c = [x hr] * candidate_weight + candidate_bias
|
||||
FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape,
|
||||
candidate_weight, candidate_bias_shape, candidate_bias,
|
||||
output_shape, output, cpu_backend_context);
|
||||
|
||||
auto c = MapAsArrayWithLastDimAsRows(output, output_shape);
|
||||
// output = (1 - u) .* tanh(c) + u .* h
|
||||
c = (1.0 - u) * c.tanh() + u * h;
|
||||
|
||||
memcpy(output_state, output, n_batch * n_output * sizeof(float));
|
||||
}
|
||||
|
||||
} // namespace gru_cell
|
||||
} // namespace experimental
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
46
tensorflow/lite/experimental/kernels/gru_cell.h
Normal file
46
tensorflow/lite/experimental/kernels/gru_cell.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace experimental {
|
||||
namespace gru_cell {
|
||||
|
||||
void GruCell(const RuntimeShape& input_shape, const float* input,
|
||||
const RuntimeShape& state_shape, const float* input_state,
|
||||
const RuntimeShape& gate_weight_shape, const float* gate_weight,
|
||||
const RuntimeShape& gate_bias_shape, const float* gate_bias,
|
||||
const RuntimeShape& candidate_weight_shape,
|
||||
const float* candidate_weight,
|
||||
const RuntimeShape& candidate_bias_shape,
|
||||
const float* candidate_bias, const RuntimeShape& output_shape,
|
||||
float* output, float* output_state,
|
||||
const RuntimeShape& activation_shape, float* activation,
|
||||
const RuntimeShape& concat_shape, float* concat,
|
||||
const tflite::FullyConnectedParams& fc_params,
|
||||
tflite::CpuBackendContext* cpu_backend_context);
|
||||
|
||||
} // namespace gru_cell
|
||||
} // namespace experimental
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_
|
@ -0,0 +1,250 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace experimental {
|
||||
namespace unidirectional_sequence_gru {
|
||||
namespace {
|
||||
|
||||
void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
|
||||
const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
|
||||
const TfLiteTensor* candidate_weight,
|
||||
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
|
||||
TfLiteTensor* output_state, TfLiteTensor* activation,
|
||||
TfLiteTensor* concat,
|
||||
tflite::CpuBackendContext* cpu_backend_context) {
|
||||
const int n_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
const int n_output = output->dims->data[2];
|
||||
const int n_batch_input = n_batch * n_input;
|
||||
const int n_batch_output = n_batch * n_output;
|
||||
const RuntimeShape input_shape({n_batch, n_input});
|
||||
const float* input_data = GetTensorData<float>(input);
|
||||
const RuntimeShape state_shape = GetTensorShape(input_state);
|
||||
const float* input_state_data = GetTensorData<float>(input_state);
|
||||
const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
|
||||
const float* gate_weight_data = GetTensorData<float>(gate_weight);
|
||||
const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
|
||||
const float* gate_bias_data = GetTensorData<float>(gate_bias);
|
||||
const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
|
||||
const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
|
||||
const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
|
||||
const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
|
||||
const RuntimeShape activation_shape = GetTensorShape(activation);
|
||||
const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
|
||||
float* output_data = GetTensorData<float>(output);
|
||||
float* output_state_data = GetTensorData<float>(output_state);
|
||||
float* activation_data = GetTensorData<float>(activation);
|
||||
const RuntimeShape concat_shape = GetTensorShape(concat);
|
||||
float* concat_data = GetTensorData<float>(concat);
|
||||
tflite::FullyConnectedParams fc_params;
|
||||
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
|
||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||
for (int i = 0; i < n_time; ++i) {
|
||||
gru_cell::GruCell(
|
||||
input_shape, input_data, state_shape, input_state_data,
|
||||
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
|
||||
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
|
||||
candidate_bias_data, output_shape, output_data, output_state_data,
|
||||
activation_shape, activation_data, concat_shape, concat_data, fc_params,
|
||||
cpu_backend_context);
|
||||
input_data += n_batch_input;
|
||||
output_data += n_batch_output;
|
||||
input_state_data = output_state_data;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
enum InputTensor {
|
||||
// Input tensor of size [n_time, n_batch, n_input]
|
||||
kInput = 0,
|
||||
// Input state tensor of size [n_batch, n_output]
|
||||
kInputState = 1,
|
||||
// Gate weight tensor of size [2*n_output, n_input+n_output]
|
||||
kGateWeight = 2,
|
||||
// Gate bias tensor of size [2*n_output]
|
||||
kGateBias = 3,
|
||||
// Candidate weight tensor of size [n_output, n_input+n_output]
|
||||
kCandidateWeight = 4,
|
||||
// Candidate bias tensor of size [n_output]
|
||||
kCandidateBias = 5,
|
||||
kInputNum = 6
|
||||
};
|
||||
|
||||
enum OutputTensor {
|
||||
// Input tensor of size [n_time, n_batch, n_output]
|
||||
kOutput = 0,
|
||||
// Output state tensor of size [n_batch, n_output]
|
||||
kOutputState = 1,
|
||||
kOutputNum = 2
|
||||
};
|
||||
|
||||
enum TemporaryTensor {
|
||||
// Scratch buffer for activation of size [n_batch, 2*n_output]
|
||||
kActivation = 0,
|
||||
// Scratch buffer for activation of size [n_batch, n_input+n_output]
|
||||
kConcat = 1,
|
||||
kTemporaryNum = 2
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
cpu_backend_support::IncrementUsageCounter(context);
|
||||
auto* scratch_tensor_index = new int;
|
||||
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
|
||||
return scratch_tensor_index;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
cpu_backend_support::DecrementUsageCounter(context);
|
||||
delete reinterpret_cast<int*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
|
||||
|
||||
// input's dim = [n_time, n_batch, n_input]
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
|
||||
const int n_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
|
||||
// input_state's dim = [n_batch, n_output]
|
||||
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
|
||||
TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
|
||||
const int n_output = input_state->dims->data[1];
|
||||
|
||||
// gate_weight' dim = [2 * n_output, n_input + n_output]
|
||||
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
|
||||
TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
|
||||
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
|
||||
|
||||
// gate_bias' dim = [2 * n_output]
|
||||
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
|
||||
TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
|
||||
|
||||
// candidate_weight' dim = [n_output, n_input + n_output]
|
||||
const TfLiteTensor* candidate_weight =
|
||||
GetInput(context, node, kCandidateWeight);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
|
||||
n_input + n_output);
|
||||
|
||||
// candidate_bias' dim = [n_output]
|
||||
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
|
||||
|
||||
// output's dim = [n_time, n_batch, n_output]
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutput);
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||
output_size->data[0] = n_time;
|
||||
output_size->data[1] = n_batch;
|
||||
output_size->data[2] = n_output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
|
||||
// output_state's dim = [n_batch, n_output]
|
||||
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, output_state,
|
||||
TfLiteIntArrayCopy(input_state->dims)));
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
|
||||
|
||||
// activation's dim = [n_batch, 2 * n_output]
|
||||
node->temporaries->data[kActivation] = *scratch_tensor_index;
|
||||
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
||||
activation->type = input->type;
|
||||
activation->allocation_type = kTfLiteArenaRw;
|
||||
TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
|
||||
activation_size->data[0] = n_batch;
|
||||
activation_size->data[1] = 2 * n_output;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, activation, activation_size));
|
||||
|
||||
// concat's dim = [n_batch, n_input + n_output]
|
||||
node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
|
||||
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
||||
concat->type = input->type;
|
||||
concat->allocation_type = kTfLiteArenaRw;
|
||||
TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
|
||||
concat_size->data[0] = n_batch;
|
||||
concat_size->data[1] = n_input + n_output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, concat, concat_size));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
|
||||
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
|
||||
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
|
||||
const TfLiteTensor* candidate_weight =
|
||||
GetInput(context, node, kCandidateWeight);
|
||||
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutput);
|
||||
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
||||
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
||||
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
||||
auto cpu_backend_context = cpu_backend_support::GetFromContext(context);
|
||||
|
||||
if (gate_weight->type == kTfLiteFloat32) {
|
||||
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
|
||||
candidate_bias, output, output_state, activation, concat,
|
||||
cpu_backend_context);
|
||||
} else {
|
||||
context->ReportError(context,
|
||||
"Unsupported combination of data types for GruCell");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace unidirectional_sequence_gru
|
||||
|
||||
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
|
||||
static TfLiteRegistration r = {
|
||||
unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
|
||||
unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
@ -0,0 +1,147 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace experimental {
|
||||
|
||||
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU();
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class GRUOpModel : public SingleOpModel {
|
||||
public:
|
||||
explicit GRUOpModel(const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weight_type = TensorType_FLOAT32) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
input_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
|
||||
gate_weight_ = AddInput(TensorType_FLOAT32);
|
||||
gate_bias_ = AddInput(TensorType_FLOAT32);
|
||||
candidate_weight_ = AddInput(TensorType_FLOAT32);
|
||||
candidate_bias_ = AddInput(TensorType_FLOAT32);
|
||||
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
output_state_ = AddOutput(TensorType_FLOAT32);
|
||||
|
||||
SetCustomOp("UNIDIRECTIONAL_SEQUENCE_GRU", {},
|
||||
Register_UNIDIRECTIONAL_SEQUENCE_GRU);
|
||||
BuildInterpreter(input_shapes);
|
||||
}
|
||||
|
||||
void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
|
||||
|
||||
void SetInputState(const std::vector<float>& f) {
|
||||
PopulateTensor(input_state_, f);
|
||||
}
|
||||
|
||||
void SetGateWeight(const std::vector<float>& f) {
|
||||
PopulateTensor(gate_weight_, f);
|
||||
}
|
||||
|
||||
void SetGateBias(const std::vector<float>& f) {
|
||||
PopulateTensor(gate_bias_, f);
|
||||
}
|
||||
|
||||
void SetCandidateWeight(const std::vector<float>& f) {
|
||||
PopulateTensor(candidate_weight_, f);
|
||||
}
|
||||
|
||||
void SetCandidateBias(const std::vector<float>& f) {
|
||||
PopulateTensor(candidate_bias_, f);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
int num_batches() { return n_batch_; }
|
||||
int num_inputs() { return n_input_; }
|
||||
int num_outputs() { return n_output_; }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int input_state_;
|
||||
int gate_weight_;
|
||||
int gate_bias_;
|
||||
int candidate_weight_;
|
||||
int candidate_bias_;
|
||||
|
||||
int output_;
|
||||
int output_state_;
|
||||
int n_batch_;
|
||||
int n_input_;
|
||||
int n_output_;
|
||||
};
|
||||
|
||||
TEST(GRUTest, SimpleTest) {
|
||||
const int n_time = 2;
|
||||
const int n_batch = 2;
|
||||
const int n_input = 2;
|
||||
const int n_output = 3;
|
||||
|
||||
GRUOpModel m({{n_time, n_batch, n_input},
|
||||
{n_batch, n_output},
|
||||
{2 * n_output, n_input + n_output},
|
||||
{2 * n_output},
|
||||
{n_output, n_input + n_output},
|
||||
{n_output}});
|
||||
// All data is randomly generated.
|
||||
m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
|
||||
0.93647677, 0.47361764, 0.39643127});
|
||||
m.SetInputState(
|
||||
{0.09992421, 0.3028481, 0.78305984, 0.50438094, 0.11269058, 0.10244724});
|
||||
m.SetGateWeight({0.7256918, 0.8945897, 0.03285786, 0.42637166, 0.119376324,
|
||||
0.83035135, 0.16997327, 0.42302176, 0.77598256, 0.2660894,
|
||||
0.9587266, 0.6218451, 0.88164485, 0.12272458, 0.2699055,
|
||||
0.18399088, 0.21930052, 0.3374841, 0.70866305, 0.9523419,
|
||||
0.25170696, 0.60988617, 0.79823977, 0.64477515, 0.2602957,
|
||||
0.5053131, 0.93722224, 0.8451359, 0.97905475, 0.38669217});
|
||||
m.SetGateBias(
|
||||
{0.032708533, 0.018445263, 0.15320699, 0.8163046, 0.26683575, 0.1412022});
|
||||
m.SetCandidateWeight({0.96165305, 0.95572084, 0.11534478, 0.96965164,
|
||||
0.33562955, 0.8680755, 0.003066936, 0.057793964,
|
||||
0.8671354, 0.33354893, 0.7313398, 0.78492093,
|
||||
0.19530584, 0.116550304, 0.13599132});
|
||||
m.SetCandidateBias({0.89837056, 0.54769796, 0.63364106});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(n_time, n_batch, n_output));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0.20112592, 0.45286041, 0.80842507, 0.59567153, 0.2619998,
|
||||
0.22922856, 0.27715868, 0.5247152, 0.82300174, 0.65812796,
|
||||
0.38217607, 0.3401444})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace experimental
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
Loading…
Reference in New Issue
Block a user