Add Maximum & Minimum op support for GPU delegate

Refactored elementwise op kernel to handle Maximum & Minimum.

PiperOrigin-RevId: 296146084
Change-Id: Iefd333b79638d8705b28167657af475aa75e639a
This commit is contained in:
Terry Heo 2020-02-20 00:20:05 -08:00 committed by TensorFlower Gardener
parent a6ec8dadc4
commit 76562fef92
12 changed files with 445 additions and 74 deletions

View File

@ -34,6 +34,8 @@ TFLite on GPU supports the following ops in 16-bit and 32-bit float precision:
* `LOGISTIC v1`
* `LSTM v2 (Basic LSTM only)`
* `MAX_POOL_2D v1`
* `MAXIMUM v1`
* `MINIMUM v1`
* `MUL v1`
* `PAD v1`
* `PRELU v1`

View File

@ -106,7 +106,9 @@ ElementwiseTwoInput::ElementwiseTwoInput(ElementwiseTwoInput&& operation)
: ElementwiseOperation(std::move(operation)),
link_index_(operation.link_index_),
op_type_(operation.op_type_),
broadcast_(operation.broadcast_) {}
broadcast_(operation.broadcast_),
scalar_para_(operation.scalar_para_),
use_scalar_para_(operation.use_scalar_para_) {}
ElementwiseTwoInput& ElementwiseTwoInput::operator=(
ElementwiseTwoInput&& operation) {
@ -114,30 +116,43 @@ ElementwiseTwoInput& ElementwiseTwoInput::operator=(
link_index_ = operation.link_index_;
op_type_ = operation.op_type_;
broadcast_ = operation.broadcast_;
scalar_para_ = operation.scalar_para_;
use_scalar_para_ = operation.use_scalar_para_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void ElementwiseTwoInput::SetLinkIndex(int index) { link_index_ = index; }
void ElementwiseTwoInput::SetLinkIndex(int index) {
link_index_ = index;
if (use_scalar_para_) {
scalar_para_.SetName(absl::StrCat("scalar_para_", index));
}
}
std::string ElementwiseTwoInput::GetCoreCode(
const LinkingContext& context) const {
const std::string size_name = "src_size_" + std::to_string(link_index_);
TensorCodeGenerator src_tensor(
absl::StrCat("src_data_", link_index_),
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
const std::string x_coord = broadcast_.width ? "0" : context.x_coord;
const std::string y_coord = broadcast_.height ? "0" : context.y_coord;
const std::string s_coord = broadcast_.channels ? "0" : context.s_coord;
const std::string second_var = "second_var_" + std::to_string(link_index_);
std::string result = " FLT4 " + second_var + " = " +
src_tensor.ReadWHS(x_coord, y_coord, s_coord) + ";\n";
if (broadcast_.channels) {
result += " " + second_var + ".y = " + second_var + ".x;\n";
result += " " + second_var + ".z = " + second_var + ".x;\n";
result += " " + second_var + ".w = " + second_var + ".x;\n";
std::string result;
std::string second_var;
if (use_scalar_para_) {
second_var = absl::StrCat("(FLT)(", scalar_para_.GetName(), ")");
} else {
const std::string size_name = "src_size_" + std::to_string(link_index_);
TensorCodeGenerator src_tensor(
absl::StrCat("src_data_", link_index_),
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
const std::string x_coord = broadcast_.width ? "0" : context.x_coord;
const std::string y_coord = broadcast_.height ? "0" : context.y_coord;
const std::string s_coord = broadcast_.channels ? "0" : context.s_coord;
second_var = "second_var_" + std::to_string(link_index_);
result = " FLT4 " + second_var + " = " +
src_tensor.ReadWHS(x_coord, y_coord, s_coord) + ";\n";
if (broadcast_.channels) {
result += " " + second_var + ".y = " + second_var + ".x;\n";
result += " " + second_var + ".z = " + second_var + ".x;\n";
result += " " + second_var + ".w = " + second_var + ".x;\n";
}
}
switch (op_type_) {
case OperationType::ADD:
@ -146,6 +161,12 @@ std::string ElementwiseTwoInput::GetCoreCode(
case OperationType::DIV:
result += "$0 /= $1;\n";
break;
case OperationType::MAXIMUM:
result += "$0 = max($0, $1);\n";
break;
case OperationType::MINIMUM:
result += "$0 = min($0, $1);\n";
break;
case OperationType::MUL:
result += "$0 *= $1;\n";
break;
@ -167,20 +188,44 @@ std::string ElementwiseTwoInput::GetCoreCode(
std::string ElementwiseTwoInput::GetArgsDeclaration() const {
std::string args;
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ,
absl::StrCat("src_data_", link_index_),
definition_.src_tensors[1]));
absl::StrAppend(&args, ",\n int4 src_size_", link_index_);
if (use_scalar_para_) {
absl::StrAppend(&args, ",\n ", scalar_para_.GetDeclaration());
} else {
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ,
absl::StrCat("src_data_", link_index_),
definition_.src_tensors[1]));
absl::StrAppend(&args, ",\n int4 src_size_", link_index_);
}
return args;
}
Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
if (use_scalar_para_) {
RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_para_));
} else {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
}
return OkStatus();
}
ElementwiseTwoInput CreateElementwiseTwoInput(
const CreationContext& creation_context, const OperationDef& definition,
const OperationType& op_type, const BroadcastSettings& broadcast,
const ElementwiseAttributes& attr) {
ElementwiseTwoInput operation(definition, op_type, broadcast);
auto scalar = absl::get_if<float>(&attr.param);
if (scalar) {
const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32
: definition.precision;
operation.SetScalarPara(FLT(scalar_precision, *scalar));
}
operation.SetLinkIndex(0);
return operation;
}
ElementwiseTwoInput CreateElementwiseTwoInput(
const OperationDef& definition, const OperationType& op_type,
const BroadcastSettings& broadcast) {

View File

@ -63,7 +63,8 @@ class ElementwiseTwoInput : public ElementwiseOperation {
const BroadcastSettings& broadcast)
: ElementwiseOperation(definition),
op_type_(op_type),
broadcast_(broadcast) {}
broadcast_(broadcast),
use_scalar_para_(false) {}
// Move only
ElementwiseTwoInput(ElementwiseTwoInput&& operation);
@ -75,13 +76,24 @@ class ElementwiseTwoInput : public ElementwiseOperation {
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
inline void SetScalarPara(FLT scalar) {
scalar_para_ = scalar;
use_scalar_para_ = true;
}
private:
int link_index_;
OperationType op_type_;
BroadcastSettings broadcast_;
FLT scalar_para_;
bool use_scalar_para_;
};
ElementwiseTwoInput CreateElementwiseTwoInput(
const CreationContext& creation_context, const OperationDef& definition,
const OperationType& op_type, const BroadcastSettings& broadcast,
const ElementwiseAttributes& attr);
ElementwiseTwoInput CreateElementwiseTwoInput(
const OperationDef& definition, const OperationType& op_type,
const BroadcastSettings& broadcast);

View File

@ -425,6 +425,118 @@ TEST_F(OpenCLOperationTest, Add) {
}
}
TEST_F(OpenCLOperationTest, Maxiumum) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, -6.2f, 2.0f, -3.0f};
src_tensor_1.data = {1.0f, 2.0f, 3.0f, -2.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, OperationType::MAXIMUM);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {1.0f, 2.0f, 3.0f, -2.0f}));
}
}
}
TEST_F(OpenCLOperationTest, MaxiumumWithScalar) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 4, 1, 1);
src_tensor_0.data = {0.0f, -6.2f, 2.0f, -3.0f};
ElementwiseAttributes attr;
attr.param = -1.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
BroadcastSettings broadcast;
ElementwiseTwoInput operation = CreateElementwiseTwoInput(
creation_context_, op_def, OperationType::MAXIMUM, broadcast, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 4, 1, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, -1.0f, 2.0f, -1.0f}));
}
}
}
TEST_F(OpenCLOperationTest, Minimum) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, -6.2f, 2.0f, -3.0f};
src_tensor_1.data = {1.0f, 2.0f, 3.0f, -2.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, OperationType::MINIMUM);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, -6.2f, 2.0f, -3.0f}));
}
}
}
TEST_F(OpenCLOperationTest, MinimumWithScalar) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 4, 1, 1);
src_tensor_0.data = {0.0f, -6.2f, 2.0f, -3.0f};
ElementwiseAttributes attr;
attr.param = -1.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
BroadcastSettings broadcast;
ElementwiseTwoInput operation = CreateElementwiseTwoInput(
creation_context_, op_def, OperationType::MINIMUM, broadcast, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 4, 1, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-1.0f, -6.2f, -1.0f, -3.0f}));
}
}
}
TEST_F(OpenCLOperationTest, Mul) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);

