TFLite GPU: Make GPU delegate recognize MobileNet v3.

PiperOrigin-RevId: 256432626
This commit is contained in:
Juhyun Lee 2019-07-03 13:53:48 -07:00 committed by TensorFlower Gardener
parent 3b239555fe
commit 107a35a551
3 changed files with 189 additions and 111 deletions

View File

@ -263,24 +263,22 @@ class ObjectReader {
tflite_node_(tflite_node),
tensor_to_value_(tensor_to_value) {}
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) const {
if (idx >= tflite_node_->inputs->size) {
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
}
RETURN_IF_ERROR(
ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value));
return OkStatus();
return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value);
}
int GetNumberOfRuntimeInputs() {
int GetNumberOfRuntimeInputs() const {
return GetNumberOfRuntimeInputsForNode(context_, tflite_node_);
}
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) {
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const {
if (idx >= tflite_node_->inputs->size) {
return OutOfRangeError(StrCat("Input tensor index: ", idx));
}
int32_t tensor_idx = tflite_node_->inputs->data[idx];
const int tensor_idx = tflite_node_->inputs->data[idx];
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
return OutOfRangeError(StrCat("Tensor index: ", tensor_idx));
}
@ -330,7 +328,7 @@ class ObjectReader {
}
Status ReadValueByTensorIdx(uint32_t tensor_idx,
Value<TensorRef<BHWC>>** value) {
Value<TensorRef<BHWC>>** value) const {
if (tensor_idx >= tensor_to_value_->size()) {
return OutOfRangeError(
StrCat("ReadValue: input tensor index: ", tensor_idx));
@ -351,6 +349,12 @@ class ObjectReader {
return OkStatus();
}
TfLiteTensor* GetInputTensor(int index) const {
return index >= 0 && index < tflite_node_->inputs->size
? context_->tensors + tflite_node_->inputs->data[index]
: nullptr;
}
private:
GraphFloat32* graph_ = nullptr;
const TfLiteContext* context_ = nullptr;
@ -1019,39 +1023,66 @@ class AddOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
// TODO(eignasheva): add shapes check.
TfLiteAddParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return OkStatus();
if (tflite_node->inputs->size != 2) {
return UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
TfLiteAddParams* tf_options = nullptr;
return RetrieveBuiltinData(tflite_node, &tf_options);
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
// TFLite currently only supports 2 input ADDs. Thus, the logic below only
// considers 2 input cases. The underlying GPU shader programs can accept
// more inputs, but the logic below would have to be expanded.
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return InvalidArgumentError("Couldn't get the 1st input tensor for ADD.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return InvalidArgumentError("Couldn't get the 2nd input tensor for ADD.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return InvalidArgumentError("No runtime input tensors for ADD.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::ADD);
RETURN_IF_ERROR(reader->AddOutputs(node));
AddAttributes attr;
for (int idx = 0; idx < tflite_node->inputs->size; ++idx) {
if (!reader->AddInput(node, idx).ok()) {
if (tflite_node->inputs->size != 2) {
return InvalidArgumentError(
"Broadcast Add should accept 2 inputs, one input tensor and "
"broadcasted tensor");
if (runtime_tensor0 && runtime_tensor1) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
TfLiteIntArray dims;
RETURN_IF_ERROR(reader->GetTensorDims(1, &dims));
if (dims.size <= 0) {
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
}
}
node->operation.attributes = std::move(attr);
const auto* tf_options =
@ -1059,9 +1090,8 @@ class AddOperationParser : public TFLiteOperationParser {
if (!tf_options) {
return InternalError("Missing tflite params");
}
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
return OkStatus();
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node);
}
};
@ -1427,41 +1457,119 @@ class ReLuOperationParser : public TFLiteOperationParser {
int clip_;
};
Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
const TfLiteIntArray* dims = tflite_tensor.dims;
switch (dims->size) {
case 1:
*bhwc = BHWC(dims->data[0], 1, 1, 1);
return OkStatus();
case 2:
*bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
return OkStatus();
case 3:
*bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
return OkStatus();
case 4:
*bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
return OkStatus();
default:
return InvalidArgumentError(
absl::StrCat("Tensor \"", tflite_tensor.name,
"\" has bad input dims size: ", dims->size, "."));
}
}
class MulOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
// TODO(eignasheva): add params check
if (tflite_node->inputs->size != 2) {
return UnimplementedError("MUL requires two input tensors.");
}
// TODO(eignasheva): Add params check.
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return InvalidArgumentError("Couldn't get the 1st input tensor for MUL.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return InvalidArgumentError("Couldn't get the 2nd input tensor for MUL.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return InvalidArgumentError("No runtime input tensors for MUL.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
if (runtime_tensor0 && runtime_tensor1) {
BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
BHWC shape1;
RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
int input_tensor0 = 0;
int input_tensor1 = 1;
if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
shape0.c == shape1.c) {
input_tensor0 = 1;
input_tensor1 = 0;
}
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
}
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
// input and the constant input tensor must be bound to 2nd input.
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
graph, reader);
}
private:
Status ParseApplyMask(int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
if (reader->GetNumberOfRuntimeInputs() == 2) {
// ApplyMask operation
node->operation.type = ToString(OperationType::APPLY_MASK);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node);
}
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
MultiplyScalarAttributes attr;
TfLiteIntArray dims;
RETURN_IF_ERROR(reader->GetTensorDims(1, &dims));
if (dims.size <= 0) {
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
node->operation.attributes = std::move(attr);
}
return reader->AddOutputs(node);
}
};
@ -1963,26 +2071,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRef<BHWC>* tensor_ref) {
tensor_ref->type = ToDataType(tflite_tensor.type);
const TfLiteIntArray* dims = tflite_tensor.dims;
switch (dims->size) {
case 1:
tensor_ref->shape = BHWC(dims->data[0], 1, 1, 1);
break;
case 2:
tensor_ref->shape = BHWC(dims->data[0], 1, 1, dims->data[1]);
break;
case 3:
tensor_ref->shape = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
break;
case 4:
tensor_ref->shape =
BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
break;
default:
return InvalidArgumentError(StrCat(
"Tensor ref has unsupported number of dimensions: ", dims->size));
}
return OkStatus();
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
}
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,

