TFLite GPU: Implement HARD_SWISH for MobileNet v3.

PiperOrigin-RevId: 257276206
This commit is contained in:
Juhyun Lee 2019-07-09 14:50:05 -07:00 committed by TensorFlower Gardener
parent 83788ea45a
commit 77f6f8ccde
10 changed files with 344 additions and 197 deletions

View File

@ -816,6 +816,24 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
} }
}; };
class HardSwishOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration*) final {
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
/*outputs=*/1);
}
Status Parse(const TfLiteNode*, const TfLiteRegistration*,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::HARD_SWISH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
return reader->AddOutputs(node);
}
};
class ReshapeOperationParser : public TFLiteOperationParser { class ReshapeOperationParser : public TFLiteOperationParser {
public: public:
Status IsSupported(const TfLiteContext* context, Status IsSupported(const TfLiteContext* context,
@ -2003,6 +2021,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return make_unique<ElementwiseOperationParser>(OperationType::DIV); return make_unique<ElementwiseOperationParser>(OperationType::DIV);
case kTfLiteBuiltinFullyConnected: case kTfLiteBuiltinFullyConnected:
return make_unique<FullyConnectedOperationParser>(); return make_unique<FullyConnectedOperationParser>();
case kTfLiteBuiltinHardSwish:
return make_unique<HardSwishOperationParser>();
case kTfLiteBuiltinLogistic: case kTfLiteBuiltinLogistic:
return make_unique<ElementwiseOperationParser>(OperationType::SIGMOID); return make_unique<ElementwiseOperationParser>(OperationType::SIGMOID);
case kTfLiteBuiltinLog: case kTfLiteBuiltinLog:

View File

@ -46,50 +46,58 @@ Padding2D& Padding2D::operator-(const Padding2D& value) {
std::string ToString(enum OperationType op) { std::string ToString(enum OperationType op) {
switch (op) { switch (op) {
case OperationType::UNKNOWN:
break;
case OperationType::ABS: case OperationType::ABS:
return "abs"; return "abs";
case OperationType::ADD: case OperationType::ADD:
return "add"; return "add";
case OperationType::APPLY_MASK: case OperationType::APPLY_MASK:
return "apply_mask"; return "apply_mask";
case OperationType::BATCH_TO_SPACE:
return "batch_to_space";
case OperationType::POOLING_2D:
return "pooling_2d";
case OperationType::MAX_UNPOOLING_2D:
return "max_unpooling";
case OperationType::BATCH_NORMALIZATION: case OperationType::BATCH_NORMALIZATION:
return "batch_normalization"; return "batch_normalization";
case OperationType::BATCH_TO_SPACE:
return "batch_to_space";
case OperationType::CONCAT: case OperationType::CONCAT:
return "concat"; return "concat";
case OperationType::CONST: case OperationType::CONST:
return "const"; return "const";
case OperationType::CONVOLUTION_2D: case OperationType::CONVOLUTION_2D:
return "convolution_2d"; return "convolution_2d";
case OperationType::CONVOLUTION_TRANSPOSED:
return "convolution_transposed";
case OperationType::COS: case OperationType::COS:
return "cos"; return "cos";
case OperationType::DEPTHWISE_CONVOLUTION: case OperationType::DEPTHWISE_CONVOLUTION:
return "depthwise_convolution"; return "depthwise_convolution";
case OperationType::DIV: case OperationType::DIV:
return "div"; return "div";
case OperationType::FULLY_CONNECTED:
return "fully_connected";
case OperationType::HARD_SWISH:
return "hard_swish";
case OperationType::LOG: case OperationType::LOG:
return "log"; return "log";
case OperationType::LSTM:
return "lstm";
case OperationType::MAX_UNPOOLING_2D:
return "max_unpooling";
case OperationType::MUL: case OperationType::MUL:
return "mul"; return "mul";
case OperationType::MULTIPLY_SCALAR:
return "multiply_scalar";
case OperationType::PAD: case OperationType::PAD:
return "pad"; return "pad";
case OperationType::POOLING_2D:
return "pooling_2d";
case OperationType::POW: case OperationType::POW:
return "pow"; return "pow";
case OperationType::PRELU: case OperationType::PRELU:
return "prelu"; return "prelu";
case OperationType::RELU: case OperationType::RELU:
return "relu"; return "relu";
case OperationType::RESIZE:
return "resize";
case OperationType::RESHAPE: case OperationType::RESHAPE:
return "reshape"; return "reshape";
case OperationType::RESIZE:
return "resize";
case OperationType::RSQRT: case OperationType::RSQRT:
return "rsqrt"; return "rsqrt";
case OperationType::SIGMOID: case OperationType::SIGMOID:
@ -110,18 +118,12 @@ std::string ToString(enum OperationType op) {
return "squared_diff"; return "squared_diff";
case OperationType::SUB: case OperationType::SUB:
return "subtract"; return "subtract";
case OperationType::UPSAMPLE_2D:
return "upsample_2d";
case OperationType::CONVOLUTION_TRANSPOSED:
return "convolution_transposed";
case OperationType::MULTIPLY_SCALAR:
return "multiply_scalar";
case OperationType::FULLY_CONNECTED:
return "fully_connected";
case OperationType::TANH: case OperationType::TANH:
return "tanh"; return "tanh";
case OperationType::LSTM: case OperationType::UPSAMPLE_2D:
return "lstm"; return "upsample_2d";
default:
break;
} }
return "unknown_operation"; return "unknown_operation";
} }
@ -140,6 +142,7 @@ OperationType OperationTypeFromString(const std::string& name) {
{"cos", OperationType::COS}, {"cos", OperationType::COS},
{"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION}, {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
{"fully_connected", OperationType::FULLY_CONNECTED}, {"fully_connected", OperationType::FULLY_CONNECTED},
{"hard_swish", OperationType::HARD_SWISH},
{"log", OperationType::LOG}, {"log", OperationType::LOG},
{"lstm", OperationType::LSTM}, {"lstm", OperationType::LSTM},
{"max_unpooling", OperationType::MAX_UNPOOLING_2D}, {"max_unpooling", OperationType::MAX_UNPOOLING_2D},

View File

@ -46,14 +46,15 @@ enum class OperationType {
DEPTHWISE_CONVOLUTION, DEPTHWISE_CONVOLUTION,
DIV, DIV,
FULLY_CONNECTED, FULLY_CONNECTED,
HARD_SWISH,
LOG, LOG,
LSTM, LSTM,
MAX_UNPOOLING_2D, MAX_UNPOOLING_2D,
MUL, MUL,
MULTIPLY_SCALAR, MULTIPLY_SCALAR,
PAD,
POOLING_2D, POOLING_2D,
POW, POW,
PAD,
PRELU, PRELU,
RELU, RELU,
RESHAPE, RESHAPE,

View File

@ -34,19 +34,18 @@ class ElementwiseOneArgument : public NodeShader {
GeneratedCode* generated_code) const final { GeneratedCode* generated_code) const final {
std::string source; std::string source;
switch (operation_type_) { switch (operation_type_) {
case OperationType::ABS: { case OperationType::ABS:
source = "value_0 = abs(value_0);"; source = "value_0 = abs(value_0);";
break; break;
} case OperationType::COS:
case OperationType::SIN: {
source = "value_0 = sin(value_0);";
break;
}
case OperationType::COS: {
source = "value_0 = cos(value_0);"; source = "value_0 = cos(value_0);";
break; break;
} case OperationType::HARD_SWISH:
case OperationType::LOG: { source =
"value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
"vec4(1.0));";
break;
case OperationType::LOG:
source = R"( source = R"(
const float nan = normalize(vec4(0, 0, 0, 0)).x; const float nan = normalize(vec4(0, 0, 0, 0)).x;
value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan; value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
@ -55,18 +54,7 @@ class ElementwiseOneArgument : public NodeShader {
value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan; value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
)"; )";
break; break;
} case OperationType::RSQRT:
case OperationType::SQRT: {
source = R"(
const float nan = normalize(vec4(0,0,0,0)).x;
value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
)";
break;
}
case OperationType::RSQRT: {
source = R"( source = R"(
const float nan = normalize(vec4(0, 0, 0, 0)).x; const float nan = normalize(vec4(0, 0, 0, 0)).x;
value_0.x = value_0.x >= 0.0 ? 1.0 / sqrt(value_0.x) : nan; value_0.x = value_0.x >= 0.0 ? 1.0 / sqrt(value_0.x) : nan;
@ -75,19 +63,27 @@ class ElementwiseOneArgument : public NodeShader {
value_0.w = value_0.w >= 0.0 ? 1.0 / sqrt(value_0.w) : nan; value_0.w = value_0.w >= 0.0 ? 1.0 / sqrt(value_0.w) : nan;
)"; )";
break; break;
} case OperationType::SIGMOID:
case OperationType::SQUARE: {
source = "value_0 = value_0 * value_0;";
break;
}
case OperationType::SIGMOID: {
source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));"; source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
break; break;
} case OperationType::SIN:
case OperationType::TANH: { source = "value_0 = sin(value_0);";
break;
case OperationType::SQRT:
source = R"(
const float nan = normalize(vec4(0, 0, 0, 0)).x;
value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
)";
break;
case OperationType::SQUARE:
source = "value_0 = value_0 * value_0;";
break;
case OperationType::TANH:
source = "value_0 = tanh(value_0);"; source = "value_0 = tanh(value_0);";
break; break;
}
default: default:
return InvalidArgumentError("Incorrect elementwise operation type."); return InvalidArgumentError("Incorrect elementwise operation type.");
} }
@ -183,19 +179,20 @@ std::unique_ptr<NodeShader> NewElementwiseNodeShader(
OperationType operation_type) { OperationType operation_type) {
switch (operation_type) { switch (operation_type) {
case OperationType::ABS: case OperationType::ABS:
case OperationType::SIN:
case OperationType::COS: case OperationType::COS:
case OperationType::LOG: case OperationType::LOG:
case OperationType::SQRT: case OperationType::HARD_SWISH:
case OperationType::RSQRT: case OperationType::RSQRT:
case OperationType::SQUARE:
case OperationType::SIGMOID: case OperationType::SIGMOID:
case OperationType::SIN:
case OperationType::SQRT:
case OperationType::SQUARE:
case OperationType::TANH: case OperationType::TANH:
return absl::make_unique<ElementwiseOneArgument>(operation_type); return absl::make_unique<ElementwiseOneArgument>(operation_type);
case OperationType::SUB:
case OperationType::DIV: case OperationType::DIV:
case OperationType::POW: case OperationType::POW:
case OperationType::SQUARED_DIFF: case OperationType::SQUARED_DIFF:
case OperationType::SUB:
return absl::make_unique<ElementwiseTwoArguments>(operation_type); return absl::make_unique<ElementwiseTwoArguments>(operation_type);
default: default:
return nullptr; return nullptr;

View File

@ -28,139 +28,45 @@ namespace gpu {
namespace gl { namespace gl {
namespace { namespace {
class ElementwiseOneArgumentTest : public ::testing::Test { TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
public:
ElementwiseOneArgumentTest() = default;
~ElementwiseOneArgumentTest() override = default;
TensorRef<BHWC> GetTensorRef(int ref) {
TensorRef<BHWC> tensor_ref; TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32; tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref; tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1); tensor_ref.shape = shape;
return tensor_ref; return tensor_ref;
} }
};
TEST_F(ElementwiseOneArgumentTest, Abs) { TEST(ElementwiseTest, Abs) {
OperationType op_type = OperationType::ABS; OperationType op_type = OperationType::ABS;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, const BHWC shape(1, 2, 2, 1);
{GetTensorRef(1)}); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0), EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0})); Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0}));
} }
TEST_F(ElementwiseOneArgumentTest, Sin) { TEST(ElementwiseTest, Cos) {
OperationType op_type = OperationType::SIN;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
}
TEST_F(ElementwiseOneArgumentTest, Cos) {
OperationType op_type = OperationType::COS; OperationType op_type = OperationType::COS;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, const BHWC shape(1, 2, 2, 1);
{GetTensorRef(1)}); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1})); ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0), EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302})); Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302}));
} }
TEST_F(ElementwiseOneArgumentTest, Log) { TEST(ElementwiseTest, Div) {
OperationType op_type = OperationType::LOG;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
}
TEST_F(ElementwiseOneArgumentTest, Sqrt) {
OperationType op_type = OperationType::SQRT;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
}
TEST_F(ElementwiseOneArgumentTest, Rsqrt) {
OperationType op_type = OperationType::RSQRT;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
}
TEST_F(ElementwiseOneArgumentTest, Square) {
OperationType op_type = OperationType::SQUARE;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
}
TEST_F(ElementwiseOneArgumentTest, Sigmoid) {
OperationType op_type = OperationType::SIGMOID;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
}
TEST_F(ElementwiseOneArgumentTest, Tanh) {
OperationType op_type = OperationType::TANH;
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
{GetTensorRef(1)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
}
class ElementwiseTwoArgumentsTest : public ::testing::Test {
public:
ElementwiseTwoArgumentsTest() = default;
~ElementwiseTwoArgumentsTest() override = default;
TensorRef<BHWC> GetTensorRef(int ref) {
TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1);
return tensor_ref;
}
};
TEST_F(ElementwiseTwoArgumentsTest, Sub) {
OperationType op_type = OperationType::SUB;
SingleOpModel model({ToString(op_type), {}},
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
}
TEST_F(ElementwiseTwoArgumentsTest, Div) {
OperationType op_type = OperationType::DIV; OperationType op_type = OperationType::DIV;
SingleOpModel model({ToString(op_type), {}}, const BHWC shape(1, 2, 2, 1);
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
@ -168,10 +74,39 @@ TEST_F(ElementwiseTwoArgumentsTest, Div) {
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0})); Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
} }
TEST_F(ElementwiseTwoArgumentsTest, Pow) { TEST(ElementwiseTest, HardSwish) {
OperationType op_type = OperationType::HARD_SWISH;
const BHWC shape(1, 1, 1, 7);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(
model.PopulateTensor(0, {-4.5f, -3.0f, -1.5f, 0.0f, 1.5f, 3.0f, 4.5f}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6f),
{0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}));
}
TEST(ElementwiseTest, Log) {
OperationType op_type = OperationType::LOG;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
}
TEST(ElementwiseTest, Pow) {
OperationType op_type = OperationType::POW; OperationType op_type = OperationType::POW;
SingleOpModel model({ToString(op_type), {}}, const BHWC shape(1, 2, 2, 1);
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
@ -179,10 +114,73 @@ TEST_F(ElementwiseTwoArgumentsTest, Pow) {
Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0})); Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0}));
} }
TEST_F(ElementwiseTwoArgumentsTest, SquaredDiff) { TEST(ElementwiseTest, Rsqrt) {
OperationType op_type = OperationType::RSQRT;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
}
TEST(ElementwiseTest, Sigmoid) {
OperationType op_type = OperationType::SIGMOID;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
}
TEST(ElementwiseTest, Sin) {
OperationType op_type = OperationType::SIN;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
}
TEST(ElementwiseTest, Sqrt) {
OperationType op_type = OperationType::SQRT;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
}
TEST(ElementwiseTest, Square) {
OperationType op_type = OperationType::SQUARE;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
}
TEST(ElementwiseTest, SquaredDiff) {
OperationType op_type = OperationType::SQUARED_DIFF; OperationType op_type = OperationType::SQUARED_DIFF;
SingleOpModel model({ToString(op_type), {}}, const BHWC shape(1, 2, 2, 1);
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0})); ASSERT_TRUE(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0})); ASSERT_TRUE(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
@ -190,6 +188,32 @@ TEST_F(ElementwiseTwoArgumentsTest, SquaredDiff) {
Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0})); Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0}));
} }
TEST(ElementwiseTest, Sub) {
OperationType op_type = OperationType::SUB;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
}
TEST(ElementwiseTest, Tanh) {
OperationType op_type = OperationType::TANH;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
}
} // namespace } // namespace
} // namespace gl } // namespace gl
} // namespace gpu } // namespace gpu

