Add support for LeakyRelu
PiperOrigin-RevId: 222121991
This commit is contained in:
parent
d5a620d8b9
commit
03bbe80587
@ -248,6 +248,7 @@ def generated_test_models():
|
||||
"sum",
|
||||
"l2norm",
|
||||
"l2_pool",
|
||||
"leaky_relu",
|
||||
"less",
|
||||
"less_equal",
|
||||
"local_response_norm",
|
||||
|
@ -378,6 +378,20 @@ Options {
|
||||
}
|
||||
```
|
||||
|
||||
**LEAKY_RELU**
|
||||
|
||||
```
|
||||
Inputs {
|
||||
0: a tensor
|
||||
}
|
||||
Outputs {
|
||||
0: a tensor equivalent to max(input, input * alpha)
|
||||
}
|
||||
Options {
|
||||
alpha: slope of the activation at x < 0 (provided alpha <= 1)
|
||||
}
|
||||
```
|
||||
|
||||
**LESS**
|
||||
|
||||
```
|
||||
|
@ -288,8 +288,8 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently, got %d.",
|
||||
input->type);
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -309,8 +309,8 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently, got %d.",
|
||||
input->type);
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -328,8 +328,8 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently, got %d.",
|
||||
input->type);
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -367,8 +367,8 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently, got %d.",
|
||||
input->type);
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -407,9 +407,8 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently, got %d.",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -604,8 +603,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Only float32 and uint8_t supported currently, got %d.",
|
||||
input->type);
|
||||
context, "Only float32 and uint8_t supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -636,8 +635,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently., got %d",
|
||||
input->type);
|
||||
context->ReportError(context, "Only float32 supported currently., got %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -652,8 +651,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
if (input->type != kTfLiteFloat32) {
|
||||
context->ReportError(context, "Only float32 supported currently, got %d.",
|
||||
input->type);
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
|
||||
@ -663,6 +662,28 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
|
||||
|
||||
LeakyReluParams op_params;
|
||||
op_params.alpha = params->alpha;
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
optimized_ops::LeakyRelu(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(context, "Only float32 supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace activations
|
||||
|
||||
TfLiteRegistration* Register_RELU() {
|
||||
@ -721,6 +742,13 @@ TfLiteRegistration* Register_PRELU() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LEAKY_RELU() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
|
||||
activations::GenericPrepare,
|
||||
activations::LeakyReluEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -606,6 +606,39 @@ TEST(FloatActivationsOpTest, PRelu) {
|
||||
}));
|
||||
}
|
||||
|
||||
class LeakyReluOpModel : public SingleOpModel {
|
||||
public:
|
||||
LeakyReluOpModel(const TensorData& input, float alpha) {
|
||||
input_ = AddInput(input);
|
||||
output_ = AddOutput(input);
|
||||
SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
|
||||
CreateLeakyReluOptions(builder_, alpha).Union());
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(FloatActivationsOpTest, LeakyRelu) {
|
||||
LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);
|
||||
|
||||
m.SetInput({
|
||||
0.0f, 1.0f, 3.0f, // Row 1
|
||||
1.0f, -1.0f, -2.0f, // Row 2
|
||||
});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({
|
||||
0.0f, 1.0f, 3.0f, // Row 1
|
||||
1.0f, -0.5f, -1.0f, // Row 2
|
||||
}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -65,6 +65,7 @@ using reference_ops::Greater;
|
||||
using reference_ops::GreaterEqual;
|
||||
using reference_ops::GreaterEqualWithScaling;
|
||||
using reference_ops::GreaterWithScaling;
|
||||
using reference_ops::LeakyRelu;
|
||||
using reference_ops::Less;
|
||||
using reference_ops::LessEqual;
|
||||
using reference_ops::LessEqualWithScaling;
|
||||
|
@ -558,6 +558,19 @@ inline void ReluX(const tflite::ActivationParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void LeakyRelu(const tflite::LeakyReluParams& params,
|
||||
const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("LeakyRelu (not fused)");
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
const float val = input_data[i];
|
||||
// Note that this implementation matches that of TensorFlow, and corresponds
|
||||
// to the traditional LeakyRelu equation only for alpha <= 1.
|
||||
output_data[i] = std::max(val, val * params.alpha);
|
||||
}
|
||||
}
|
||||
|
||||
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
const RuntimeShape& input_shape,
|
||||
const float* input_data,
|
||||
|
@ -1006,6 +1006,10 @@ struct UnpackParams {
|
||||
int16 axis;
|
||||
};
|
||||
|
||||
struct LeakyReluParams {
|
||||
float alpha;
|
||||
};
|
||||
|
||||
template <typename P>
|
||||
inline void SetActivationParams(float min, float max, P* params) {
|
||||
params->float_activation_min = min;
|
||||
|
@ -123,6 +123,7 @@ TfLiteRegistration* Register_SQUARE();
|
||||
TfLiteRegistration* Register_ZEROS_LIKE();
|
||||
TfLiteRegistration* Register_FLOOR_MOD();
|
||||
TfLiteRegistration* Register_RANGE();
|
||||
TfLiteRegistration* Register_LEAKY_RELU();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -256,6 +257,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
|
||||
AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD());
|
||||
AddBuiltin(BuiltinOperator_RANGE, Register_RANGE());
|
||||
AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -755,6 +755,34 @@ def make_prelu_tests(zip_path):
|
||||
use_frozen_graph=True)
|
||||
|
||||
|
||||
def make_leaky_relu_tests(zip_path):
|
||||
"""Make a set of tests to do LeakyRelu."""
|
||||
|
||||
test_parameters = [
|
||||
{
|
||||
"input_shape": [[], [1], [5], [1, 10, 10, 3], [3, 3, 3, 3]],
|
||||
"alpha": [0.1, 1.0, 2.0, -0.1, -1.0, -2.0],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the graph for the test case."""
|
||||
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=tf.float32, name="input", shape=parameters["input_shape"])
|
||||
out = tf.nn.leaky_relu(input_tensor, alpha=parameters["alpha"])
|
||||
return [input_tensor], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
"""Build the inputs for the test case."""
|
||||
input_values = create_tensor_data(
|
||||
np.float32, parameters["input_shape"], min_value=-3, max_value=10)
|
||||
return [input_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
# This function tests various TensorFLow functions that generates Const op,
|
||||
# including `tf.ones`, `tf.zeros` and random functions.
|
||||
def make_constant_tests(zip_path):
|
||||
|
@ -1714,6 +1714,7 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
|
||||
case OperatorType::kRelu1:
|
||||
case OperatorType::kRelu6:
|
||||
case OperatorType::kPRelu:
|
||||
case OperatorType::kLeakyRelu:
|
||||
case OperatorType::kSoftmax:
|
||||
case OperatorType::kLogSoftmax:
|
||||
case OperatorType::kLog:
|
||||
|
@ -2217,6 +2217,21 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertLeakyReluOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
CHECK_EQ(node.op(), "LeakyRelu");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
|
||||
CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
|
||||
const auto& input_name = node.input(0);
|
||||
auto* op = new LeakyReluOperator;
|
||||
op->inputs.push_back(input_name);
|
||||
op->outputs.push_back(node.name());
|
||||
op->alpha = GetFloatAttr(node, "alpha");
|
||||
model->operators.emplace_back(op);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
@ -2280,6 +2295,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>},
|
||||
{"Identity", ConvertIdentityOperator},
|
||||
{"LRN", ConvertLRNOperator},
|
||||
{"LeakyRelu", ConvertLeakyReluOperator},
|
||||
{"LegacyFedInput", ConvertPlaceholderOperator},
|
||||
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
|
||||
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
|
||||
|
@ -152,7 +152,8 @@ enum class OperatorType : uint8 {
|
||||
kCTCBeamSearchDecoder,
|
||||
kUnpack,
|
||||
kZerosLike,
|
||||
kResizeNearestNeighbor
|
||||
kResizeNearestNeighbor,
|
||||
kLeakyRelu
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -699,6 +700,19 @@ struct PReluOperator : Operator {
|
||||
PReluOperator() : Operator(OperatorType::kPRelu) {}
|
||||
};
|
||||
|
||||
// LeakyRelu
|
||||
// x -> max(x, alpha * x)
|
||||
//
|
||||
// Inputs:
|
||||
// inputs[0]: required: the input array
|
||||
//
|
||||
// TensorFlow equivalent: LeakyRelu
|
||||
struct LeakyReluOperator : Operator {
|
||||
LeakyReluOperator() : Operator(OperatorType::kLeakyRelu) {}
|
||||
|
||||
float alpha = 0.2f; // 0.2 matches the default value for the TF op attribute.
|
||||
};
|
||||
|
||||
// Element-wise Logistic operator:
|
||||
// x -> Logistic(x) = 1 / (1 + exp(-x))
|
||||
//
|
||||
|
@ -1218,6 +1218,24 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
};
|
||||
|
||||
class LeakyRelu
|
||||
: public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
|
||||
::tflite::BuiltinOptions_LeakyReluOptions> {
|
||||
public:
|
||||
using BuiltinOperator::BuiltinOperator;
|
||||
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||
const TocoOperator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
|
||||
}
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {
|
||||
op->alpha = options.alpha();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
};
|
||||
|
||||
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
|
||||
const string& tensorflow_node_def) {
|
||||
auto fbb = absl::make_unique<flexbuffers::Builder>();
|
||||
@ -1516,6 +1534,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
OperatorType::kOneHot));
|
||||
ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
|
||||
OperatorType::kUnpack));
|
||||
ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
|
||||
OperatorType::kLeakyRelu));
|
||||
|
||||
// Custom Operators.
|
||||
ops.push_back(
|
||||
|
@ -517,6 +517,14 @@ TEST_F(OperatorTest, BuiltinUnpack) {
|
||||
EXPECT_EQ(op.axis, output_toco_op->axis);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinLeakyRelu) {
|
||||
LeakyReluOperator op;
|
||||
op.alpha = 3;
|
||||
auto output_toco_op = SerializeAndDeserialize(
|
||||
GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op);
|
||||
EXPECT_EQ(op.alpha, output_toco_op->alpha);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
|
||||
CTCBeamSearchDecoderOperator op;
|
||||
op.beam_width = 3;
|
||||
|
@ -411,6 +411,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
|
||||
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
|
||||
HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
x
Reference in New Issue
Block a user