View File

@ -286,12 +286,12 @@ cc_library(
srcs = ["mul.cc"],
hdrs = ["mul.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
@ -33,34 +34,21 @@ namespace {
class ApplyMask : public NodeShader {
public:
static bool IsSupported(const GenerationContext& ctx) {
auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
if (inputs.size() != 2) return false;
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
// Implementation requires 2 input tensors: source and mask.
if (inputs.size() != 2) {
return false;
}
// [H, W, C] x [H, W, 0][0]
if (shape1.c == 1) return true;
auto src_shape = inputs[0]->tensor.shape;
auto mask_shape = inputs[1]->tensor.shape;
if (shape0.c != shape1.c) return false;
// Height and width dimensions of the two input tensors must be the same.
if (src_shape.h != mask_shape.h || src_shape.w != mask_shape.w) {
return false;
}
// [H, W, C] x [H, W, C]
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
// Broadcast will be done if mask tensor has 1 channel.
if (mask_shape.c == 1) {
return true;
}
// Bitwise multiplication will be done if mask tensor has the same amount of
// channels as source tensor.
if (src_shape.c == mask_shape.c) {
return true;
}
// Other cases are not supported.
return false;
// [H, W, C] x [0, 0, C]
return shape1.h == 1 && shape1.w == 1;
}
Status GenerateCode(const GenerationContext& ctx,
@ -69,19 +57,20 @@ class ApplyMask : public NodeShader {
return InvalidArgumentError(
"This case is not supported by apply mask operation");
}
auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
std::string source;
if (inputs[1]->tensor.shape.c == 1) {
// Broadcast case, mask channels size == 1.
source =
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "
"$input_data_1[gid.x, gid.y, 0]$.x;";
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
if (shape1.c == 1) {
// [H, W, C] x [H, W, 0][0]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
// [H, W, C] x [H, W, C]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
} else {
// Bitwise multiplication case, src channels size == mask channels size.
source =
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "
"$input_data_1[gid.x, gid.y, 0]$;";
// [H, W, C] x [0, 0, C]
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
}
*generated_code = {