TFLite GPU: Make GPU delegate recognize MobileNet v3.
PiperOrigin-RevId: 256432626
This commit is contained in:
parent
3b239555fe
commit
107a35a551
@ -263,24 +263,22 @@ class ObjectReader {
|
|||||||
tflite_node_(tflite_node),
|
tflite_node_(tflite_node),
|
||||||
tensor_to_value_(tensor_to_value) {}
|
tensor_to_value_(tensor_to_value) {}
|
||||||
|
|
||||||
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
|
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) const {
|
||||||
if (idx >= tflite_node_->inputs->size) {
|
if (idx >= tflite_node_->inputs->size) {
|
||||||
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
|
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(
|
return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value);
|
||||||
ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value));
|
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetNumberOfRuntimeInputs() {
|
int GetNumberOfRuntimeInputs() const {
|
||||||
return GetNumberOfRuntimeInputsForNode(context_, tflite_node_);
|
return GetNumberOfRuntimeInputsForNode(context_, tflite_node_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) {
|
Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const {
|
||||||
if (idx >= tflite_node_->inputs->size) {
|
if (idx >= tflite_node_->inputs->size) {
|
||||||
return OutOfRangeError(StrCat("Input tensor index: ", idx));
|
return OutOfRangeError(StrCat("Input tensor index: ", idx));
|
||||||
}
|
}
|
||||||
int32_t tensor_idx = tflite_node_->inputs->data[idx];
|
const int tensor_idx = tflite_node_->inputs->data[idx];
|
||||||
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
|
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
|
||||||
return OutOfRangeError(StrCat("Tensor index: ", tensor_idx));
|
return OutOfRangeError(StrCat("Tensor index: ", tensor_idx));
|
||||||
}
|
}
|
||||||
@ -330,7 +328,7 @@ class ObjectReader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||||
Value<TensorRef<BHWC>>** value) {
|
Value<TensorRef<BHWC>>** value) const {
|
||||||
if (tensor_idx >= tensor_to_value_->size()) {
|
if (tensor_idx >= tensor_to_value_->size()) {
|
||||||
return OutOfRangeError(
|
return OutOfRangeError(
|
||||||
StrCat("ReadValue: input tensor index: ", tensor_idx));
|
StrCat("ReadValue: input tensor index: ", tensor_idx));
|
||||||
@ -351,6 +349,12 @@ class ObjectReader {
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteTensor* GetInputTensor(int index) const {
|
||||||
|
return index >= 0 && index < tflite_node_->inputs->size
|
||||||
|
? context_->tensors + tflite_node_->inputs->data[index]
|
||||||
|
: nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphFloat32* graph_ = nullptr;
|
GraphFloat32* graph_ = nullptr;
|
||||||
const TfLiteContext* context_ = nullptr;
|
const TfLiteContext* context_ = nullptr;
|
||||||
@ -1019,37 +1023,64 @@ class AddOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
// TODO(eignasheva): add shapes check.
|
if (tflite_node->inputs->size != 2) {
|
||||||
|
return UnimplementedError("ADD requires two input tensors.");
|
||||||
|
}
|
||||||
|
// TODO(eignasheva): Add shapes check.
|
||||||
TfLiteAddParams* tf_options = nullptr;
|
TfLiteAddParams* tf_options = nullptr;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
return RetrieveBuiltinData(tflite_node, &tf_options);
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Parse(const TfLiteNode* tflite_node,
|
Status Parse(const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||||
ObjectReader* reader) final {
|
ObjectReader* reader) final {
|
||||||
|
// TFLite currently only supports 2 input ADDs. Thus, the logic below only
|
||||||
|
// considers 2 input cases. The underlying GPU shader programs can accept
|
||||||
|
// more inputs, but the logic below would have to be expanded.
|
||||||
|
|
||||||
|
// Determine runtime/constant tensors.
|
||||||
|
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
||||||
|
if (!input0) {
|
||||||
|
return InvalidArgumentError("Couldn't get the 1st input tensor for ADD.");
|
||||||
|
}
|
||||||
|
const TfLiteTensor* input1 = reader->GetInputTensor(1);
|
||||||
|
if (!input1) {
|
||||||
|
return InvalidArgumentError("Couldn't get the 2nd input tensor for ADD.");
|
||||||
|
}
|
||||||
|
const bool constant_tensor0 = IsConstantTensor(input0);
|
||||||
|
const bool constant_tensor1 = IsConstantTensor(input1);
|
||||||
|
if (constant_tensor0 && constant_tensor1) {
|
||||||
|
return InvalidArgumentError("No runtime input tensors for ADD.");
|
||||||
|
}
|
||||||
|
const bool runtime_tensor0 = !constant_tensor0;
|
||||||
|
const bool runtime_tensor1 = !constant_tensor1;
|
||||||
|
|
||||||
Node* node = graph->NewNode();
|
Node* node = graph->NewNode();
|
||||||
node->operation.type = ToString(OperationType::ADD);
|
node->operation.type = ToString(OperationType::ADD);
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||||
|
|
||||||
AddAttributes attr;
|
AddAttributes attr;
|
||||||
for (int idx = 0; idx < tflite_node->inputs->size; ++idx) {
|
if (runtime_tensor0 && runtime_tensor1) {
|
||||||
if (!reader->AddInput(node, idx).ok()) {
|
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||||
if (tflite_node->inputs->size != 2) {
|
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
||||||
return InvalidArgumentError(
|
} else {
|
||||||
"Broadcast Add should accept 2 inputs, one input tensor and "
|
int runtime_tensor = 0;
|
||||||
"broadcasted tensor");
|
int constant_tensor = 1;
|
||||||
}
|
TfLiteIntArray* constant_dims = input1->dims;
|
||||||
TfLiteIntArray dims;
|
if (constant_tensor0 && runtime_tensor1) {
|
||||||
RETURN_IF_ERROR(reader->GetTensorDims(1, &dims));
|
runtime_tensor = 1;
|
||||||
if (dims.size <= 0) {
|
constant_tensor = 0;
|
||||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
constant_dims = input0->dims;
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
}
|
||||||
attr.param = tensor.data[0];
|
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||||
} else {
|
if (constant_dims->size <= 0) {
|
||||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||||
attr.param = std::move(tensor);
|
attr.param = tensor.data[0];
|
||||||
}
|
} else {
|
||||||
|
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||||
|
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||||
|
attr.param = std::move(tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
node->operation.attributes = std::move(attr);
|
node->operation.attributes = std::move(attr);
|
||||||
@ -1059,9 +1090,8 @@ class AddOperationParser : public TFLiteOperationParser {
|
|||||||
if (!tf_options) {
|
if (!tf_options) {
|
||||||
return InternalError("Missing tflite params");
|
return InternalError("Missing tflite params");
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
|
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
|
||||||
graph, node));
|
node);
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1427,41 +1457,119 @@ class ReLuOperationParser : public TFLiteOperationParser {
|
|||||||
int clip_;
|
int clip_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
|
||||||
|
const TfLiteIntArray* dims = tflite_tensor.dims;
|
||||||
|
switch (dims->size) {
|
||||||
|
case 1:
|
||||||
|
*bhwc = BHWC(dims->data[0], 1, 1, 1);
|
||||||
|
return OkStatus();
|
||||||
|
case 2:
|
||||||
|
*bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
|
||||||
|
return OkStatus();
|
||||||
|
case 3:
|
||||||
|
*bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
|
||||||
|
return OkStatus();
|
||||||
|
case 4:
|
||||||
|
*bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
|
||||||
|
return OkStatus();
|
||||||
|
default:
|
||||||
|
return InvalidArgumentError(
|
||||||
|
absl::StrCat("Tensor \"", tflite_tensor.name,
|
||||||
|
"\" has bad input dims size: ", dims->size, "."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class MulOperationParser : public TFLiteOperationParser {
|
class MulOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
// TODO(eignasheva): add params check
|
if (tflite_node->inputs->size != 2) {
|
||||||
|
return UnimplementedError("MUL requires two input tensors.");
|
||||||
|
}
|
||||||
|
// TODO(eignasheva): Add params check.
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Parse(const TfLiteNode* tflite_node,
|
Status Parse(const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||||
ObjectReader* reader) final {
|
ObjectReader* reader) final {
|
||||||
Node* node = graph->NewNode();
|
// Determine runtime/constant tensors.
|
||||||
if (reader->GetNumberOfRuntimeInputs() == 2) {
|
const TfLiteTensor* input0 = reader->GetInputTensor(0);
|
||||||
// ApplyMask operation
|
if (!input0) {
|
||||||
node->operation.type = ToString(OperationType::APPLY_MASK);
|
return InvalidArgumentError("Couldn't get the 1st input tensor for MUL.");
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
|
||||||
} else {
|
|
||||||
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
|
||||||
MultiplyScalarAttributes attr;
|
|
||||||
TfLiteIntArray dims;
|
|
||||||
RETURN_IF_ERROR(reader->GetTensorDims(1, &dims));
|
|
||||||
if (dims.size <= 0) {
|
|
||||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
|
||||||
attr.param = tensor.data[0];
|
|
||||||
} else {
|
|
||||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(1, &tensor));
|
|
||||||
attr.param = std::move(tensor);
|
|
||||||
}
|
|
||||||
node->operation.attributes = std::move(attr);
|
|
||||||
}
|
}
|
||||||
|
const TfLiteTensor* input1 = reader->GetInputTensor(1);
|
||||||
|
if (!input1) {
|
||||||
|
return InvalidArgumentError("Couldn't get the 2nd input tensor for MUL.");
|
||||||
|
}
|
||||||
|
const bool constant_tensor0 = IsConstantTensor(input0);
|
||||||
|
const bool constant_tensor1 = IsConstantTensor(input1);
|
||||||
|
if (constant_tensor0 && constant_tensor1) {
|
||||||
|
return InvalidArgumentError("No runtime input tensors for MUL.");
|
||||||
|
}
|
||||||
|
const bool runtime_tensor0 = !constant_tensor0;
|
||||||
|
const bool runtime_tensor1 = !constant_tensor1;
|
||||||
|
|
||||||
|
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
|
||||||
|
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
|
||||||
|
if (runtime_tensor0 && runtime_tensor1) {
|
||||||
|
BHWC shape0;
|
||||||
|
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
||||||
|
BHWC shape1;
|
||||||
|
RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
|
||||||
|
int input_tensor0 = 0;
|
||||||
|
int input_tensor1 = 1;
|
||||||
|
if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
|
||||||
|
shape0.c == shape1.c) {
|
||||||
|
input_tensor0 = 1;
|
||||||
|
input_tensor1 = 0;
|
||||||
|
}
|
||||||
|
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
|
||||||
|
// input and the constant input tensor must be bound to 2nd input.
|
||||||
|
int runtime_tensor = 0;
|
||||||
|
int constant_tensor = 1;
|
||||||
|
TfLiteIntArray* constant_dims = input1->dims;
|
||||||
|
if (constant_tensor0 && runtime_tensor1) {
|
||||||
|
runtime_tensor = 1;
|
||||||
|
constant_tensor = 0;
|
||||||
|
constant_dims = input0->dims;
|
||||||
|
}
|
||||||
|
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
|
||||||
|
graph, reader);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status ParseApplyMask(int input_tensor0, int input_tensor1,
|
||||||
|
GraphFloat32* graph, ObjectReader* reader) {
|
||||||
|
Node* node = graph->NewNode();
|
||||||
|
node->operation.type = ToString(OperationType::APPLY_MASK);
|
||||||
|
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||||
|
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||||
|
return reader->AddOutputs(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
|
||||||
|
const TfLiteIntArray* constant_dims,
|
||||||
|
GraphFloat32* graph, ObjectReader* reader) {
|
||||||
|
Node* node = graph->NewNode();
|
||||||
|
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
|
||||||
|
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||||
|
MultiplyScalarAttributes attr;
|
||||||
|
if (constant_dims->size <= 0) {
|
||||||
|
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||||
|
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||||
|
attr.param = tensor.data[0];
|
||||||
|
} else {
|
||||||
|
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||||
|
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||||
|
attr.param = std::move(tensor);
|
||||||
|
}
|
||||||
|
node->operation.attributes = std::move(attr);
|
||||||
return reader->AddOutputs(node);
|
return reader->AddOutputs(node);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1963,26 +2071,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||||||
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||||
TensorRef<BHWC>* tensor_ref) {
|
TensorRef<BHWC>* tensor_ref) {
|
||||||
tensor_ref->type = ToDataType(tflite_tensor.type);
|
tensor_ref->type = ToDataType(tflite_tensor.type);
|
||||||
const TfLiteIntArray* dims = tflite_tensor.dims;
|
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
|
||||||
switch (dims->size) {
|
|
||||||
case 1:
|
|
||||||
tensor_ref->shape = BHWC(dims->data[0], 1, 1, 1);
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
tensor_ref->shape = BHWC(dims->data[0], 1, 1, dims->data[1]);
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
tensor_ref->shape = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
tensor_ref->shape =
|
|
||||||
BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return InvalidArgumentError(StrCat(
|
|
||||||
"Tensor ref has unsupported number of dimensions: ", dims->size));
|
|
||||||
}
|
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
|
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
|
||||||
|
@ -286,12 +286,12 @@ cc_library(
|
|||||||
srcs = ["mul.cc"],
|
srcs = ["mul.cc"],
|
||||||
hdrs = ["mul.h"],
|
hdrs = ["mul.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
@ -33,34 +34,21 @@ namespace {
|
|||||||
class ApplyMask : public NodeShader {
|
class ApplyMask : public NodeShader {
|
||||||
public:
|
public:
|
||||||
static bool IsSupported(const GenerationContext& ctx) {
|
static bool IsSupported(const GenerationContext& ctx) {
|
||||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
|
if (inputs.size() != 2) return false;
|
||||||
|
const auto& shape0 = inputs[0]->tensor.shape;
|
||||||
|
const auto& shape1 = inputs[1]->tensor.shape;
|
||||||
|
|
||||||
// Implementation requires 2 input tensors: source and mask.
|
// [H, W, C] x [H, W, 0][0]
|
||||||
if (inputs.size() != 2) {
|
if (shape1.c == 1) return true;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto src_shape = inputs[0]->tensor.shape;
|
if (shape0.c != shape1.c) return false;
|
||||||
auto mask_shape = inputs[1]->tensor.shape;
|
|
||||||
|
|
||||||
// Height and width dimensions of the two input tensors must be the same.
|
// [H, W, C] x [H, W, C]
|
||||||
if (src_shape.h != mask_shape.h || src_shape.w != mask_shape.w) {
|
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast will be done if mask tensor has 1 channel.
|
// [H, W, C] x [0, 0, C]
|
||||||
if (mask_shape.c == 1) {
|
return shape1.h == 1 && shape1.w == 1;
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bitwise multiplication will be done if mask tensor has the same amount of
|
|
||||||
// channels as source tensor.
|
|
||||||
if (src_shape.c == mask_shape.c) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other cases are not supported.
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GenerateCode(const GenerationContext& ctx,
|
Status GenerateCode(const GenerationContext& ctx,
|
||||||
@ -69,19 +57,20 @@ class ApplyMask : public NodeShader {
|
|||||||
return InvalidArgumentError(
|
return InvalidArgumentError(
|
||||||
"This case is not supported by apply mask operation");
|
"This case is not supported by apply mask operation");
|
||||||
}
|
}
|
||||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
|
const auto& shape0 = inputs[0]->tensor.shape;
|
||||||
|
const auto& shape1 = inputs[1]->tensor.shape;
|
||||||
|
|
||||||
std::string source;
|
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
|
||||||
if (inputs[1]->tensor.shape.c == 1) {
|
if (shape1.c == 1) {
|
||||||
// Broadcast case, mask channels size == 1.
|
// [H, W, C] x [H, W, 0][0]
|
||||||
source =
|
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
|
||||||
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "
|
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
|
||||||
"$input_data_1[gid.x, gid.y, 0]$.x;";
|
// [H, W, C] x [H, W, C]
|
||||||
|
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
|
||||||
} else {
|
} else {
|
||||||
// Bitwise multiplication case, src channels size == mask channels size.
|
// [H, W, C] x [0, 0, C]
|
||||||
source =
|
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
|
||||||
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "
|
|
||||||
"$input_data_1[gid.x, gid.y, 0]$;";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*generated_code = {
|
*generated_code = {
|
||||||
|
Loading…
Reference in New Issue
Block a user