diff --git a/tensorflow/lite/experimental/kernels/BUILD b/tensorflow/lite/experimental/kernels/BUILD index 78af889cf1e..bf4a007fb8c 100644 --- a/tensorflow/lite/experimental/kernels/BUILD +++ b/tensorflow/lite/experimental/kernels/BUILD @@ -35,7 +35,7 @@ cc_library( ) cc_library( - name = "experimental_ops", + name = "ctc_beam_search_decoder_op", srcs = [ "ctc_beam_search_decoder.cc", ], @@ -66,7 +66,7 @@ cc_test( srcs = ["ctc_beam_search_decoder_test.cc"], tags = ["tflite_not_portable_ios"], deps = [ - ":experimental_ops", + ":ctc_beam_search_decoder_op", "//tensorflow/lite:framework", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:test_util", @@ -74,3 +74,54 @@ cc_test( "@flatbuffers", ], ) + +cc_library( + name = "gru_cell", + srcs = ["gru_cell.cc"], + hdrs = ["gru_cell.h"], + deps = [ + "//tensorflow/lite/kernels:cpu_backend_context", + "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:tensor", + "//third_party/eigen3", + ], +) + +cc_library( + name = "unidirectional_sequence_gru_op", + srcs = [ + "unidirectional_sequence_gru.cc", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":gru_cell", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:cpu_backend_context", + "//tensorflow/lite/kernels:cpu_backend_support", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels/internal:tensor", + "@flatbuffers", + ], +) + +cc_test( + name = "unidirectional_sequence_gru_test", + size = "small", + srcs = ["unidirectional_sequence_gru_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":unidirectional_sequence_gru_op", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/experimental/kernels/gru_cell.cc b/tensorflow/lite/experimental/kernels/gru_cell.cc new file mode 100644 index 00000000000..c21896ae83f --- /dev/null +++ b/tensorflow/lite/experimental/kernels/gru_cell.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/kernels/gru_cell.h" + +#include + +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace gru_cell { + +using optimized_ops::ArrayMap; +using optimized_ops::FullyConnected; +using optimized_ops::MapAsArrayWithLastDimAsRows; +using reference_ops::Concatenation; + +void GruCell(const RuntimeShape& input_shape, const float* input, + const RuntimeShape& state_shape, const float* input_state, + const RuntimeShape& gate_weight_shape, const float* gate_weight, + const RuntimeShape& gate_bias_shape, const float* gate_bias, + const RuntimeShape& candidate_weight_shape, + const float* candidate_weight, + const RuntimeShape& candidate_bias_shape, + const float* candidate_bias, const RuntimeShape& output_shape, + float* output, float* output_state, + const RuntimeShape& activation_shape, float* activation, + const RuntimeShape& concat_shape, float* concat, + const tflite::FullyConnectedParams& fc_params, + tflite::CpuBackendContext* cpu_backend_context) { + const int n_batch = input_shape.Dims(0); + const int n_input = input_shape.Dims(1); + const int n_output = state_shape.Dims(1); + + // [x h] = concat(input, state) + std::vector concat_arrays_data; + std::vector concat_arrays_shapes; + concat_arrays_data.push_back(input); + concat_arrays_data.push_back(input_state); + concat_arrays_shapes.push_back(&input_shape); + concat_arrays_shapes.push_back(&state_shape); + tflite::ConcatenationParams concat_params; + concat_params.axis = 1; + concat_params.inputs_count = concat_arrays_data.size(); + Concatenation(concat_params, &(concat_arrays_shapes[0]), + &(concat_arrays_data[0]), concat_shape, concat); + + // [r u] = [x h] * gate_weight + gate_bias + FullyConnected(fc_params, concat_shape, concat, gate_weight_shape, + gate_weight, gate_bias_shape, gate_bias, activation_shape, + activation, cpu_backend_context); + + // [r u] = sigmoid([r u]) + auto ru = MapAsArrayWithLastDimAsRows(activation, activation_shape); + ru = ru.unaryExpr(Eigen::internal::scalar_logistic_op()); + auto r = ru.block(0 * n_output, 0, n_output, n_batch); + auto u = ru.block(1 * n_output, 0, n_output, n_batch); + + // hr = h .* r + auto h = MapAsArrayWithLastDimAsRows(input_state, state_shape); + auto xh = MapAsArrayWithLastDimAsRows(concat, concat_shape); + auto hr = xh.block(n_input, 0, n_output, n_batch); + hr = h * r; + + // c = [x hr] * candidate_weight + candidate_bias + FullyConnected(fc_params, concat_shape, concat, candidate_weight_shape, + candidate_weight, candidate_bias_shape, candidate_bias, + output_shape, output, cpu_backend_context); + + auto c = MapAsArrayWithLastDimAsRows(output, output_shape); + // output = (1 - u) .* tanh(c) + u .* h + c = (1.0 - u) * c.tanh() + u * h; + + memcpy(output_state, output, n_batch * n_output * sizeof(float)); +} + +} // namespace gru_cell +} // namespace experimental +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/experimental/kernels/gru_cell.h b/tensorflow/lite/experimental/kernels/gru_cell.h new file mode 100644 index 00000000000..cd7b02e2a69 --- /dev/null +++ b/tensorflow/lite/experimental/kernels/gru_cell.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ + +#include "tensorflow/lite/kernels/cpu_backend_context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace gru_cell { + +void GruCell(const RuntimeShape& input_shape, const float* input, + const RuntimeShape& state_shape, const float* input_state, + const RuntimeShape& gate_weight_shape, const float* gate_weight, + const RuntimeShape& gate_bias_shape, const float* gate_bias, + const RuntimeShape& candidate_weight_shape, + const float* candidate_weight, + const RuntimeShape& candidate_bias_shape, + const float* candidate_bias, const RuntimeShape& output_shape, + float* output, float* output_state, + const RuntimeShape& activation_shape, float* activation, + const RuntimeShape& concat_shape, float* concat, + const tflite::FullyConnectedParams& fc_params, + tflite::CpuBackendContext* cpu_backend_context); + +} // namespace gru_cell +} // namespace experimental +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_GRU_CELL_H_ diff --git a/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc new file mode 100644 index 00000000000..fc0d681f3bc --- /dev/null +++ b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc @@ -0,0 +1,250 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/kernels/gru_cell.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" +#include "tensorflow/lite/kernels/cpu_backend_support.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace unidirectional_sequence_gru { +namespace { + +void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state, + const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias, + const TfLiteTensor* candidate_weight, + const TfLiteTensor* candidate_bias, TfLiteTensor* output, + TfLiteTensor* output_state, TfLiteTensor* activation, + TfLiteTensor* concat, + tflite::CpuBackendContext* cpu_backend_context) { + const int n_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + const int n_output = output->dims->data[2]; + const int n_batch_input = n_batch * n_input; + const int n_batch_output = n_batch * n_output; + const RuntimeShape input_shape({n_batch, n_input}); + const float* input_data = GetTensorData(input); + const RuntimeShape state_shape = GetTensorShape(input_state); + const float* input_state_data = GetTensorData(input_state); + const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight); + const float* gate_weight_data = GetTensorData(gate_weight); + const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias); + const float* gate_bias_data = GetTensorData(gate_bias); + const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight); + const float* candidate_weight_data = GetTensorData(candidate_weight); + const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias); + const float* candidate_bias_data = GetTensorData(candidate_bias); + const RuntimeShape activation_shape = GetTensorShape(activation); + const RuntimeShape output_shape = RuntimeShape({n_batch, n_output}); + float* output_data = GetTensorData(output); + float* output_state_data = GetTensorData(output_state); + float* activation_data = GetTensorData(activation); + const RuntimeShape concat_shape = GetTensorShape(concat); + float* concat_data = GetTensorData(concat); + tflite::FullyConnectedParams fc_params; + fc_params.float_activation_min = std::numeric_limits::lowest(); + fc_params.float_activation_max = std::numeric_limits::max(); + for (int i = 0; i < n_time; ++i) { + gru_cell::GruCell( + input_shape, input_data, state_shape, input_state_data, + gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data, + candidate_weight_shape, candidate_weight_data, candidate_bias_shape, + candidate_bias_data, output_shape, output_data, output_state_data, + activation_shape, activation_data, concat_shape, concat_data, fc_params, + cpu_backend_context); + input_data += n_batch_input; + output_data += n_batch_output; + input_state_data = output_state_data; + } +} + +} // namespace + +enum InputTensor { + // Input tensor of size [n_time, n_batch, n_input] + kInput = 0, + // Input state tensor of size [n_batch, n_output] + kInputState = 1, + // Gate weight tensor of size [2*n_output, n_input+n_output] + kGateWeight = 2, + // Gate bias tensor of size [2*n_output] + kGateBias = 3, + // Candidate weight tensor of size [n_output, n_input+n_output] + kCandidateWeight = 4, + // Candidate bias tensor of size [n_output] + kCandidateBias = 5, + kInputNum = 6 +}; + +enum OutputTensor { + // Input tensor of size [n_time, n_batch, n_output] + kOutput = 0, + // Output state tensor of size [n_batch, n_output] + kOutputState = 1, + kOutputNum = 2 +}; + +enum TemporaryTensor { + // Scratch buffer for activation of size [n_batch, 2*n_output] + kActivation = 0, + // Scratch buffer for activation of size [n_batch, n_input+n_output] + kConcat = 1, + kTemporaryNum = 2 +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + cpu_backend_support::IncrementUsageCounter(context); + auto* scratch_tensor_index = new int; + context->AddTensors(context, kTemporaryNum, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + cpu_backend_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + int* scratch_tensor_index = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum); + TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum); + + // input's dim = [n_time, n_batch, n_input] + const TfLiteTensor* input = GetInput(context, node, kInput); + TF_LITE_ENSURE_EQ(context, input->dims->size, 3); + const int n_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + // input_state's dim = [n_batch, n_output] + const TfLiteTensor* input_state = GetInput(context, node, kInputState); + TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch); + const int n_output = input_state->dims->data[1]; + + // gate_weight' dim = [2 * n_output, n_input + n_output] + const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight); + TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2); + TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output); + TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output); + + // gate_bias' dim = [2 * n_output] + const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias); + TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output); + + // candidate_weight' dim = [n_output, n_input + n_output] + const TfLiteTensor* candidate_weight = + GetInput(context, node, kCandidateWeight); + TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2); + TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1], + n_input + n_output); + + // candidate_bias' dim = [n_output] + const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias); + TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output); + + // output's dim = [n_time, n_batch, n_output] + TfLiteTensor* output = GetOutput(context, node, kOutput); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = n_time; + output_size->data[1] = n_batch; + output_size->data[2] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + // output_state's dim = [n_batch, n_output] + TfLiteTensor* output_state = GetOutput(context, node, kOutputState); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, + TfLiteIntArrayCopy(input_state->dims))); + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(kTemporaryNum); + + // activation's dim = [n_batch, 2 * n_output] + node->temporaries->data[kActivation] = *scratch_tensor_index; + TfLiteTensor* activation = GetTemporary(context, node, kActivation); + activation->type = input->type; + activation->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2); + activation_size->data[0] = n_batch; + activation_size->data[1] = 2 * n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation, activation_size)); + + // concat's dim = [n_batch, n_input + n_output] + node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat; + TfLiteTensor* concat = GetTemporary(context, node, kConcat); + concat->type = input->type; + concat->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2); + concat_size->data[0] = n_batch; + concat_size->data[1] = n_input + n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, concat, concat_size)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInput); + const TfLiteTensor* input_state = GetInput(context, node, kInputState); + const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight); + const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias); + const TfLiteTensor* candidate_weight = + GetInput(context, node, kCandidateWeight); + const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias); + TfLiteTensor* output = GetOutput(context, node, kOutput); + TfLiteTensor* output_state = GetOutput(context, node, kOutputState); + TfLiteTensor* activation = GetTemporary(context, node, kActivation); + TfLiteTensor* concat = GetTemporary(context, node, kConcat); + auto cpu_backend_context = cpu_backend_support::GetFromContext(context); + + if (gate_weight->type == kTfLiteFloat32) { + GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight, + candidate_bias, output, output_state, activation, concat, + cpu_backend_context); + } else { + context->ReportError(context, + "Unsupported combination of data types for GruCell"); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace unidirectional_sequence_gru + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() { + static TfLiteRegistration r = { + unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free, + unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval}; + return &r; +} + +} // namespace experimental +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru_test.cc b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru_test.cc new file mode 100644 index 00000000000..f035c873dbd --- /dev/null +++ b/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/lite/kernels/test_util.h" + +namespace tflite { +namespace ops { +namespace experimental { + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class GRUOpModel : public SingleOpModel { + public: + explicit GRUOpModel(const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) { + input_ = AddInput(TensorType_FLOAT32); + input_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true); + gate_weight_ = AddInput(TensorType_FLOAT32); + gate_bias_ = AddInput(TensorType_FLOAT32); + candidate_weight_ = AddInput(TensorType_FLOAT32); + candidate_bias_ = AddInput(TensorType_FLOAT32); + + output_ = AddOutput(TensorType_FLOAT32); + output_state_ = AddOutput(TensorType_FLOAT32); + + SetCustomOp("UNIDIRECTIONAL_SEQUENCE_GRU", {}, + Register_UNIDIRECTIONAL_SEQUENCE_GRU); + BuildInterpreter(input_shapes); + } + + void SetInput(const std::vector& f) { PopulateTensor(input_, f); } + + void SetInputState(const std::vector& f) { + PopulateTensor(input_state_, f); + } + + void SetGateWeight(const std::vector& f) { + PopulateTensor(gate_weight_, f); + } + + void SetGateBias(const std::vector& f) { + PopulateTensor(gate_bias_, f); + } + + void SetCandidateWeight(const std::vector& f) { + PopulateTensor(candidate_weight_, f); + } + + void SetCandidateBias(const std::vector& f) { + PopulateTensor(candidate_bias_, f); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_batches() { return n_batch_; } + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + + private: + int input_; + int input_state_; + int gate_weight_; + int gate_bias_; + int candidate_weight_; + int candidate_bias_; + + int output_; + int output_state_; + int n_batch_; + int n_input_; + int n_output_; +}; + +TEST(GRUTest, SimpleTest) { + const int n_time = 2; + const int n_batch = 2; + const int n_input = 2; + const int n_output = 3; + + GRUOpModel m({{n_time, n_batch, n_input}, + {n_batch, n_output}, + {2 * n_output, n_input + n_output}, + {2 * n_output}, + {n_output, n_input + n_output}, + {n_output}}); + // All data is randomly generated. + m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085, + 0.93647677, 0.47361764, 0.39643127}); + m.SetInputState( + {0.09992421, 0.3028481, 0.78305984, 0.50438094, 0.11269058, 0.10244724}); + m.SetGateWeight({0.7256918, 0.8945897, 0.03285786, 0.42637166, 0.119376324, + 0.83035135, 0.16997327, 0.42302176, 0.77598256, 0.2660894, + 0.9587266, 0.6218451, 0.88164485, 0.12272458, 0.2699055, + 0.18399088, 0.21930052, 0.3374841, 0.70866305, 0.9523419, + 0.25170696, 0.60988617, 0.79823977, 0.64477515, 0.2602957, + 0.5053131, 0.93722224, 0.8451359, 0.97905475, 0.38669217}); + m.SetGateBias( + {0.032708533, 0.018445263, 0.15320699, 0.8163046, 0.26683575, 0.1412022}); + m.SetCandidateWeight({0.96165305, 0.95572084, 0.11534478, 0.96965164, + 0.33562955, 0.8680755, 0.003066936, 0.057793964, + 0.8671354, 0.33354893, 0.7313398, 0.78492093, + 0.19530584, 0.116550304, 0.13599132}); + m.SetCandidateBias({0.89837056, 0.54769796, 0.63364106}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(n_time, n_batch, n_output)); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.20112592, 0.45286041, 0.80842507, 0.59567153, 0.2619998, + 0.22922856, 0.27715868, 0.5247152, 0.82300174, 0.65812796, + 0.38217607, 0.3401444}))); +} + +} // namespace +} // namespace experimental +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}