Add support for full LSTM operation in GPU delegate.
PiperOrigin-RevId: 330964067 Change-Id: I88ce0eb1f1e3ddb0c54f95fd9a50af359dde1791
This commit is contained in:
parent
f13d7aeb42
commit
23e75f5e56
@ -745,6 +745,25 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "lstm_full_test",
|
||||
srcs = ["lstm_full_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"linux",
|
||||
"local",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/gpu:delegate",
|
||||
"//tensorflow/lite/kernels:test_main",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mean_stddev_normalization",
|
||||
srcs = ["mean_stddev_normalization.cc"],
|
||||
|
1181
tensorflow/lite/delegates/gpu/cl/kernels/lstm_full_test.cc
Normal file
1181
tensorflow/lite/delegates/gpu/cl/kernels/lstm_full_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -107,12 +107,39 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_parser",
|
||||
srcs = ["lstm_parser.cc"],
|
||||
hdrs = ["lstm_parser.h"],
|
||||
deps = [
|
||||
":data_type",
|
||||
":model",
|
||||
":model_builder_helper",
|
||||
":object_reader",
|
||||
":operations",
|
||||
":shape",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/lite:string",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:lstm_shared",
|
||||
"//tensorflow/lite/kernels/internal:quantization_util",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:any",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_builder",
|
||||
srcs = ["model_builder.cc"],
|
||||
hdrs = ["model_builder.h"],
|
||||
deps = [
|
||||
":data_type",
|
||||
":lstm_parser",
|
||||
":model",
|
||||
":model_builder_helper",
|
||||
":model_transformer",
|
||||
|
553
tensorflow/lite/delegates/gpu/common/lstm_parser.cc
Normal file
553
tensorflow/lite/delegates/gpu/common/lstm_parser.cc
Normal file
@ -0,0 +1,553 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/any.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/lstm_shared.h"
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
Value* CreateNewSimilarValue(GraphFloat32* graph, const Value* old_value) {
|
||||
Value* new_value = graph->NewValue();
|
||||
new_value->quant_params = old_value->quant_params;
|
||||
new_value->tensor.shape = old_value->tensor.shape;
|
||||
new_value->tensor.type = old_value->tensor.type;
|
||||
new_value->tensor.ref = -1;
|
||||
return new_value;
|
||||
}
|
||||
|
||||
absl::Status SetFullyConnectedWeights(int weights_tensor_id,
|
||||
ObjectReader* reader,
|
||||
FullyConnectedAttributes* attr) {
|
||||
Tensor<HW, DataType::FLOAT32> weights;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
|
||||
attr->weights.data = std::move(weights.data);
|
||||
attr->weights.id = weights.id;
|
||||
attr->weights.shape.o = weights.shape.h;
|
||||
attr->weights.shape.h = 1;
|
||||
attr->weights.shape.w = 1;
|
||||
attr->weights.shape.i = weights.shape.w;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool HasTensor(const TfLiteNode* node, const int index) {
|
||||
return (index < node->inputs->size) &&
|
||||
(node->inputs->data[index] != kTfLiteOptionalTensor);
|
||||
}
|
||||
|
||||
bool HasCifg(const TfLiteNode* node) {
|
||||
return !HasTensor(
|
||||
node, tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor);
|
||||
}
|
||||
|
||||
bool HasPeephole(const TfLiteNode* node) {
|
||||
// Use forget weights to detect peephole instead of input weights as input
|
||||
// weights may be missing for cifg.
|
||||
return HasTensor(
|
||||
node, tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor);
|
||||
}
|
||||
|
||||
bool HasNormalization(const TfLiteNode* node) {
|
||||
return HasTensor(
|
||||
node,
|
||||
tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
|
||||
}
|
||||
|
||||
bool HasProjection(const TfLiteNode* node) {
|
||||
return HasTensor(node,
|
||||
tflite::ops::builtin::lstm::full::kProjectionWeightsTensor);
|
||||
}
|
||||
|
||||
// Builds subgraph for a single LSTM gate.
|
||||
// Returns a Value representing the gate's output.
|
||||
// High-level parameters:
|
||||
// - Has normalization (if true: provide normalization weights).
|
||||
// - Has peephole connection (if true: provide peephole weights).
|
||||
// - Which activation function to use.
|
||||
// Note: no support for aux input.
|
||||
//
|
||||
// Implements the following:
|
||||
// (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
|
||||
// temp = input_weights * input_tensor + recurrent_weights * output_state;
|
||||
// if (peephole):
|
||||
// temp += peephole_weights .* cell_state;
|
||||
// if (layer normalization):
|
||||
// gate = activate(normalization_weights .* mean_stddev_norm(temp) + bias);
|
||||
// else:
|
||||
// gate = activate(temp + bias);
|
||||
//
|
||||
absl::Status BuildLstmGate(GraphFloat32* graph, ObjectReader* reader,
|
||||
Value* output_state, Value* cell_state,
|
||||
int input_weight_id, int recurrent_weight_id,
|
||||
int cell_weight_id, int bias_id,
|
||||
int normalization_weight_id,
|
||||
const TfLiteFusedActivation activation,
|
||||
bool has_peephole, bool has_normalization,
|
||||
Value** gate_out) {
|
||||
Value* input_times_weights = CreateNewSimilarValue(graph, cell_state);
|
||||
{
|
||||
// #1 matrix multiplication: input_weights * input_tensor
|
||||
// If has no normalization, also adds bias.
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::FULLY_CONNECTED);
|
||||
FullyConnectedAttributes fc_attr;
|
||||
RETURN_IF_ERROR(
|
||||
SetFullyConnectedWeights(input_weight_id, reader, &fc_attr));
|
||||
if (!has_normalization) {
|
||||
RETURN_IF_ERROR(reader->ReadTensor(bias_id, &(fc_attr.bias)));
|
||||
}
|
||||
node->operation.attributes = std::move(fc_attr);
|
||||
RETURN_IF_ERROR(
|
||||
reader->AddInput(node, tflite::ops::builtin::lstm::full::kInputTensor));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, input_times_weights->id));
|
||||
}
|
||||
|
||||
Value* output_state_times_weights = CreateNewSimilarValue(graph, cell_state);
|
||||
{
|
||||
// #2 matrix multiplication: recurrent_weights * output_state
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::FULLY_CONNECTED);
|
||||
FullyConnectedAttributes fc_attr;
|
||||
RETURN_IF_ERROR(
|
||||
SetFullyConnectedWeights(recurrent_weight_id, reader, &fc_attr));
|
||||
node->operation.attributes = std::move(fc_attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, output_state->id));
|
||||
RETURN_IF_ERROR(
|
||||
graph->SetProducer(node->id, output_state_times_weights->id));
|
||||
}
|
||||
|
||||
Value* cell_state_times_weights;
|
||||
if (has_peephole) {
|
||||
// #3 elementwise multiplication: cell_weight .* cell_state
|
||||
cell_state_times_weights = CreateNewSimilarValue(graph, cell_state);
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
ElementwiseAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> weights;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(cell_weight_id, &weights));
|
||||
attr.param = std::move(weights);
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_times_weights->id));
|
||||
}
|
||||
|
||||
Value* gate_before_normalization = CreateNewSimilarValue(graph, cell_state);
|
||||
Node* add_node = graph->NewNode();
|
||||
{
|
||||
// #4 elementwise addition: #1 + #2 + #3
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(add_node->id, input_times_weights->id));
|
||||
RETURN_IF_ERROR(
|
||||
graph->AddConsumer(add_node->id, output_state_times_weights->id));
|
||||
if (has_peephole) {
|
||||
RETURN_IF_ERROR(
|
||||
graph->AddConsumer(add_node->id, cell_state_times_weights->id));
|
||||
}
|
||||
RETURN_IF_ERROR(
|
||||
graph->SetProducer(add_node->id, gate_before_normalization->id));
|
||||
}
|
||||
|
||||
if (!has_normalization) {
|
||||
// #5 Activation function: activate(temp + bias)
|
||||
// Bias is added in node #1.
|
||||
RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, add_node));
|
||||
*gate_out = gate_before_normalization;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Value* normalized_gate =
|
||||
CreateNewSimilarValue(graph, gate_before_normalization);
|
||||
{
|
||||
// #6 Normalization: normalize(temp)
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MEAN_STDDEV_NORMALIZATION);
|
||||
RETURN_IF_ERROR(
|
||||
graph->AddConsumer(node->id, gate_before_normalization->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, normalized_gate->id));
|
||||
}
|
||||
Value* reweighted_normalized_gate =
|
||||
CreateNewSimilarValue(graph, normalized_gate);
|
||||
{
|
||||
// #7 Elementwise multiplication: norm_weights .* #6
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
ElementwiseAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> norm_weights;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(normalization_weight_id, &norm_weights));
|
||||
attr.param = std::move(norm_weights);
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, normalized_gate->id));
|
||||
RETURN_IF_ERROR(
|
||||
graph->SetProducer(node->id, reweighted_normalized_gate->id));
|
||||
}
|
||||
Value* gate = CreateNewSimilarValue(graph, reweighted_normalized_gate);
|
||||
{
|
||||
// #8 Elementwise add: #7 + bias
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::ADD);
|
||||
ElementwiseAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> bias;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(bias_id, &bias));
|
||||
attr.param = std::move(bias);
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(
|
||||
graph->AddConsumer(node->id, reweighted_normalized_gate->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, gate->id));
|
||||
|
||||
// #9: Activation function
|
||||
RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
|
||||
}
|
||||
*gate_out = gate;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Builds subgraph for LSTM cell state update.
|
||||
// Returns a Value representing the updated cell state.
|
||||
// High-level parameters:
|
||||
// - clip: if > 0, clamp the resulting cell state to [-clip, +clip].
|
||||
//
|
||||
// Implements the following:
|
||||
// (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
|
||||
//
|
||||
// cell_state_new = clip(forget_gate .* cell_state + input_gate .* cell_gate);
|
||||
//
|
||||
absl::Status BuildCellStateUpdate(GraphFloat32* graph, ObjectReader* reader,
|
||||
Value* forget_gate, Value* input_gate,
|
||||
Value* cell_gate, float cell_clip,
|
||||
Value** cell_state_new) {
|
||||
Value* cell_state;
|
||||
RETURN_IF_ERROR(reader->ReadValue(
|
||||
tflite::ops::builtin::lstm::full::kCellStateTensor, &cell_state));
|
||||
Value* cell_state_contrib = CreateNewSimilarValue(graph, cell_gate);
|
||||
{
|
||||
// #1 elementwise multiplication: forget_gate .* cell_state
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_contrib->id));
|
||||
}
|
||||
Value* cell_gate_contrib = CreateNewSimilarValue(graph, cell_gate);
|
||||
{
|
||||
// #2 elementwise multiplication: input_gate .* cell_gate
|
||||
// Note, with CIFG input_gate is equal to 1-forget_gate.
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input_gate->id));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, cell_gate_contrib->id));
|
||||
}
|
||||
Value* new_cell_state = CreateNewSimilarValue(graph, cell_gate);
|
||||
{
|
||||
// #3 elementwise add: #1 + #2
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::ADD);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state_contrib->id));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate_contrib->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, new_cell_state->id));
|
||||
}
|
||||
|
||||
if (cell_clip <= 0.0f) {
|
||||
*cell_state_new = new_cell_state;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// TODO(b/157166356): Maybe add OperationType::CLAMP ?
|
||||
Value* max_clipped_state = CreateNewSimilarValue(graph, new_cell_state);
|
||||
{
|
||||
// #4 elementwise minimum: min(#3, clip)
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MINIMUM);
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = cell_clip;
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, new_cell_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
|
||||
}
|
||||
Value* clipped_cell_state = CreateNewSimilarValue(graph, max_clipped_state);
|
||||
{
|
||||
// #5 elementwise maximum: max(#4, -clip)
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MAXIMUM);
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = -cell_clip;
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_cell_state->id));
|
||||
}
|
||||
*cell_state_new = clipped_cell_state;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Build subgraph for LSTM output state update.
|
||||
// Returns value representing the updated output state.
|
||||
// High-level parameters:
|
||||
// - Has projection (if true, provide projection_weights).
|
||||
// - Has projection bias (only with projection).
|
||||
// - clip: clamp the projection output to [-clip, clip].
|
||||
// - Which activation function to use.
|
||||
// Note the updated output state does not depend on the old output state
|
||||
// directly, only through the output gate.
|
||||
//
|
||||
// Implements the following:
|
||||
// (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
|
||||
//
|
||||
// temp = output_gate .* activate(cell_state);
|
||||
// if (projection):
|
||||
// output_state_new = clip(projection_weights * temp + projection_bias);
|
||||
// else:
|
||||
// output_state_new = temp;
|
||||
//
|
||||
absl::Status BuildOutputStateUpdate(GraphFloat32* graph, ObjectReader* reader,
|
||||
Value* output_state, Value* output_gate,
|
||||
Value* cell_state,
|
||||
TfLiteFusedActivation activation,
|
||||
bool has_projection, float proj_clip,
|
||||
Value** output_state_new) {
|
||||
Value* activated_state = CreateNewSimilarValue(graph, cell_state);
|
||||
{
|
||||
// #1 activation: activate(cell_state)
|
||||
Node* node = graph->NewNode();
|
||||
switch (activation) {
|
||||
case kTfLiteActTanh:
|
||||
node->operation.type = ToString(OperationType::TANH);
|
||||
break;
|
||||
case kTfLiteActSigmoid:
|
||||
node->operation.type = ToString(OperationType::SIGMOID);
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported activation: ", activation));
|
||||
}
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, activated_state->id));
|
||||
}
|
||||
|
||||
Value* new_output_state = CreateNewSimilarValue(graph, cell_state);
|
||||
{
|
||||
// #2 elementwise multiplication: output_gate .* #1
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, activated_state->id));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, output_gate->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, new_output_state->id));
|
||||
}
|
||||
|
||||
if (!has_projection) {
|
||||
*output_state_new = new_output_state;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Value* projected_output_state = CreateNewSimilarValue(graph, output_state);
|
||||
{
|
||||
// #3 matrix multiplication: projection_weights * #2 + projection_bias
|
||||
Node* node = graph->NewNode();
|
||||
FullyConnectedAttributes fc_attr;
|
||||
RETURN_IF_ERROR(SetFullyConnectedWeights(
|
||||
tflite::ops::builtin::lstm::full::kProjectionWeightsTensor, reader,
|
||||
&fc_attr));
|
||||
// Projection bias is optional
|
||||
reader
|
||||
->ReadTensor(tflite::ops::builtin::lstm::full::kProjectionBiasTensor,
|
||||
&(fc_attr.bias))
|
||||
.IgnoreError();
|
||||
node->operation.attributes = std::move(fc_attr);
|
||||
node->operation.type = ToString(OperationType::FULLY_CONNECTED);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, projected_output_state->id));
|
||||
}
|
||||
|
||||
if (proj_clip <= 0.0f) {
|
||||
*output_state_new = projected_output_state;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// TODO(b/157166356): Maybe add OperationType::CLAMP ?
|
||||
Value* max_clipped_state =
|
||||
CreateNewSimilarValue(graph, projected_output_state);
|
||||
{
|
||||
// #4 elementwise minimum: min(#3, clip)
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MINIMUM);
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = proj_clip;
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, projected_output_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
|
||||
}
|
||||
Value* clipped_output_state = CreateNewSimilarValue(graph, max_clipped_state);
|
||||
{
|
||||
// #5 elementwise maximum: max(#4, -clip)
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MAXIMUM);
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = -proj_clip;
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_output_state->id));
|
||||
}
|
||||
*output_state_new = clipped_output_state;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Build subgraph for a single LSTM OP.
|
||||
// Returns a mapping for the used variable tensors' updated Values.
|
||||
//
|
||||
// High-level parameters:
|
||||
// - Has CIFG:
|
||||
// If false, calculate input_gate regularly.
|
||||
// If true, calculate input_gate to 1-forget_gate.
|
||||
// - Has peephole: see BuildLstmGate. Applies to all gates.
|
||||
// - Has normalization: see BuildLstmGate. Applies to all gates.
|
||||
// - Has projection, projection_bias, proj_clip: see BuildOutputStateUpdate
|
||||
// - Which activation to use:
|
||||
// Applies to only cell gate and output state update.
|
||||
// Other gates always use Sigmoid.
|
||||
//
|
||||
absl::Status ParseLSTMAttributes(
|
||||
const TfLiteNode* tflite_node, const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params,
|
||||
absl::flat_hash_map<int, ValueId>* new_variable_input_values) {
|
||||
const bool has_cifg = HasCifg(tflite_node);
|
||||
const bool has_peephole = HasPeephole(tflite_node);
|
||||
const bool has_normalization = HasNormalization(tflite_node);
|
||||
const bool has_projection = HasProjection(tflite_node);
|
||||
|
||||
Value* old_cell_state;
|
||||
RETURN_IF_ERROR(reader->ReadValue(
|
||||
tflite::ops::builtin::lstm::full::kCellStateTensor, &old_cell_state));
|
||||
|
||||
if (old_cell_state->tensor.shape.b != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Batched execution is not supported for LSTM");
|
||||
}
|
||||
|
||||
Value* old_output_state;
|
||||
RETURN_IF_ERROR(reader->ReadValue(
|
||||
tflite::ops::builtin::lstm::full::kOutputStateTensor, &old_output_state));
|
||||
|
||||
Value* forget_gate;
|
||||
RETURN_IF_ERROR(BuildLstmGate(
|
||||
graph, reader, old_output_state, old_cell_state,
|
||||
tflite::ops::builtin::lstm::full::kInputToForgetWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kForgetGateBiasTensor,
|
||||
tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor,
|
||||
kTfLiteActSigmoid, has_peephole, has_normalization, &forget_gate));
|
||||
|
||||
Value* input_gate;
|
||||
if (has_cifg) {
|
||||
// When using cifg, input_gate is computed as (1 - forget_gate).
|
||||
Node* node = graph->NewNode();
|
||||
input_gate = CreateNewSimilarValue(graph, forget_gate);
|
||||
|
||||
node->operation.type = ToString(OperationType::SUB);
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 1.0f;
|
||||
attr.runtime_tensor_is_second = true;
|
||||
node->operation.attributes = std::move(attr);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, input_gate->id));
|
||||
} else {
|
||||
RETURN_IF_ERROR(BuildLstmGate(
|
||||
graph, reader, old_output_state, old_cell_state,
|
||||
tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kRecurrentToInputWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kCellToInputWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kInputGateBiasTensor,
|
||||
tflite::ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor,
|
||||
kTfLiteActSigmoid, has_peephole, has_normalization, &input_gate));
|
||||
}
|
||||
|
||||
// Cell state will not have peephole connections to itself
|
||||
Value* cell_gate;
|
||||
RETURN_IF_ERROR(BuildLstmGate(
|
||||
graph, reader, old_output_state, old_cell_state,
|
||||
tflite::ops::builtin::lstm::full::kInputToCellWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
|
||||
/*cell_weight_id=*/-1,
|
||||
tflite::ops::builtin::lstm::full::kCellGateBiasTensor,
|
||||
tflite::ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor,
|
||||
params->activation, /*has_peephole=*/false, has_normalization,
|
||||
&cell_gate));
|
||||
|
||||
Value* new_cell_state;
|
||||
RETURN_IF_ERROR(BuildCellStateUpdate(graph, reader, forget_gate, input_gate,
|
||||
cell_gate, params->cell_clip,
|
||||
&new_cell_state));
|
||||
|
||||
Value* output_gate;
|
||||
RETURN_IF_ERROR(BuildLstmGate(
|
||||
graph, reader, old_output_state, new_cell_state,
|
||||
tflite::ops::builtin::lstm::full::kInputToOutputWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kCellToOutputWeightsTensor,
|
||||
tflite::ops::builtin::lstm::full::kOutputGateBiasTensor,
|
||||
tflite::ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor,
|
||||
kTfLiteActSigmoid, has_peephole, has_normalization, &output_gate));
|
||||
|
||||
Value* new_output_state;
|
||||
RETURN_IF_ERROR(BuildOutputStateUpdate(graph, reader, old_output_state,
|
||||
output_gate, new_cell_state,
|
||||
params->activation, has_projection,
|
||||
params->proj_clip, &new_output_state));
|
||||
|
||||
{
|
||||
// Copy updated output state to output.
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::COPY);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
|
||||
RETURN_IF_ERROR(reader->AddOutput(
|
||||
node, tflite::ops::builtin::lstm::full::kOutputTensor));
|
||||
}
|
||||
|
||||
new_variable_input_values->clear();
|
||||
new_variable_input_values->emplace(
|
||||
tflite::ops::builtin::lstm::full::kCellStateTensor, new_cell_state->id);
|
||||
new_variable_input_values->emplace(
|
||||
tflite::ops::builtin::lstm::full::kOutputStateTensor,
|
||||
new_output_state->id);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
34
tensorflow/lite/delegates/gpu/common/lstm_parser.h
Normal file
34
tensorflow/lite/delegates/gpu/common/lstm_parser.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_LSTM_PARSER_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_LSTM_PARSER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
absl::Status ParseLSTMAttributes(
|
||||
const TfLiteNode* tflite_node, const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params,
|
||||
absl::flat_hash_map<int, ValueId>* new_variable_input_values);
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_LSTM_PARSER_H_
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
@ -1003,18 +1004,39 @@ class HardSwishOperationParser : public TFLiteOperationParser {
|
||||
// / \
|
||||
// new_state1 activation0
|
||||
//
|
||||
// For full LSTM cells, see this blog post:
|
||||
// https://colah.github.io/posts/2015-08-Understanding-LSTMs/
|
||||
// In addition to Peephole connections and Combined Input Forget Gates (CIFG)
|
||||
// described in that post, this code also adds the following optional features:
|
||||
// - Configurable activations (sigmoid or TANH)
|
||||
// - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
|
||||
// - Output projection:
|
||||
// https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
|
||||
// - Configurable clipping of cell state and output state.
|
||||
class LSTMOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
|
||||
const TfLiteLSTMParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
switch (tf_options->kernel_type) {
|
||||
case kTfLiteLSTMFullKernel:
|
||||
// TODO(b/157166356): Add check for input/output tensor counts.
|
||||
case kTfLiteLSTMFullKernel: {
|
||||
const int inputs = NumInputs(tflite_node);
|
||||
if (inputs != 20 && inputs != 24) {
|
||||
return absl::InternalError(
|
||||
absl::StrCat("Expected 20 or 24 input tensors, but node has ",
|
||||
inputs, " input(s)."));
|
||||
}
|
||||
const int runtime_outputs = NumOutputs(tflite_node);
|
||||
if (runtime_outputs != 1) {
|
||||
return absl::InternalError(
|
||||
absl::StrCat("Expected 1 output tensor, but node has ",
|
||||
runtime_outputs, " output(s)."));
|
||||
}
|
||||
return CheckFullParameters(tf_options);
|
||||
}
|
||||
case kTfLiteLSTMBasicKernel:
|
||||
RETURN_IF_ERROR(
|
||||
CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/3,
|
||||
@ -1036,6 +1058,11 @@ class LSTMOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
|
||||
final {
|
||||
return new_variable_input_value_map_;
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ParseBasic(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
@ -1108,14 +1135,24 @@ class LSTMOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader,
|
||||
const TfLiteLSTMParams* tf_options) {
|
||||
return absl::UnimplementedError(
|
||||
"Full LSTM support is not yet implemented.");
|
||||
// Invoke full LSTM parser
|
||||
RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
|
||||
reader, tf_options,
|
||||
&new_variable_input_value_map_));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
|
||||
return absl::UnimplementedError(
|
||||
"Full LSTM support is not yet implemented.");
|
||||
if (tf_options->activation != kTfLiteActSigmoid &&
|
||||
tf_options->activation != kTfLiteActTanh) {
|
||||
return absl::UnimplementedError(
|
||||
"Only sigmoid or tanh activation is supported.");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
|
||||
};
|
||||
|
||||
class MulOperationParser : public TFLiteOperationParser {
|
||||
|
Loading…
x
Reference in New Issue
Block a user