PReLU fixed for PowerVR.
PiperOrigin-RevId: 265574591
This commit is contained in:
parent
149cba920c
commit
560d2b422c
@ -24,10 +24,11 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
PReLU::PReLU(const OperationDef& definition, const PReLUAttributes& attr)
|
||||
PReLU::PReLU(const OperationDef& definition, const PReLUAttributes& attr,
|
||||
CalculationsPrecision scalar_precision)
|
||||
: ElementwiseOperation(definition) {
|
||||
if (attr.clip != 0) {
|
||||
clip_ = FLT(definition.precision, attr.clip);
|
||||
clip_ = FLT(scalar_precision, attr.clip);
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,7 +68,7 @@ std::string PReLU::GetCoreCode(const std::string& src,
|
||||
std::string PReLU::GetArgsDeclaration() const {
|
||||
std::string args = absl::StrCat(",\n ", alpha_.GetDeclaration());
|
||||
if (clip_.Active()) {
|
||||
args = absl::StrCat(args, ",\n ", clip_.GetDeclaration());
|
||||
absl::StrAppend(&args, ",\n ", clip_.GetDeclaration());
|
||||
}
|
||||
return args;
|
||||
}
|
||||
@ -88,7 +89,10 @@ Status CreatePReLU(const CreationContext& creation_context,
|
||||
if (!alpha) {
|
||||
return InvalidArgumentError("Alpha is missing");
|
||||
}
|
||||
*result = PReLU(definition, attr);
|
||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||
? CalculationsPrecision::F32
|
||||
: definition.precision;
|
||||
*result = PReLU(definition, attr, scalar_precision);
|
||||
RETURN_IF_ERROR(result->UploadParameters(*alpha, creation_context.context));
|
||||
result->SetLinkIndex(0);
|
||||
return OkStatus();
|
||||
|
@ -52,7 +52,8 @@ class PReLU : public ElementwiseOperation {
|
||||
const PReLUAttributes& attr, PReLU* result);
|
||||
|
||||
private:
|
||||
PReLU(const OperationDef& definition, const PReLUAttributes& attr);
|
||||
PReLU(const OperationDef& definition, const PReLUAttributes& attr,
|
||||
CalculationsPrecision scalar_precision);
|
||||
|
||||
template <DataType T>
|
||||
Status UploadParameters(const ::tflite::gpu::Tensor<Linear, T>& parameters,
|
||||
|
Loading…
Reference in New Issue
Block a user