Added support of NEG for GPU backends.
PiperOrigin-RevId: 330009294 Change-Id: I8955b499adfb1ca23084d36d79d8509bbc37e4eb
This commit is contained in:
parent
ec878bb3e3
commit
073362c7dc
@ -58,6 +58,9 @@ std::string GetOneInputCode(const OperationType& op_type,
|
||||
case OperationType::LOG:
|
||||
result = "$0 = log($0);\n";
|
||||
break;
|
||||
case OperationType::NEG:
|
||||
result = "$0 = -($0);\n";
|
||||
break;
|
||||
case OperationType::RSQRT:
|
||||
result = "$0 = rsqrt($0);\n";
|
||||
break;
|
||||
|
@ -208,6 +208,30 @@ TEST_F(OpenCLOperationTest, Log) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Neg) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor.data = {1.0f, -2.0f, 0.0f, 4.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateElementwiseOneInput(op_def, OperationType::NEG);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-1.0f, 2.0f, 0.0f, -4.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Rsqrt) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||
|
@ -336,6 +336,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
|
||||
case OperationType::EXP:
|
||||
case OperationType::HARD_SWISH:
|
||||
case OperationType::LOG:
|
||||
case OperationType::NEG:
|
||||
case OperationType::RSQRT:
|
||||
case OperationType::SIGMOID:
|
||||
case OperationType::SIN:
|
||||
|
@ -801,6 +801,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
||||
case OperationType::ELU:
|
||||
case OperationType::EXP:
|
||||
case OperationType::LOG:
|
||||
case OperationType::NEG:
|
||||
case OperationType::RSQRT:
|
||||
case OperationType::SIGMOID:
|
||||
case OperationType::SIN:
|
||||
@ -2574,6 +2575,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
|
||||
case kTfLiteBuiltinMul:
|
||||
return std::make_unique<MulOperationParser>();
|
||||
case kTfLiteBuiltinNeg:
|
||||
return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
|
||||
case kTfLiteBuiltinPad:
|
||||
return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
|
||||
case kTfLiteBuiltinPow:
|
||||
|
@ -134,6 +134,8 @@ std::string ToString(enum OperationType op) {
|
||||
return "minimum";
|
||||
case OperationType::MUL:
|
||||
return "mul";
|
||||
case OperationType::NEG:
|
||||
return "neg";
|
||||
case OperationType::NOT_EQUAL:
|
||||
return "not_equal";
|
||||
case OperationType::PAD:
|
||||
@ -224,6 +226,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
OperationType::MEAN_STDDEV_NORMALIZATION},
|
||||
{"minimum", OperationType::MINIMUM},
|
||||
{"mul", OperationType::MUL},
|
||||
{"neg", OperationType::NEG},
|
||||
{"not_equal", OperationType::NOT_EQUAL},
|
||||
{"pad", OperationType::PAD},
|
||||
{"pooling_2d", OperationType::POOLING_2D},
|
||||
|
@ -63,6 +63,7 @@ enum class OperationType {
|
||||
MEAN_STDDEV_NORMALIZATION,
|
||||
MINIMUM,
|
||||
MUL,
|
||||
NEG,
|
||||
NOT_EQUAL,
|
||||
PAD,
|
||||
POOLING_2D,
|
||||
|
@ -69,6 +69,9 @@ class ElementwiseOneArgument : public NodeShader {
|
||||
value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
|
||||
)";
|
||||
break;
|
||||
case OperationType::NEG:
|
||||
source = "value_0 = -(value_0);";
|
||||
break;
|
||||
case OperationType::RSQRT:
|
||||
source = R"(
|
||||
const float nan = normalize(vec4(0, 0, 0, 0)).x;
|
||||
|
@ -129,6 +129,18 @@ TEST(ElementwiseOneArgumentTest, Log) {
|
||||
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
|
||||
}
|
||||
|
||||
TEST(ElementwiseOneArgumentTest, Neg) {
|
||||
OperationType op_type = OperationType::NEG;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||
/*inputs=*/{GetTensorRef(0, shape)},
|
||||
/*outputs=*/{GetTensorRef(1, shape)});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, -3.1415926, 0.0, 1.0}));
|
||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||
EXPECT_THAT(model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6), {-1.0, 3.1415926, 0.0, -1.0}));
|
||||
}
|
||||
|
||||
TEST(ElementwiseOneArgumentTest, Rsqrt) {
|
||||
OperationType op_type = OperationType::RSQRT;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
|
@ -103,6 +103,7 @@ class Registry : public NodeShader {
|
||||
insert_elementwise_op(Type::EXP);
|
||||
insert_elementwise_op(Type::HARD_SWISH);
|
||||
insert_elementwise_op(Type::LOG);
|
||||
insert_elementwise_op(Type::NEG);
|
||||
insert_elementwise_op(Type::MAXIMUM);
|
||||
insert_elementwise_op(Type::MINIMUM);
|
||||
insert_elementwise_op(Type::POW);
|
||||
|
@ -371,6 +371,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
case OperationType::EXP:
|
||||
case OperationType::HARD_SWISH:
|
||||
case OperationType::LOG:
|
||||
case OperationType::NEG:
|
||||
case OperationType::RSQRT:
|
||||
case OperationType::SIGMOID:
|
||||
case OperationType::SIN:
|
||||
|
@ -45,6 +45,7 @@ std::string OneInputFunctor(OperationType op_type, const std::string& value) {
|
||||
"$0.w < FLT(0.0f) ? exp($0.w) - FLT(1.0f) : $0.w)"},
|
||||
{OperationType::EXP, "exp($0)"},
|
||||
{OperationType::LOG, "log($0)"},
|
||||
{OperationType::NEG, "-($0)"},
|
||||
{OperationType::SQRT, "sqrt($0)"},
|
||||
{OperationType::RSQRT, "1.0 / sqrt($0)"},
|
||||
{OperationType::SQUARE, "$0 * $0"},
|
||||
|
@ -257,6 +257,19 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testNeg {
|
||||
OperationType op_type = OperationType::NEG;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||
/*inputs=*/{GetTensorRef(0, shape)},
|
||||
/*outputs=*/{GetTensorRef(1, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {-1.0, 3.1415926, 0.0, 1.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({1.0, -3.1415926, 0.0, -1.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testPow {
|
||||
OperationType op_type = OperationType::POW;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
|
Loading…
x
Reference in New Issue
Block a user