View File

@ -60,10 +60,10 @@ class Registry : public NodeShader {
using Type = OperationType; using Type = OperationType;
using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>; using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>;
auto insert_op = [&](Type type, NewShaderFunc func) { const auto insert_op = [&](Type type, NewShaderFunc func) {
shaders_[ToString(type)].push_back(func()); shaders_[ToString(type)].push_back(func());
}; };
auto insert_elementwise_op = [&](Type operation_type) { const auto insert_elementwise_op = [&](Type operation_type) {
shaders_[ToString(operation_type)].push_back( shaders_[ToString(operation_type)].push_back(
NewElementwiseNodeShader(operation_type)); NewElementwiseNodeShader(operation_type));
}; };
@ -82,26 +82,27 @@ class Registry : public NodeShader {
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader); insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
insert_op(Type::PAD, NewPadNodeShader); insert_op(Type::PAD, NewPadNodeShader);
insert_op(Type::POOLING_2D, NewPoolingNodeShader); insert_op(Type::POOLING_2D, NewPoolingNodeShader);
insert_op(Type::PRELU, NewPReLUNodeShader);
insert_op(Type::RELU, NewReLUNodeShader); insert_op(Type::RELU, NewReLUNodeShader);
insert_op(Type::RESHAPE, NewReshapeNodeShader); insert_op(Type::RESHAPE, NewReshapeNodeShader);
insert_op(Type::PRELU, NewPReLUNodeShader);
insert_op(Type::SLICE, NewSliceNodeShader); insert_op(Type::SLICE, NewSliceNodeShader);
insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader); insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader);
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader); insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
insert_elementwise_op(Type::ABS); insert_elementwise_op(Type::ABS);
insert_elementwise_op(Type::COS); insert_elementwise_op(Type::COS);
insert_elementwise_op(Type::DIV);
insert_elementwise_op(Type::HARD_SWISH);
insert_elementwise_op(Type::LOG); insert_elementwise_op(Type::LOG);
insert_elementwise_op(Type::POW);
insert_elementwise_op(Type::RSQRT); insert_elementwise_op(Type::RSQRT);
insert_elementwise_op(Type::SIGMOID); insert_elementwise_op(Type::SIGMOID);
insert_elementwise_op(Type::SIN); insert_elementwise_op(Type::SIN);
insert_elementwise_op(Type::SQRT); insert_elementwise_op(Type::SQRT);
insert_elementwise_op(Type::SQUARE); insert_elementwise_op(Type::SQUARE);
insert_elementwise_op(Type::TANH);
insert_elementwise_op(Type::SUB);
insert_elementwise_op(Type::DIV);
insert_elementwise_op(Type::POW);
insert_elementwise_op(Type::SQUARED_DIFF); insert_elementwise_op(Type::SQUARED_DIFF);
insert_elementwise_op(Type::SUB);
insert_elementwise_op(Type::TANH);
#ifndef TFLITE_GPU_BINARY_RELEASE #ifndef TFLITE_GPU_BINARY_RELEASE
insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader); insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h" #include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
@ -172,6 +173,9 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
node->operation.attributes), node->operation.attributes),
options); options);
break; break;
case OperationType::HARD_SWISH:
tasks = HardSwish(node_id, inputs[0], outputs[0], options);
break;
case OperationType::MAX_UNPOOLING_2D: case OperationType::MAX_UNPOOLING_2D:
tasks = MaxUnpooling(node_id, inputs[0], inputs[1], outputs[0], tasks = MaxUnpooling(node_id, inputs[0], inputs[1], outputs[0],
absl::any_cast<MaxUnpooling2DAttributes>( absl::any_cast<MaxUnpooling2DAttributes>(

View File

@ -12,6 +12,7 @@ cc_library(
":depthwise_conv", ":depthwise_conv",
":elementwise", ":elementwise",
":fully_connected", ":fully_connected",
":hard_swish",
":max_unpooling", ":max_unpooling",
":mul", ":mul",
":padding", ":padding",
@ -122,6 +123,18 @@ cc_library(
], ],
) )
cc_library(
name = "hard_swish",
srcs = ["hard_swish.cc"],
hdrs = ["hard_swish.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
],
)
cc_library( cc_library(
name = "max_unpooling", name = "max_unpooling",
srcs = ["max_unpooling.cc"], srcs = ["max_unpooling.cc"],

View File

@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.h"
#include <memory>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> HardSwish(int id, ValueId input_id,
ValueId output_id,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;
desc->shader_source = R"(
FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
return value * clamp(value / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f));
}
)";
desc->input_buffers = {{input_id}};
desc->output_buffer = {output_id};
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_HARD_SWISH_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_HARD_SWISH_H_
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
std::vector<ComputeTaskDescriptorPtr> HardSwish(int id, ValueId input_id,
ValueId output_id,
const RuntimeOptions& options);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_HARD_SWISH_H_