View File

@ -231,6 +231,8 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
return OkStatus();
}
case OperationType::DIV:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB: {
@ -238,8 +240,10 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, op_type, broadcast);
const auto attr =
absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
ElementwiseTwoInput operation = CreateElementwiseTwoInput(
creation_context, op_def, op_type, broadcast, attr);
*gpu_op = absl::make_unique<ElementwiseTwoInput>(std::move(operation));
return OkStatus();
}

View File

@ -389,6 +389,39 @@ Status CheckInputsOutputs(const TfLiteContext* context,
return OkStatus();
}
// The function checks input tensors including 1 constant tensor.
Status CheckInputsOutputsAllowingOneConstInput(const TfLiteContext* context,
const TfLiteNode* tflite_node,
int inputs, int outputs) {
int number_of_const_inputs = 0;
int number_of_runtime_inputs = 0;
for (int i = 0; i < tflite_node->inputs->size; i++) {
if (IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
number_of_const_inputs++;
} else {
number_of_runtime_inputs++;
}
}
if (tflite_node->inputs->size != inputs) {
return InternalError(absl::StrFormat(
"Expected %d input tensor(s), but node has %d input(s).", inputs,
tflite_node->inputs->size));
}
if (number_of_const_inputs > 1) {
return InternalError(absl::StrFormat(
"Expected 1 const input tensor, but node has %d const input(s).",
number_of_const_inputs));
}
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
if (runtime_outputs != outputs) {
return InternalError(
absl::StrFormat("Expected %d output tensor(s), but node has %d runtime "
"output(s).",
outputs, runtime_outputs));
}
return OkStatus();
}
// A parser responsible for parsing TFLite operation and adding it to a graph.
class TFLiteOperationParser {
public:
@ -642,6 +675,55 @@ Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
}
}
Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
TensorOrScalar* tensor_or_scalar) {
const std::string& opname = node->operation.type;
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return InvalidArgumentError("Couldn't get the 1st input tensor for " +
opname);
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return InvalidArgumentError("Couldn't get the 2nd input tensor for " +
opname);
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return InvalidArgumentError("No runtime input tensors for " + opname);
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
if (runtime_tensor0 && runtime_tensor1) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
}
}
return OkStatus();
}
class AddOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
@ -663,51 +745,11 @@ class AddOperationParser : public TFLiteOperationParser {
// considers 2 input cases. The underlying GPU shader programs can accept
// more inputs, but the logic below would have to be expanded.
// Determine runtime/constant tensors.
const TfLiteTensor* input0 = reader->GetInputTensor(0);
if (!input0) {
return InvalidArgumentError("Couldn't get the 1st input tensor for ADD.");
}
const TfLiteTensor* input1 = reader->GetInputTensor(1);
if (!input1) {
return InvalidArgumentError("Couldn't get the 2nd input tensor for ADD.");
}
const bool constant_tensor0 = IsConstantTensor(input0);
const bool constant_tensor1 = IsConstantTensor(input1);
if (constant_tensor0 && constant_tensor1) {
return InvalidArgumentError("No runtime input tensors for ADD.");
}
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::ADD);
RETURN_IF_ERROR(reader->AddOutputs(node));
AddAttributes attr;
if (runtime_tensor0 && runtime_tensor1) {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
} else {
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
if (constant_tensor0 && runtime_tensor1) {
runtime_tensor = 1;
constant_tensor = 0;
constant_dims = input0->dims;
}
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
}
}
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
const auto* tf_options =
reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
@ -1053,6 +1095,11 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
} else if (IsTwoArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2,
/*outputs=*/1));
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(CheckInputsOutputsAllowingOneConstInput(context,
tflite_node,
/*inputs=*/2,
/*outputs=*/1));
} else {
return InvalidArgumentError("Op can only handle 1 or 2 operand(s).");
}
@ -1103,6 +1150,16 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(
MaybeFuseActivationToTheSingleOutput(activation, graph, node));
}
} else if (IsTwoArgumentOperationWithConst()) {
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
auto const_vector =
absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
&attr.param);
if (const_vector) {
return InvalidArgumentError("Constant vector is not supported");
}
node->operation.attributes = std::move(attr);
} else {
return InvalidArgumentError("Incorrect operation type passed");
}
@ -1161,6 +1218,16 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
}
}
bool IsTwoArgumentOperationWithConst() const {
switch (operation_type_) {
case OperationType::MINIMUM:
case OperationType::MAXIMUM:
return true;
default:
return false;
}
}
OperationType operation_type_;
};
@ -2547,10 +2614,16 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return absl::make_unique<ElementwiseOperationParser>(OperationType::LOG);
case kTfLiteBuiltinLstm:
return absl::make_unique<LSTMOperationParser>();
case kTfLiteBuiltinMaximum:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::MAXIMUM);
case kTfLiteBuiltinMaxPool2d:
return absl::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
case kTfLiteBuiltinMean:
return absl::make_unique<MeanOperationParser>();
case kTfLiteBuiltinMinimum:
return absl::make_unique<ElementwiseOperationParser>(
OperationType::MINIMUM);
case kTfLiteBuiltinMirrorPad:
return absl::make_unique<PadOperationParser>(/*mirror_pad=*/true);
case kTfLiteBuiltinMul:

