Added support of NEG for GPU backends.

PiperOrigin-RevId: 330009294
Change-Id: I8955b499adfb1ca23084d36d79d8509bbc37e4eb
This commit is contained in:
Raman Sarokin 2020-09-03 15:51:44 -07:00 committed by TensorFlower Gardener
parent ec878bb3e3
commit 073362c7dc
12 changed files with 66 additions and 0 deletions

View File

@ -58,6 +58,9 @@ std::string GetOneInputCode(const OperationType& op_type,
case OperationType::LOG:
result = "$0 = log($0);\n";
break;
case OperationType::NEG:
result = "$0 = -($0);\n";
break;
case OperationType::RSQRT:
result = "$0 = rsqrt($0);\n";
break;

View File

@ -208,6 +208,30 @@ TEST_F(OpenCLOperationTest, Log) {
}
}
TEST_F(OpenCLOperationTest, Neg) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {1.0f, -2.0f, 0.0f, 4.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::NEG);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-1.0f, 2.0f, 0.0f, -4.0f}));
}
}
}
TEST_F(OpenCLOperationTest, Rsqrt) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 1, 2);

View File

@ -336,6 +336,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
case OperationType::EXP:
case OperationType::HARD_SWISH:
case OperationType::LOG:
case OperationType::NEG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:

View File

@ -801,6 +801,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
case OperationType::ELU:
case OperationType::EXP:
case OperationType::LOG:
case OperationType::NEG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:
@ -2574,6 +2575,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
case kTfLiteBuiltinMul:
return std::make_unique<MulOperationParser>();
case kTfLiteBuiltinNeg:
return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
case kTfLiteBuiltinPad:
return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
case kTfLiteBuiltinPow:

View File

@ -134,6 +134,8 @@ std::string ToString(enum OperationType op) {
return "minimum";
case OperationType::MUL:
return "mul";
case OperationType::NEG:
return "neg";
case OperationType::NOT_EQUAL:
return "not_equal";
case OperationType::PAD:
@ -224,6 +226,7 @@ OperationType OperationTypeFromString(const std::string& name) {
OperationType::MEAN_STDDEV_NORMALIZATION},
{"minimum", OperationType::MINIMUM},
{"mul", OperationType::MUL},
{"neg", OperationType::NEG},
{"not_equal", OperationType::NOT_EQUAL},
{"pad", OperationType::PAD},
{"pooling_2d", OperationType::POOLING_2D},

View File

@ -63,6 +63,7 @@ enum class OperationType {
MEAN_STDDEV_NORMALIZATION,
MINIMUM,
MUL,
NEG,
NOT_EQUAL,
PAD,
POOLING_2D,

View File

@ -69,6 +69,9 @@ class ElementwiseOneArgument : public NodeShader {
value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
)";
break;
case OperationType::NEG:
source = "value_0 = -(value_0);";
break;
case OperationType::RSQRT:
source = R"(
const float nan = normalize(vec4(0, 0, 0, 0)).x;

View File

@ -129,6 +129,18 @@ TEST(ElementwiseOneArgumentTest, Log) {
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
}
TEST(ElementwiseOneArgumentTest, Neg) {
OperationType op_type = OperationType::NEG;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, -3.1415926, 0.0, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {-1.0, 3.1415926, 0.0, -1.0}));
}
TEST(ElementwiseOneArgumentTest, Rsqrt) {
OperationType op_type = OperationType::RSQRT;
const BHWC shape(1, 2, 2, 1);

View File

@ -103,6 +103,7 @@ class Registry : public NodeShader {
insert_elementwise_op(Type::EXP);
insert_elementwise_op(Type::HARD_SWISH);
insert_elementwise_op(Type::LOG);
insert_elementwise_op(Type::NEG);
insert_elementwise_op(Type::MAXIMUM);
insert_elementwise_op(Type::MINIMUM);
insert_elementwise_op(Type::POW);

View File

@ -371,6 +371,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::EXP:
case OperationType::HARD_SWISH:
case OperationType::LOG:
case OperationType::NEG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN:

View File

@ -45,6 +45,7 @@ std::string OneInputFunctor(OperationType op_type, const std::string& value) {
"$0.w < FLT(0.0f) ? exp($0.w) - FLT(1.0f) : $0.w)"},
{OperationType::EXP, "exp($0)"},
{OperationType::LOG, "log($0)"},
{OperationType::NEG, "-($0)"},
{OperationType::SQRT, "sqrt($0)"},
{OperationType::RSQRT, "1.0 / sqrt($0)"},
{OperationType::SQUARE, "$0 * $0"},

View File

@ -257,6 +257,19 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testNeg {
OperationType op_type = OperationType::NEG;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
XCTAssertTrue(model.PopulateTensor(0, {-1.0, 3.1415926, 0.0, 1.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({1.0, -3.1415926, 0.0, -1.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testPow {
OperationType op_type = OperationType::POW;
const BHWC shape(1, 2, 2, 1);