Changed accumulator precision for devices that does not support round to nearest.

PiperOrigin-RevId: 299852182
Change-Id: I386b9518201675e1357027de483f82419ba2cd7c
This commit is contained in:
Raman Sarokin 2020-03-09 08:55:06 -07:00 committed by TensorFlower Gardener
parent adc1b86230
commit 9f93835bbd
2 changed files with 13 additions and 1 deletions

View File

@ -99,6 +99,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/metal:api",
"//tensorflow/lite/delegates/gpu/metal:buffer_convert",
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
"//tensorflow/lite/delegates/gpu/metal:environment",
"//tensorflow/lite/delegates/gpu/metal:inference_context",
"@com_google_absl//absl/types:span",
],

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/api.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
#include "tensorflow/lite/delegates/gpu/metal/common.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
@ -295,7 +296,17 @@ class Delegate {
if (options_.allow_precision_loss) {
storage_type_size = sizeof(HalfBits);
runtime_options.storage_precision = RuntimeOptions::Precision::FP16;
runtime_options.accumulator_precision = RuntimeOptions::Precision::FP16;
const auto gpu_type = GetGpuType();
const bool powervr = gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8 ||
gpu_type == GpuType::kA9 || gpu_type == GpuType::kA10;
if (powervr) {
// PowerVR gpus support only round to zero for floating-point operations,
// to increase precision we will use F32 accumulator in this case
runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32;
} else {
// Apple own gpus support round to nearest and have better precision
runtime_options.accumulator_precision = RuntimeOptions::Precision::FP16;
}
} else {
storage_type_size = sizeof(float);
runtime_options.storage_precision = RuntimeOptions::Precision::FP32;