TFLite GPU: Implement HARD_SWISH for MobileNet v3.
PiperOrigin-RevId: 257276206
This commit is contained in:
parent
83788ea45a
commit
77f6f8ccde
@ -816,6 +816,24 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HardSwishOperationParser : public TFLiteOperationParser {
|
||||||
|
public:
|
||||||
|
Status IsSupported(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration*) final {
|
||||||
|
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
||||||
|
/*outputs=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Parse(const TfLiteNode*, const TfLiteRegistration*,
|
||||||
|
GraphFloat32* graph, ObjectReader* reader) final {
|
||||||
|
Node* node = graph->NewNode();
|
||||||
|
node->operation.type = ToString(OperationType::HARD_SWISH);
|
||||||
|
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||||
|
return reader->AddOutputs(node);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ReshapeOperationParser : public TFLiteOperationParser {
|
class ReshapeOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
@ -2003,6 +2021,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||||||
return make_unique<ElementwiseOperationParser>(OperationType::DIV);
|
return make_unique<ElementwiseOperationParser>(OperationType::DIV);
|
||||||
case kTfLiteBuiltinFullyConnected:
|
case kTfLiteBuiltinFullyConnected:
|
||||||
return make_unique<FullyConnectedOperationParser>();
|
return make_unique<FullyConnectedOperationParser>();
|
||||||
|
case kTfLiteBuiltinHardSwish:
|
||||||
|
return make_unique<HardSwishOperationParser>();
|
||||||
case kTfLiteBuiltinLogistic:
|
case kTfLiteBuiltinLogistic:
|
||||||
return make_unique<ElementwiseOperationParser>(OperationType::SIGMOID);
|
return make_unique<ElementwiseOperationParser>(OperationType::SIGMOID);
|
||||||
case kTfLiteBuiltinLog:
|
case kTfLiteBuiltinLog:
|
||||||
|
@ -46,50 +46,58 @@ Padding2D& Padding2D::operator-(const Padding2D& value) {
|
|||||||
|
|
||||||
std::string ToString(enum OperationType op) {
|
std::string ToString(enum OperationType op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case OperationType::UNKNOWN:
|
|
||||||
break;
|
|
||||||
case OperationType::ABS:
|
case OperationType::ABS:
|
||||||
return "abs";
|
return "abs";
|
||||||
case OperationType::ADD:
|
case OperationType::ADD:
|
||||||
return "add";
|
return "add";
|
||||||
case OperationType::APPLY_MASK:
|
case OperationType::APPLY_MASK:
|
||||||
return "apply_mask";
|
return "apply_mask";
|
||||||
case OperationType::BATCH_TO_SPACE:
|
|
||||||
return "batch_to_space";
|
|
||||||
case OperationType::POOLING_2D:
|
|
||||||
return "pooling_2d";
|
|
||||||
case OperationType::MAX_UNPOOLING_2D:
|
|
||||||
return "max_unpooling";
|
|
||||||
case OperationType::BATCH_NORMALIZATION:
|
case OperationType::BATCH_NORMALIZATION:
|
||||||
return "batch_normalization";
|
return "batch_normalization";
|
||||||
|
case OperationType::BATCH_TO_SPACE:
|
||||||
|
return "batch_to_space";
|
||||||
case OperationType::CONCAT:
|
case OperationType::CONCAT:
|
||||||
return "concat";
|
return "concat";
|
||||||
case OperationType::CONST:
|
case OperationType::CONST:
|
||||||
return "const";
|
return "const";
|
||||||
case OperationType::CONVOLUTION_2D:
|
case OperationType::CONVOLUTION_2D:
|
||||||
return "convolution_2d";
|
return "convolution_2d";
|
||||||
|
case OperationType::CONVOLUTION_TRANSPOSED:
|
||||||
|
return "convolution_transposed";
|
||||||
case OperationType::COS:
|
case OperationType::COS:
|
||||||
return "cos";
|
return "cos";
|
||||||
case OperationType::DEPTHWISE_CONVOLUTION:
|
case OperationType::DEPTHWISE_CONVOLUTION:
|
||||||
return "depthwise_convolution";
|
return "depthwise_convolution";
|
||||||
case OperationType::DIV:
|
case OperationType::DIV:
|
||||||
return "div";
|
return "div";
|
||||||
|
case OperationType::FULLY_CONNECTED:
|
||||||
|
return "fully_connected";
|
||||||
|
case OperationType::HARD_SWISH:
|
||||||
|
return "hard_swish";
|
||||||
case OperationType::LOG:
|
case OperationType::LOG:
|
||||||
return "log";
|
return "log";
|
||||||
|
case OperationType::LSTM:
|
||||||
|
return "lstm";
|
||||||
|
case OperationType::MAX_UNPOOLING_2D:
|
||||||
|
return "max_unpooling";
|
||||||
case OperationType::MUL:
|
case OperationType::MUL:
|
||||||
return "mul";
|
return "mul";
|
||||||
|
case OperationType::MULTIPLY_SCALAR:
|
||||||
|
return "multiply_scalar";
|
||||||
case OperationType::PAD:
|
case OperationType::PAD:
|
||||||
return "pad";
|
return "pad";
|
||||||
|
case OperationType::POOLING_2D:
|
||||||
|
return "pooling_2d";
|
||||||
case OperationType::POW:
|
case OperationType::POW:
|
||||||
return "pow";
|
return "pow";
|
||||||
case OperationType::PRELU:
|
case OperationType::PRELU:
|
||||||
return "prelu";
|
return "prelu";
|
||||||
case OperationType::RELU:
|
case OperationType::RELU:
|
||||||
return "relu";
|
return "relu";
|
||||||
case OperationType::RESIZE:
|
|
||||||
return "resize";
|
|
||||||
case OperationType::RESHAPE:
|
case OperationType::RESHAPE:
|
||||||
return "reshape";
|
return "reshape";
|
||||||
|
case OperationType::RESIZE:
|
||||||
|
return "resize";
|
||||||
case OperationType::RSQRT:
|
case OperationType::RSQRT:
|
||||||
return "rsqrt";
|
return "rsqrt";
|
||||||
case OperationType::SIGMOID:
|
case OperationType::SIGMOID:
|
||||||
@ -110,18 +118,12 @@ std::string ToString(enum OperationType op) {
|
|||||||
return "squared_diff";
|
return "squared_diff";
|
||||||
case OperationType::SUB:
|
case OperationType::SUB:
|
||||||
return "subtract";
|
return "subtract";
|
||||||
case OperationType::UPSAMPLE_2D:
|
|
||||||
return "upsample_2d";
|
|
||||||
case OperationType::CONVOLUTION_TRANSPOSED:
|
|
||||||
return "convolution_transposed";
|
|
||||||
case OperationType::MULTIPLY_SCALAR:
|
|
||||||
return "multiply_scalar";
|
|
||||||
case OperationType::FULLY_CONNECTED:
|
|
||||||
return "fully_connected";
|
|
||||||
case OperationType::TANH:
|
case OperationType::TANH:
|
||||||
return "tanh";
|
return "tanh";
|
||||||
case OperationType::LSTM:
|
case OperationType::UPSAMPLE_2D:
|
||||||
return "lstm";
|
return "upsample_2d";
|
||||||
|
default:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
return "unknown_operation";
|
return "unknown_operation";
|
||||||
}
|
}
|
||||||
@ -140,6 +142,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
|||||||
{"cos", OperationType::COS},
|
{"cos", OperationType::COS},
|
||||||
{"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
|
{"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
|
||||||
{"fully_connected", OperationType::FULLY_CONNECTED},
|
{"fully_connected", OperationType::FULLY_CONNECTED},
|
||||||
|
{"hard_swish", OperationType::HARD_SWISH},
|
||||||
{"log", OperationType::LOG},
|
{"log", OperationType::LOG},
|
||||||
{"lstm", OperationType::LSTM},
|
{"lstm", OperationType::LSTM},
|
||||||
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
||||||
|
@ -46,14 +46,15 @@ enum class OperationType {
|
|||||||
DEPTHWISE_CONVOLUTION,
|
DEPTHWISE_CONVOLUTION,
|
||||||
DIV,
|
DIV,
|
||||||
FULLY_CONNECTED,
|
FULLY_CONNECTED,
|
||||||
|
HARD_SWISH,
|
||||||
LOG,
|
LOG,
|
||||||
LSTM,
|
LSTM,
|
||||||
MAX_UNPOOLING_2D,
|
MAX_UNPOOLING_2D,
|
||||||
MUL,
|
MUL,
|
||||||
MULTIPLY_SCALAR,
|
MULTIPLY_SCALAR,
|
||||||
|
PAD,
|
||||||
POOLING_2D,
|
POOLING_2D,
|
||||||
POW,
|
POW,
|
||||||
PAD,
|
|
||||||
PRELU,
|
PRELU,
|
||||||
RELU,
|
RELU,
|
||||||
RESHAPE,
|
RESHAPE,
|
||||||
|
@ -34,19 +34,18 @@ class ElementwiseOneArgument : public NodeShader {
|
|||||||
GeneratedCode* generated_code) const final {
|
GeneratedCode* generated_code) const final {
|
||||||
std::string source;
|
std::string source;
|
||||||
switch (operation_type_) {
|
switch (operation_type_) {
|
||||||
case OperationType::ABS: {
|
case OperationType::ABS:
|
||||||
source = "value_0 = abs(value_0);";
|
source = "value_0 = abs(value_0);";
|
||||||
break;
|
break;
|
||||||
}
|
case OperationType::COS:
|
||||||
case OperationType::SIN: {
|
|
||||||
source = "value_0 = sin(value_0);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::COS: {
|
|
||||||
source = "value_0 = cos(value_0);";
|
source = "value_0 = cos(value_0);";
|
||||||
break;
|
break;
|
||||||
}
|
case OperationType::HARD_SWISH:
|
||||||
case OperationType::LOG: {
|
source =
|
||||||
|
"value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
|
||||||
|
"vec4(1.0));";
|
||||||
|
break;
|
||||||
|
case OperationType::LOG:
|
||||||
source = R"(
|
source = R"(
|
||||||
const float nan = normalize(vec4(0, 0, 0, 0)).x;
|
const float nan = normalize(vec4(0, 0, 0, 0)).x;
|
||||||
value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
|
value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
|
||||||
@ -55,18 +54,7 @@ class ElementwiseOneArgument : public NodeShader {
|
|||||||
value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
|
value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
|
||||||
)";
|
)";
|
||||||
break;
|
break;
|
||||||
}
|
case OperationType::RSQRT:
|
||||||
case OperationType::SQRT: {
|
|
||||||
source = R"(
|
|
||||||
const float nan = normalize(vec4(0,0,0,0)).x;
|
|
||||||
value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
|
|
||||||
value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
|
|
||||||
value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
|
|
||||||
value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
|
|
||||||
)";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::RSQRT: {
|
|
||||||
source = R"(
|
source = R"(
|
||||||
const float nan = normalize(vec4(0, 0, 0, 0)).x;
|
const float nan = normalize(vec4(0, 0, 0, 0)).x;
|
||||||
value_0.x = value_0.x >= 0.0 ? 1.0 / sqrt(value_0.x) : nan;
|
value_0.x = value_0.x >= 0.0 ? 1.0 / sqrt(value_0.x) : nan;
|
||||||
@ -75,19 +63,27 @@ class ElementwiseOneArgument : public NodeShader {
|
|||||||
value_0.w = value_0.w >= 0.0 ? 1.0 / sqrt(value_0.w) : nan;
|
value_0.w = value_0.w >= 0.0 ? 1.0 / sqrt(value_0.w) : nan;
|
||||||
)";
|
)";
|
||||||
break;
|
break;
|
||||||
}
|
case OperationType::SIGMOID:
|
||||||
case OperationType::SQUARE: {
|
|
||||||
source = "value_0 = value_0 * value_0;";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::SIGMOID: {
|
|
||||||
source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
|
source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
|
||||||
break;
|
break;
|
||||||
}
|
case OperationType::SIN:
|
||||||
case OperationType::TANH: {
|
source = "value_0 = sin(value_0);";
|
||||||
|
break;
|
||||||
|
case OperationType::SQRT:
|
||||||
|
source = R"(
|
||||||
|
const float nan = normalize(vec4(0, 0, 0, 0)).x;
|
||||||
|
value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
|
||||||
|
value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
|
||||||
|
value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
|
||||||
|
value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
|
||||||
|
)";
|
||||||
|
break;
|
||||||
|
case OperationType::SQUARE:
|
||||||
|
source = "value_0 = value_0 * value_0;";
|
||||||
|
break;
|
||||||
|
case OperationType::TANH:
|
||||||
source = "value_0 = tanh(value_0);";
|
source = "value_0 = tanh(value_0);";
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return InvalidArgumentError("Incorrect elementwise operation type.");
|
return InvalidArgumentError("Incorrect elementwise operation type.");
|
||||||
}
|
}
|
||||||
@ -183,19 +179,20 @@ std::unique_ptr<NodeShader> NewElementwiseNodeShader(
|
|||||||
OperationType operation_type) {
|
OperationType operation_type) {
|
||||||
switch (operation_type) {
|
switch (operation_type) {
|
||||||
case OperationType::ABS:
|
case OperationType::ABS:
|
||||||
case OperationType::SIN:
|
|
||||||
case OperationType::COS:
|
case OperationType::COS:
|
||||||
case OperationType::LOG:
|
case OperationType::LOG:
|
||||||
case OperationType::SQRT:
|
case OperationType::HARD_SWISH:
|
||||||
case OperationType::RSQRT:
|
case OperationType::RSQRT:
|
||||||
case OperationType::SQUARE:
|
|
||||||
case OperationType::SIGMOID:
|
case OperationType::SIGMOID:
|
||||||
|
case OperationType::SIN:
|
||||||
|
case OperationType::SQRT:
|
||||||
|
case OperationType::SQUARE:
|
||||||
case OperationType::TANH:
|
case OperationType::TANH:
|
||||||
return absl::make_unique<ElementwiseOneArgument>(operation_type);
|
return absl::make_unique<ElementwiseOneArgument>(operation_type);
|
||||||
case OperationType::SUB:
|
|
||||||
case OperationType::DIV:
|
case OperationType::DIV:
|
||||||
case OperationType::POW:
|
case OperationType::POW:
|
||||||
case OperationType::SQUARED_DIFF:
|
case OperationType::SQUARED_DIFF:
|
||||||
|
case OperationType::SUB:
|
||||||
return absl::make_unique<ElementwiseTwoArguments>(operation_type);
|
return absl::make_unique<ElementwiseTwoArguments>(operation_type);
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -28,139 +28,45 @@ namespace gpu {
|
|||||||
namespace gl {
|
namespace gl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class ElementwiseOneArgumentTest : public ::testing::Test {
|
TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
||||||
public:
|
|
||||||
ElementwiseOneArgumentTest() = default;
|
|
||||||
~ElementwiseOneArgumentTest() override = default;
|
|
||||||
|
|
||||||
TensorRef<BHWC> GetTensorRef(int ref) {
|
|
||||||
TensorRef<BHWC> tensor_ref;
|
TensorRef<BHWC> tensor_ref;
|
||||||
tensor_ref.type = DataType::FLOAT32;
|
tensor_ref.type = DataType::FLOAT32;
|
||||||
tensor_ref.ref = ref;
|
tensor_ref.ref = ref;
|
||||||
tensor_ref.shape = BHWC(1, 2, 2, 1);
|
tensor_ref.shape = shape;
|
||||||
return tensor_ref;
|
return tensor_ref;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Abs) {
|
TEST(ElementwiseTest, Abs) {
|
||||||
OperationType op_type = OperationType::ABS;
|
OperationType op_type = OperationType::ABS;
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
const BHWC shape(1, 2, 2, 1);
|
||||||
{GetTensorRef(1)});
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Sin) {
|
TEST(ElementwiseTest, Cos) {
|
||||||
OperationType op_type = OperationType::SIN;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Cos) {
|
|
||||||
OperationType op_type = OperationType::COS;
|
OperationType op_type = OperationType::COS;
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
const BHWC shape(1, 2, 2, 1);
|
||||||
{GetTensorRef(1)});
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302}));
|
Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Log) {
|
TEST(ElementwiseTest, Div) {
|
||||||
OperationType op_type = OperationType::LOG;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Sqrt) {
|
|
||||||
OperationType op_type = OperationType::SQRT;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Rsqrt) {
|
|
||||||
OperationType op_type = OperationType::RSQRT;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Square) {
|
|
||||||
OperationType op_type = OperationType::SQUARE;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Sigmoid) {
|
|
||||||
OperationType op_type = OperationType::SIGMOID;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseOneArgumentTest, Tanh) {
|
|
||||||
OperationType op_type = OperationType::TANH;
|
|
||||||
SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)},
|
|
||||||
{GetTensorRef(1)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
|
|
||||||
}
|
|
||||||
|
|
||||||
class ElementwiseTwoArgumentsTest : public ::testing::Test {
|
|
||||||
public:
|
|
||||||
ElementwiseTwoArgumentsTest() = default;
|
|
||||||
~ElementwiseTwoArgumentsTest() override = default;
|
|
||||||
|
|
||||||
TensorRef<BHWC> GetTensorRef(int ref) {
|
|
||||||
TensorRef<BHWC> tensor_ref;
|
|
||||||
tensor_ref.type = DataType::FLOAT32;
|
|
||||||
tensor_ref.ref = ref;
|
|
||||||
tensor_ref.shape = BHWC(1, 2, 2, 1);
|
|
||||||
return tensor_ref;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(ElementwiseTwoArgumentsTest, Sub) {
|
|
||||||
OperationType op_type = OperationType::SUB;
|
|
||||||
SingleOpModel model({ToString(op_type), {}},
|
|
||||||
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ElementwiseTwoArgumentsTest, Div) {
|
|
||||||
OperationType op_type = OperationType::DIV;
|
OperationType op_type = OperationType::DIV;
|
||||||
SingleOpModel model({ToString(op_type), {}},
|
const BHWC shape(1, 2, 2, 1);
|
||||||
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)});
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape)});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
@ -168,10 +74,39 @@ TEST_F(ElementwiseTwoArgumentsTest, Div) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
|
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ElementwiseTwoArgumentsTest, Pow) {
|
TEST(ElementwiseTest, HardSwish) {
|
||||||
|
OperationType op_type = OperationType::HARD_SWISH;
|
||||||
|
const BHWC shape(1, 1, 1, 7);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(
|
||||||
|
model.PopulateTensor(0, {-4.5f, -3.0f, -1.5f, 0.0f, 1.5f, 3.0f, 4.5f}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6f),
|
||||||
|
{0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Log) {
|
||||||
|
OperationType op_type = OperationType::LOG;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Pow) {
|
||||||
OperationType op_type = OperationType::POW;
|
OperationType op_type = OperationType::POW;
|
||||||
SingleOpModel model({ToString(op_type), {}},
|
const BHWC shape(1, 2, 2, 1);
|
||||||
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)});
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape)});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
@ -179,10 +114,73 @@ TEST_F(ElementwiseTwoArgumentsTest, Pow) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ElementwiseTwoArgumentsTest, SquaredDiff) {
|
TEST(ElementwiseTest, Rsqrt) {
|
||||||
|
OperationType op_type = OperationType::RSQRT;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Sigmoid) {
|
||||||
|
OperationType op_type = OperationType::SIGMOID;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Sin) {
|
||||||
|
OperationType op_type = OperationType::SIN;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Sqrt) {
|
||||||
|
OperationType op_type = OperationType::SQRT;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Square) {
|
||||||
|
OperationType op_type = OperationType::SQUARE;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, SquaredDiff) {
|
||||||
OperationType op_type = OperationType::SQUARED_DIFF;
|
OperationType op_type = OperationType::SQUARED_DIFF;
|
||||||
SingleOpModel model({ToString(op_type), {}},
|
const BHWC shape(1, 2, 2, 1);
|
||||||
{GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)});
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape)});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0}));
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
@ -190,6 +188,32 @@ TEST_F(ElementwiseTwoArgumentsTest, SquaredDiff) {
|
|||||||
Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0}));
|
Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Sub) {
|
||||||
|
OperationType op_type = OperationType::SUB;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTest, Tanh) {
|
||||||
|
OperationType op_type = OperationType::TANH;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gl
|
} // namespace gl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -60,10 +60,10 @@ class Registry : public NodeShader {
|
|||||||
using Type = OperationType;
|
using Type = OperationType;
|
||||||
using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>;
|
using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>;
|
||||||
|
|
||||||
auto insert_op = [&](Type type, NewShaderFunc func) {
|
const auto insert_op = [&](Type type, NewShaderFunc func) {
|
||||||
shaders_[ToString(type)].push_back(func());
|
shaders_[ToString(type)].push_back(func());
|
||||||
};
|
};
|
||||||
auto insert_elementwise_op = [&](Type operation_type) {
|
const auto insert_elementwise_op = [&](Type operation_type) {
|
||||||
shaders_[ToString(operation_type)].push_back(
|
shaders_[ToString(operation_type)].push_back(
|
||||||
NewElementwiseNodeShader(operation_type));
|
NewElementwiseNodeShader(operation_type));
|
||||||
};
|
};
|
||||||
@ -82,26 +82,27 @@ class Registry : public NodeShader {
|
|||||||
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
|
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
|
||||||
insert_op(Type::PAD, NewPadNodeShader);
|
insert_op(Type::PAD, NewPadNodeShader);
|
||||||
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
|
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
|
||||||
|
insert_op(Type::PRELU, NewPReLUNodeShader);
|
||||||
insert_op(Type::RELU, NewReLUNodeShader);
|
insert_op(Type::RELU, NewReLUNodeShader);
|
||||||
insert_op(Type::RESHAPE, NewReshapeNodeShader);
|
insert_op(Type::RESHAPE, NewReshapeNodeShader);
|
||||||
insert_op(Type::PRELU, NewPReLUNodeShader);
|
|
||||||
insert_op(Type::SLICE, NewSliceNodeShader);
|
insert_op(Type::SLICE, NewSliceNodeShader);
|
||||||
insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader);
|
insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader);
|
||||||
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
|
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
|
||||||
|
|
||||||
insert_elementwise_op(Type::ABS);
|
insert_elementwise_op(Type::ABS);
|
||||||
insert_elementwise_op(Type::COS);
|
insert_elementwise_op(Type::COS);
|
||||||
|
insert_elementwise_op(Type::DIV);
|
||||||
|
insert_elementwise_op(Type::HARD_SWISH);
|
||||||
insert_elementwise_op(Type::LOG);
|
insert_elementwise_op(Type::LOG);
|
||||||
|
insert_elementwise_op(Type::POW);
|
||||||
insert_elementwise_op(Type::RSQRT);
|
insert_elementwise_op(Type::RSQRT);
|
||||||
insert_elementwise_op(Type::SIGMOID);
|
insert_elementwise_op(Type::SIGMOID);
|
||||||
insert_elementwise_op(Type::SIN);
|
insert_elementwise_op(Type::SIN);
|
||||||
insert_elementwise_op(Type::SQRT);
|
insert_elementwise_op(Type::SQRT);
|
||||||
insert_elementwise_op(Type::SQUARE);
|
insert_elementwise_op(Type::SQUARE);
|
||||||
insert_elementwise_op(Type::TANH);
|
|
||||||
insert_elementwise_op(Type::SUB);
|
|
||||||
insert_elementwise_op(Type::DIV);
|
|
||||||
insert_elementwise_op(Type::POW);
|
|
||||||
insert_elementwise_op(Type::SQUARED_DIFF);
|
insert_elementwise_op(Type::SQUARED_DIFF);
|
||||||
|
insert_elementwise_op(Type::SUB);
|
||||||
|
insert_elementwise_op(Type::TANH);
|
||||||
|
|
||||||
#ifndef TFLITE_GPU_BINARY_RELEASE
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);
|
insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h"
|
||||||
@ -172,6 +173,9 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
|||||||
node->operation.attributes),
|
node->operation.attributes),
|
||||||
options);
|
options);
|
||||||
break;
|
break;
|
||||||
|
case OperationType::HARD_SWISH:
|
||||||
|
tasks = HardSwish(node_id, inputs[0], outputs[0], options);
|
||||||
|
break;
|
||||||
case OperationType::MAX_UNPOOLING_2D:
|
case OperationType::MAX_UNPOOLING_2D:
|
||||||
tasks = MaxUnpooling(node_id, inputs[0], inputs[1], outputs[0],
|
tasks = MaxUnpooling(node_id, inputs[0], inputs[1], outputs[0],
|
||||||
absl::any_cast<MaxUnpooling2DAttributes>(
|
absl::any_cast<MaxUnpooling2DAttributes>(
|
||||||
|
@ -12,6 +12,7 @@ cc_library(
|
|||||||
":depthwise_conv",
|
":depthwise_conv",
|
||||||
":elementwise",
|
":elementwise",
|
||||||
":fully_connected",
|
":fully_connected",
|
||||||
|
":hard_swish",
|
||||||
":max_unpooling",
|
":max_unpooling",
|
||||||
":mul",
|
":mul",
|
||||||
":padding",
|
":padding",
|
||||||
@ -122,6 +123,18 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "hard_swish",
|
||||||
|
srcs = ["hard_swish.cc"],
|
||||||
|
hdrs = ["hard_swish.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "max_unpooling",
|
name = "max_unpooling",
|
||||||
srcs = ["max_unpooling.cc"],
|
srcs = ["max_unpooling.cc"],
|
||||||
|
47
tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.cc
Normal file
47
tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.cc
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
std::vector<ComputeTaskDescriptorPtr> HardSwish(int id, ValueId input_id,
|
||||||
|
ValueId output_id,
|
||||||
|
const RuntimeOptions& options) {
|
||||||
|
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||||
|
desc->id = id;
|
||||||
|
desc->is_linkable = true;
|
||||||
|
desc->shader_source = R"(
|
||||||
|
FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {
|
||||||
|
return value * clamp(value / 6.0f + FLT4(0.5f), FLT4(0.0f), FLT4(1.0f));
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
desc->input_buffers = {{input_id}};
|
||||||
|
desc->output_buffer = {output_id};
|
||||||
|
return {desc};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
37
tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.h
Normal file
37
tensorflow/lite/delegates/gpu/metal/kernels/hard_swish.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_HARD_SWISH_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_HARD_SWISH_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
std::vector<ComputeTaskDescriptorPtr> HardSwish(int id, ValueId input_id,
|
||||||
|
ValueId output_id,
|
||||||
|
const RuntimeOptions& options);
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_HARD_SWISH_H_
|
Loading…
Reference in New Issue
Block a user