View File

@ -98,10 +98,14 @@ std::string ToString(enum OperationType op) {
return "log";
case OperationType::LSTM:
return "lstm";
case OperationType::MAXIMUM:
return "maximum";
case OperationType::MAX_UNPOOLING_2D:
return "max_unpooling";
case OperationType::MEAN:
return "mean";
case OperationType::MINIMUM:
return "minimum";
case OperationType::MUL:
return "mul";
case OperationType::PAD:
@ -165,8 +169,10 @@ OperationType OperationTypeFromString(const std::string& name) {
{"hard_swish", OperationType::HARD_SWISH},
{"log", OperationType::LOG},
{"lstm", OperationType::LSTM},
{"maximum", OperationType::MAXIMUM},
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
{"mean", OperationType::MEAN},
{"minimum", OperationType::MINIMUM},
{"mul", OperationType::MUL},
{"pad", OperationType::PAD},
{"pooling_2d", OperationType::POOLING_2D},

View File

@ -47,8 +47,10 @@ enum class OperationType {
HARD_SWISH,
LOG,
LSTM,
MAXIMUM,
MAX_UNPOOLING_2D,
MEAN,
MINIMUM,
MUL,
PAD,
POOLING_2D,
@ -75,6 +77,9 @@ std::string ToString(enum OperationType op);
OperationType OperationTypeFromString(const std::string& name);
typedef absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
TensorOrScalar;
struct Padding2D {
Padding2D() = default;
Padding2D& operator=(const Padding2D& value);
@ -352,8 +357,7 @@ struct LstmAttributes {
};
struct MultiplyAttributes {
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
param;
TensorOrScalar param;
};
enum class SamplingType {
@ -435,8 +439,7 @@ struct SliceAttributes {
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr);
struct AddAttributes {
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
param;
TensorOrScalar param;
};
struct FullyConnectedAttributes {
@ -452,6 +455,10 @@ BHWC CalculateOutputShape(const BHWC& input,
// @return shape of a tensor after Mean operation is applied to the given input.
BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr);
struct ElementwiseAttributes {
TensorOrScalar param;
};
struct ReshapeAttributes {
BHWC new_shape;
};

View File

@ -139,6 +139,14 @@ class ElementwiseTwoArguments : public NodeShader {
source = "value_0 /= value_1;";
break;
}
case OperationType::MAXIMUM: {
source = "value_0 = max(value_0, value_1);";
break;
}
case OperationType::MINIMUM: {
source = "value_0 = min(value_0, value_1);";
break;
}
case OperationType::POW: {
// From documentation :
// The result is undefined if x<0 or if x=0 and y≤0.
@ -167,6 +175,37 @@ class ElementwiseTwoArguments : public NodeShader {
return OkStatus();
}
Status ImplementElementwiseWithScalar(const GenerationContext& ctx,
const float scalar,
GeneratedCode* generated_code) const {
std::string source;
switch (operation_type_) {
case OperationType::MAXIMUM: {
source = "value_0 = max(value_0, $scalar$);";
break;
}
case OperationType::MINIMUM: {
source = "value_0 = min(value_0, $scalar$);";
break;
}
default:
return InvalidArgumentError(
"Incorrect elementwise with scalar operation type.");
}
*generated_code = {
/*parameters=*/{{"scalar", scalar}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/source,
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
auto inputs = ctx.graph->FindInputs(ctx.node->id);
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
@ -219,8 +258,15 @@ class ElementwiseTwoArguments : public NodeShader {
if (IsSupportedBroadcast(ctx)) {
return ImplementElementwiseBroadcast(ctx, generated_code);
}
auto attr =
absl::any_cast<ElementwiseAttributes>(ctx.node->operation.attributes);
auto scalar = absl::get_if<float>(&attr.param);
if (scalar) {
return ImplementElementwiseWithScalar(ctx, *scalar, generated_code);
}
return InvalidArgumentError(
"This case is not supported by subtract operation");
"This case is not supported by elementwise with two arguments "
"operation");
}
private:
@ -244,6 +290,8 @@ std::unique_ptr<NodeShader> NewElementwiseNodeShader(
case OperationType::TANH:
return absl::make_unique<ElementwiseOneArgument>(operation_type);
case OperationType::DIV:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:

View File

@ -100,6 +100,64 @@ TEST(ElementwiseTest, Log) {
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
}
TEST(ElementwiseTest, Maximum) {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, -2.0}));
}
TEST(ElementwiseTest, MaximumWithScalar) {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 2, 1);
ElementwiseAttributes attr;
attr.param = -1.0f;
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/std::move(attr)},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -1.0, 2.0, -1.0}));
}
TEST(ElementwiseTest, Minimum) {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, -2.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -6.2, 2.0, -3.0}));
}
TEST(ElementwiseTest, MinimumWithScalar) {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);
ElementwiseAttributes attr;
attr.param = -1.0f;
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/std::move(attr)},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, -3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {-1.0, -6.2, -1.0, -3.0}));
}
TEST(ElementwiseTest, Pow) {
OperationType op_type = OperationType::POW;
const BHWC shape(1, 2, 2, 1);

View File

@ -96,6 +96,8 @@ class Registry : public NodeShader {
insert_elementwise_op(Type::DIV);
insert_elementwise_op(Type::HARD_SWISH);
insert_elementwise_op(Type::LOG);
insert_elementwise_op(Type::MAXIMUM);
insert_elementwise_op(Type::MINIMUM);
insert_elementwise_op(Type::POW);
insert_elementwise_op(Type::RSQRT);
insert_elementwise_op(Type::SIGMOID);

View File

@ -266,10 +266,12 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::TANH:
*tasks = ElementwiseWithOneInput(node_id, inputs[0], outputs[0], op_type);
break;
case OperationType::SUB:
case OperationType::DIV:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB:
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
break;
case OperationType::BATCH_NORMALIZATION: