TFLite GPU: Make GPU delegate recognize MobileNet v3.
PiperOrigin-RevId: 256432626
This commit is contained in:
parent
3b239555fe
commit
107a35a551
@ -263,24 +263,22 @@ class ObjectReader {
|
||||
tflite_node_(tflite_node),
|
||||
tensor_to_value_(tensor_to_value) {}
|
||||
|
||||
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
|
||||
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) const {
|
||||
if (idx >= tflite_node_->inputs->size) {
|
||||
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
|
||||
}
|
||||
RETURN_IF_ERROR(
|
||||
ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value));
|
||||
return OkStatus();
|
||||
return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value);
|
||||
}
|
||||
|
||||
int GetNumberOfRuntimeInputs() {
|
||||
int GetNumberOfRuntimeInputs() const {
|
||||
return GetNumberOfRuntimeInputsForNode(context_, tflite_node_);
|
||||
}
|
||||
|
||||
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) {
|
||||
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const {
|
||||
if (idx >= tflite_node_->inputs->size) {
|
||||
return OutOfRangeError(StrCat("Input tensor index: ", idx));
|
||||
}
|
||||
int32_t tensor_idx = tflite_node_->inputs->data[idx];
|
||||
const int tensor_idx = tflite_node_->inputs->data[idx];
|
||||
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
|
||||
return OutOfRangeError(StrCat("Tensor index: ", tensor_idx));
|
||||
}
|
||||
@ -330,7 +328,7 @@ class ObjectReader {
|
||||
}
|
||||
|
||||
Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
Value<TensorRef<BHWC>>** value) const {
|
||||
if (tensor_idx >= tensor_to_value_->size()) {
|
||||
return OutOfRangeError(
|
||||
StrCat("ReadValue: input tensor index: ", tensor_idx));
|
||||
@ -351,6 +349,12 @@ class ObjectReader {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
TfLiteTensor* GetInputTensor(int index) const {
|
||||
return index >= 0 && index < tflite_node_->inputs->size
|
||||
? context_->tensors + tflite_node_->inputs->data[index]
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphFloat32* graph_ = nullptr;
|
||||
const TfLiteContext* context_ = nullptr;
|
||||
@ -1019,39 +1023,66 @@ class AddOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||
// TODO(eignasheva): add shapes check.
|
||||
TfLiteAddParams* tf_options = nullptr;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
return OkStatus();
|
||||
if (tflite_node->inputs->size != 2) {
|
||||
return UnimplementedError("ADD requires two input tensors.");
|
||||
}
|
||||
// TODO(eignasheva): Add shapes check.
|
||||
TfLiteAddParams* tf_options = nullptr;
|
||||
return RetrieveBuiltinData(tflite_node, &tf_options);
|
||||
}
|
||||
|
||||
Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||
ObjectReader* reader) final {
|
||||
// TFLite currently only supports 2 input ADDs. Thus, the logic below only
|
||||
// considers 2 input cases. The underlying GPU shader programs can accept
|
||||
// more inputs, but the logic below would have to be expanded.
|
||||
|
||||
// Determine runtime/constant tensors.
|
||||
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
||||
if (!input0) {
|
||||
return InvalidArgumentError("Couldn't get the 1st input tensor for ADD.");
|
||||
}
|
||||
const TfLiteTensor* input1 = reader->GetInputTensor(1);
|
||||
if (!input1) {
|
||||
return InvalidArgumentError("Couldn't get the 2nd input tensor for ADD.");
|
||||
}
|
||||
const bool constant_tensor0 = IsConstantTensor(input0);
|
||||
const bool constant_tensor1 = IsConstantTensor(input1);
|
||||
if (constant_tensor0 && constant_tensor1) {
|
||||
return InvalidArgumentError("No runtime input tensors for ADD.");
|
||||
}
|
||||
const bool runtime_tensor0 = !constant_tensor0;
|
||||
const bool runtime_tensor1 = !constant_tensor1;
|
||||
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::ADD);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
AddAttributes attr;
|
||||
for (int idx = 0; idx < tflite_node->inputs->size; ++idx) {
|
||||
if (!reader->AddInput(node, idx).ok()) {
|
||||
if (tflite_node->inputs->size != 2) {
|
||||
return InvalidArgumentError(
|
||||
"Broadcast Add should accept 2 inputs, one input tensor and "
|
||||
"broadcasted tensor");
|
||||
if (runtime_tensor0 && runtime_tensor1) {
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
||||
} else {
|
||||
int runtime_tensor = 0;
|
||||
int constant_tensor = 1;
|
||||
TfLiteIntArray* constant_dims = input1->dims;
|
||||
if (constant_tensor0 && runtime_tensor1) {
|
||||
runtime_tensor = 1;
|
||||
constant_tensor = 0;
|
||||
constant_dims = input0->dims;
|
||||
}
|
||||
TfLiteIntArray dims;
|
||||
RETURN_IF_ERROR(reader->GetTensorDims(1, &dims));
|
||||
if (dims.size <= 0) {
|
||||
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||
if (constant_dims->size <= 0) {
|
||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = tensor.data[0];
|
||||
} else {
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = std::move(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
node->operation.attributes = std::move(attr);
|
||||
|
||||
const auto* tf_options =
|
||||
@ -1059,9 +1090,8 @@ class AddOperationParser : public TFLiteOperationParser {
|
||||
if (!tf_options) {
|
||||
return InternalError("Missing tflite params");
|
||||
}
|
||||
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
|
||||
graph, node));
|
||||
return OkStatus();
|
||||
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
|
||||
node);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1427,41 +1457,119 @@ class ReLuOperationParser : public TFLiteOperationParser {
|
||||
int clip_;
|
||||
};
|
||||
|
||||
Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
|
||||
const TfLiteIntArray* dims = tflite_tensor.dims;
|
||||
switch (dims->size) {
|
||||
case 1:
|
||||
*bhwc = BHWC(dims->data[0], 1, 1, 1);
|
||||
return OkStatus();
|
||||
case 2:
|
||||
*bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
|
||||
return OkStatus();
|
||||
case 3:
|
||||
*bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
|
||||
return OkStatus();
|
||||
case 4:
|
||||
*bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
|
||||
return OkStatus();
|
||||
default:
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Tensor \"", tflite_tensor.name,
|
||||
"\" has bad input dims size: ", dims->size, "."));
|
||||
}
|
||||
}
|
||||
|
||||
class MulOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||
// TODO(eignasheva): add params check
|
||||
if (tflite_node->inputs->size != 2) {
|
||||
return UnimplementedError("MUL requires two input tensors.");
|
||||
}
|
||||
// TODO(eignasheva): Add params check.
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||
ObjectReader* reader) final {
|
||||
// Determine runtime/constant tensors.
|
||||
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
||||
if (!input0) {
|
||||
return InvalidArgumentError("Couldn't get the 1st input tensor for MUL.");
|
||||
}
|
||||
const TfLiteTensor* input1 = reader->GetInputTensor(1);
|
||||
if (!input1) {
|
||||
return InvalidArgumentError("Couldn't get the 2nd input tensor for MUL.");
|
||||
}
|
||||
const bool constant_tensor0 = IsConstantTensor(input0);
|
||||
const bool constant_tensor1 = IsConstantTensor(input1);
|
||||
if (constant_tensor0 && constant_tensor1) {
|
||||
return InvalidArgumentError("No runtime input tensors for MUL.");
|
||||
}
|
||||
const bool runtime_tensor0 = !constant_tensor0;
|
||||
const bool runtime_tensor1 = !constant_tensor1;
|
||||
|
||||
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
|
||||
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
|
||||
if (runtime_tensor0 && runtime_tensor1) {
|
||||
BHWC shape0;
|
||||
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
||||
BHWC shape1;
|
||||
RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
|
||||
int input_tensor0 = 0;
|
||||
int input_tensor1 = 1;
|
||||
if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
|
||||
shape0.c == shape1.c) {
|
||||
input_tensor0 = 1;
|
||||
input_tensor1 = 0;
|
||||
}
|
||||
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
|
||||
}
|
||||
|
||||
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
|
||||
// input and the constant input tensor must be bound to 2nd input.
|
||||
int runtime_tensor = 0;
|
||||
int constant_tensor = 1;
|
||||
TfLiteIntArray* constant_dims = input1->dims;
|
||||
if (constant_tensor0 && runtime_tensor1) {
|
||||
runtime_tensor = 1;
|
||||
constant_tensor = 0;
|
||||
constant_dims = input0->dims;
|
||||
}
|
||||
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
|
||||
graph, reader);
|
||||
}
|
||||
|
||||
private:
|
||||
Status ParseApplyMask(int input_tensor0, int input_tensor1,
|
||||
GraphFloat32* graph, ObjectReader* reader) {
|
||||
Node* node = graph->NewNode();
|
||||
if (reader->GetNumberOfRuntimeInputs() == 2) {
|
||||
// ApplyMask operation
|
||||
node->operation.type = ToString(OperationType::APPLY_MASK);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
||||
} else {
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||
return reader->AddOutputs(node);
|
||||
}
|
||||
|
||||
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
|
||||
const TfLiteIntArray* constant_dims,
|
||||
GraphFloat32* graph, ObjectReader* reader) {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||
MultiplyScalarAttributes attr;
|
||||
TfLiteIntArray dims;
|
||||
RETURN_IF_ERROR(reader->GetTensorDims(1, &dims));
|
||||
if (dims.size <= 0) {
|
||||
if (constant_dims->size <= 0) {
|
||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = tensor.data[0];
|
||||
} else {
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = std::move(tensor);
|
||||
}
|
||||
node->operation.attributes = std::move(attr);
|
||||
}
|
||||
return reader->AddOutputs(node);
|
||||
}
|
||||
};
|
||||
@ -1963,26 +2071,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
TensorRef<BHWC>* tensor_ref) {
|
||||
tensor_ref->type = ToDataType(tflite_tensor.type);
|
||||
const TfLiteIntArray* dims = tflite_tensor.dims;
|
||||
switch (dims->size) {
|
||||
case 1:
|
||||
tensor_ref->shape = BHWC(dims->data[0], 1, 1, 1);
|
||||
break;
|
||||
case 2:
|
||||
tensor_ref->shape = BHWC(dims->data[0], 1, 1, dims->data[1]);
|
||||
break;
|
||||
case 3:
|
||||
tensor_ref->shape = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
|
||||
break;
|
||||
case 4:
|
||||
tensor_ref->shape =
|
||||
BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
|
||||
break;
|
||||
default:
|
||||
return InvalidArgumentError(StrCat(
|
||||
"Tensor ref has unsupported number of dimensions: ", dims->size));
|
||||
}
|
||||
return OkStatus();
|
||||
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
|
||||
}
|
||||
|
||||
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
|
||||
|
@ -286,12 +286,12 @@ cc_library(
|
||||
srcs = ["mul.cc"],
|
||||
hdrs = ["mul.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
@ -33,34 +34,21 @@ namespace {
|
||||
class ApplyMask : public NodeShader {
|
||||
public:
|
||||
static bool IsSupported(const GenerationContext& ctx) {
|
||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
if (inputs.size() != 2) return false;
|
||||
const auto& shape0 = inputs[0]->tensor.shape;
|
||||
const auto& shape1 = inputs[1]->tensor.shape;
|
||||
|
||||
// Implementation requires 2 input tensors: source and mask.
|
||||
if (inputs.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
// [H, W, C] x [H, W, 0][0]
|
||||
if (shape1.c == 1) return true;
|
||||
|
||||
auto src_shape = inputs[0]->tensor.shape;
|
||||
auto mask_shape = inputs[1]->tensor.shape;
|
||||
if (shape0.c != shape1.c) return false;
|
||||
|
||||
// Height and width dimensions of the two input tensors must be the same.
|
||||
if (src_shape.h != mask_shape.h || src_shape.w != mask_shape.w) {
|
||||
return false;
|
||||
}
|
||||
// [H, W, C] x [H, W, C]
|
||||
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
|
||||
|
||||
// Broadcast will be done if mask tensor has 1 channel.
|
||||
if (mask_shape.c == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Bitwise multiplication will be done if mask tensor has the same amount of
|
||||
// channels as source tensor.
|
||||
if (src_shape.c == mask_shape.c) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Other cases are not supported.
|
||||
return false;
|
||||
// [H, W, C] x [0, 0, C]
|
||||
return shape1.h == 1 && shape1.w == 1;
|
||||
}
|
||||
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
@ -69,19 +57,20 @@ class ApplyMask : public NodeShader {
|
||||
return InvalidArgumentError(
|
||||
"This case is not supported by apply mask operation");
|
||||
}
|
||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
const auto& shape0 = inputs[0]->tensor.shape;
|
||||
const auto& shape1 = inputs[1]->tensor.shape;
|
||||
|
||||
std::string source;
|
||||
if (inputs[1]->tensor.shape.c == 1) {
|
||||
// Broadcast case, mask channels size == 1.
|
||||
source =
|
||||
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "
|
||||
"$input_data_1[gid.x, gid.y, 0]$.x;";
|
||||
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
|
||||
if (shape1.c == 1) {
|
||||
// [H, W, C] x [H, W, 0][0]
|
||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
|
||||
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
|
||||
// [H, W, C] x [H, W, C]
|
||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
|
||||
} else {
|
||||
// Bitwise multiplication case, src channels size == mask channels size.
|
||||
source =
|
||||
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "
|
||||
"$input_data_1[gid.x, gid.y, 0]$;";
|
||||
// [H, W, C] x [0, 0, C]
|
||||
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
|
||||
}
|
||||
|
||||
*generated_code = {
|
||||
|
Loading…
Reference in New Issue
Block a user