PReLU fixed for PowerVR.

PiperOrigin-RevId: 265574591
This commit is contained in:
A. Unique TensorFlower 2019-08-26 17:10:38 -07:00 committed by TensorFlower Gardener
parent 149cba920c
commit 560d2b422c
2 changed files with 10 additions and 5 deletions

View File

@ -24,10 +24,11 @@ namespace tflite {
namespace gpu {
namespace cl {
PReLU::PReLU(const OperationDef& definition, const PReLUAttributes& attr)
PReLU::PReLU(const OperationDef& definition, const PReLUAttributes& attr,
CalculationsPrecision scalar_precision)
: ElementwiseOperation(definition) {
if (attr.clip != 0) {
clip_ = FLT(definition.precision, attr.clip);
clip_ = FLT(scalar_precision, attr.clip);
}
}
@ -67,7 +68,7 @@ std::string PReLU::GetCoreCode(const std::string& src,
std::string PReLU::GetArgsDeclaration() const {
std::string args = absl::StrCat(",\n ", alpha_.GetDeclaration());
if (clip_.Active()) {
args = absl::StrCat(args, ",\n ", clip_.GetDeclaration());
absl::StrAppend(&args, ",\n ", clip_.GetDeclaration());
}
return args;
}
@ -88,7 +89,10 @@ Status CreatePReLU(const CreationContext& creation_context,
if (!alpha) {
return InvalidArgumentError("Alpha is missing");
}
*result = PReLU(definition, attr);
const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32
: definition.precision;
*result = PReLU(definition, attr, scalar_precision);
RETURN_IF_ERROR(result->UploadParameters(*alpha, creation_context.context));
result->SetLinkIndex(0);
return OkStatus();

View File

@ -52,7 +52,8 @@ class PReLU : public ElementwiseOperation {
const PReLUAttributes& attr, PReLU* result);
private:
PReLU(const OperationDef& definition, const PReLUAttributes& attr);
PReLU(const OperationDef& definition, const PReLUAttributes& attr,
CalculationsPrecision scalar_precision);
template <DataType T>
Status UploadParameters(const ::tflite::gpu::Tensor<Linear, T>& parameters,