GpuType replaced with DeviceInfo.

PiperOrigin-RevId: 303447708
Change-Id: If384b4b8acad6eec4861b392f3765575fd85de15
This commit is contained in:
Raman Sarokin 2020-03-27 18:46:02 -07:00 committed by TensorFlower Gardener
parent 86f5733999
commit 1303fce83c
15 changed files with 83 additions and 228 deletions

View File

@ -173,29 +173,6 @@ objc_library(
],
)
objc_library(
name = "environment_test_lib",
testonly = 1,
srcs = ["environment_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":environment",
"//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
],
)
ios_unit_test(
name = "environment_test",
testonly = 1,
minimum_os_version = "10.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_android",
],
deps = [":environment_test_lib"],
)
objc_library(
name = "inference_context",
srcs = ["inference_context.mm"],
@ -273,7 +250,6 @@ objc_library(
srcs = [
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
"//tensorflow/lite/delegates/gpu/metal:compiled_model_test.mm",
"//tensorflow/lite/delegates/gpu/metal:environment_test.mm",
"//tensorflow/lite/delegates/gpu/metal:inference_context_test.mm",
],
hdrs = [

View File

@ -88,12 +88,14 @@ std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
std::vector<ComputeTaskDescriptorPtr> SelectConvolutionTransposed(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& attr,
const ConvolutionTransposedAttributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
if (CheckConvolutionTransposed4x4Support(attr)) {
return ConvolutionTransposed4x4(id, input_id, output_id, attr, options);
return ConvolutionTransposed4x4(id, input_id, output_id, attr, device_info,
options);
} else {
return ConvolutionTransposed(id, input_id, output_id, attr, options);
return ConvolutionTransposed(id, input_id, output_id, attr, device_info,
options);
}
}
@ -165,6 +167,7 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const std::vector<ValueId>& inputs,
const std::vector<ValueId>& outputs,
const DeviceInfo& device_info,
const RuntimeOptions& options,
int* last_node_id, int* last_value_id,
std::vector<ComputeTaskDescriptorPtr>* tasks) {
@ -219,8 +222,9 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
(*last_node_id) += 1;
auto t1 = ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
conv_shape, attr, options);
auto t1 =
ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
conv_shape, attr, device_info, options);
tasks->insert(tasks->end(), t1.begin(), t1.end());
Winograd36To4x4Attributes wino_down_attr;
@ -233,7 +237,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
(*last_value_id) += 2;
} else {
*tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
attr, options);
attr, device_info, options);
}
break;
}
@ -242,7 +246,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
node_id, inputs[0], outputs[0],
absl::any_cast<ConvolutionTransposedAttributes>(
node->operation.attributes),
options);
device_info, options);
break;
case OperationType::DEPTHWISE_CONVOLUTION:
*tasks =
@ -255,7 +259,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
*tasks = FullyConnected(
node_id, inputs[0], outputs[0],
absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
options);
device_info, options);
break;
case OperationType::MAX_UNPOOLING_2D:
*tasks = MaxUnpooling(
@ -388,7 +392,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
} // namespace
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
const RuntimeOptions& options,
CompiledModel* compiled_model) {
int last_node_id = 0;
for (const auto& node : graph.nodes()) {
@ -412,7 +417,7 @@ absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
if (!custom_status.ok()) {
auto primary_status =
RegisterPrimaryOps(graph, node, inputs, outputs, options,
RegisterPrimaryOps(graph, node, inputs, outputs, device_info, options,
&last_node_id, &last_value_id, &tasks);
if (!primary_status.ok()) {
return absl::UnimplementedError(

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -26,7 +27,8 @@ namespace gpu {
namespace metal {
// Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions.
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
const RuntimeOptions& options,
CompiledModel* compiled_model);
} // namespace metal

View File

@ -22,16 +22,6 @@ namespace tflite {
namespace gpu {
namespace metal {
enum class GpuType {
kUnknown,
kA7, // iPhone 5s, iPad Air, iPad Mini 2, iPad Mini 3.
kA8, // A8 iPhone 6, A8X iPad Air 2, iPad Mini 4.
kA9, // A9 iPhone 6s, iPad (2017), A9X iPad Pro (1st generation).
kA10, // iPhone 7, iPad (2018), A10X iPad Pro (2nd generation).
kA11, // iPhone 8/X.
kA12, // iPhone Xs.
};
enum class AppleGPU {
kUnknown,
kA7,
@ -53,6 +43,8 @@ struct AppleGPUInfo {
explicit AppleGPUInfo(const std::string& device_name);
AppleGPU gpu_type;
bool IsLocalMemoryPreferredOverGlobal() const;
bool IsBionic() const;
// floating point rounding mode
@ -72,8 +64,6 @@ struct DeviceInfo {
int GetComputeUnitsCount() const;
};
GpuType GetGpuType();
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -15,13 +15,8 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#import <Metal/Metal.h>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/lite/delegates/gpu/metal/common.h"
#include <map>
#include <string>
namespace tflite {
namespace gpu {
@ -50,6 +45,12 @@ AppleGPUInfo::AppleGPUInfo(const std::string& device_name) {
}
}
bool AppleGPUInfo::IsLocalMemoryPreferredOverGlobal() const {
return gpu_type == AppleGPU::kA7 ||
gpu_type == AppleGPU::kA8 ||
gpu_type == AppleGPU::kA8X;
}
bool AppleGPUInfo::IsBionic() const {
return gpu_type == AppleGPU::kA11 ||
gpu_type == AppleGPU::kA12 ||
@ -103,72 +104,6 @@ int DeviceInfo::GetComputeUnitsCount() const {
return apple_info.GetComputeUnitsCount();
}
GpuType GetGpuType() {
int max_feature_set = 0;
#if defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0
std::vector<std::pair<MTLFeatureSet, int>> features;
if (@available(iOS 8.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v1, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v1, 8);
}
if (@available(iOS 9.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v2, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v2, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v1, 9);
}
if (@available(iOS 10.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v3, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v3, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v2, 9);
}
if (@available(iOS 11.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v4, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v3, 9);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v1, 11);
}
if (@available(iOS 12.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v5, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v5, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v4, 9);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v2, 11);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily5_v1, 12);
}
id<MTLDevice> device = GetBestSupportedMetalDevice();
for (auto &type : features) {
if ([device supportsFeatureSet:type.first]) {
max_feature_set = std::max(max_feature_set, type.second);
}
}
#elif defined(__MAC_10_5) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_5
std::vector<std::pair<MTLFeatureSet, int>> features;
if (@available(macOS 10.15, *)) {
features.emplace_back(MTLFeatureSet_macOS_GPUFamily2_v1, 12);
}
id<MTLDevice> device = GetBestSupportedMetalDevice();
for (auto &type : features) {
if ([device supportsFeatureSet:type.first]) {
max_feature_set = std::max(max_feature_set, type.second);
}
}
#endif
switch (max_feature_set) {
case 7:
return GpuType::kA7;
case 8:
return GpuType::kA8;
case 9:
return GpuType::kA9;
case 10:
return GpuType::kA10;
case 11:
return GpuType::kA11;
case 12:
return GpuType::kA12;
default:
return GpuType::kUnknown;
};
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -1,49 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#import <XCTest/XCTest.h>
#include "tensorflow/lite/delegates/gpu/metal/common.h"
using ::tflite::gpu::metal::GetGpuType;
@interface EnvironmentTest : XCTestCase
@end
@implementation EnvironmentTest
- (void)testCompileTimeOSDetection {
#if IOS_VERSION > 0
XCTAssertTrue(MACOS_VERSION == 0 && TVOS_VERSION == 0, @"IOS_VERSION: %d", int{IOS_VERSION});
#endif
#if MACOS_VERSION > 0
XCTAssertTrue(IOS_VERSION == 0 && TVOS_VERSION == 0, @"MACOS_VERSION: %d", int{MACOS_VERSION});
#endif
#if TVOS_VERSION > 0
XCTAssertTrue(IOS_VERSION == 0 && MACOS_VERSION == 0, @"TVOS_VERSION: %d", int{TVOS_VERSION});
#endif
}
- (void)testGetGpuType {
#if (IOS_VERSION > 0) || (TVOS_VERSION > 0)
auto gpuType = GetGpuType();
XCTAssertTrue(gpuType != GpuType::kUnknown);
#endif
}
@end

View File

@ -833,6 +833,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/metal:api",
"//tensorflow/lite/delegates/gpu/metal:common",
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
"//tensorflow/lite/delegates/gpu/metal:environment",
"//tensorflow/lite/delegates/gpu/metal:inference_context",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@FP16",

View File

@ -659,32 +659,19 @@ bool IsKernelYIs1(const Convolution2DAttributes& attr) {
attr.padding.appended.h == 0;
}
int GetMaximumPossibleWavesCount(const BHWC& dst_shape, GpuType gpu) {
if (gpu == GpuType::kA7 || gpu == GpuType::kA8) {
int GetMaximumPossibleWavesCount(const AppleGPUInfo& apple_info,
const BHWC& dst_shape) {
if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1});
} else {
return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1});
}
}
int GetCountOfComputeUnits(GpuType gpu) {
if (gpu == GpuType::kA7 || gpu == GpuType::kA8) {
return 4;
} else if (gpu == GpuType::kA9 || gpu == GpuType::kA10) {
return 6;
} else if (gpu == GpuType::kA11) {
return 3;
} else if (gpu == GpuType::kA12) {
return 4;
} else {
// unknown gpu
return 4;
}
}
int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) {
const int max_waves = GetMaximumPossibleWavesCount(dst_shape, gpu);
const int cu_count = GetCountOfComputeUnits(gpu);
int GetRecommendedBlockSize(const AppleGPUInfo& apple_info,
const BHWC& dst_shape) {
const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
const int cu_count = apple_info.GetComputeUnitsCount();
if (max_waves >= cu_count * 64) {
return 8;
} else if (max_waves >= cu_count * 32) {
@ -696,8 +683,9 @@ int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) {
}
}
ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
const BHWC& dst_shape, GpuType gpu) {
ConvParams GetConvParamsForA7A8(const AppleGPUInfo& apple_info,
const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4);
const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4);
@ -711,7 +699,7 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
params.linear_whs = false;
params.work_group_launch_order = int3(0, 1, 2);
int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu);
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
params.block_size.z = 4;
@ -771,14 +759,14 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
return params;
}
ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
const BHWC& dst_shape, GpuType gpu) {
ConvParams GetConvParamsForA9AndHigher(const AppleGPUInfo& apple_info,
const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4);
const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4);
int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu);
bool apple_gpu = gpu == GpuType::kA11 || gpu == GpuType::kA12;
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
int3 block_size = int3(1, 1, 1);
if (blk_total_size >= 2 && apple_gpu) {
if (blk_total_size >= 2 && apple_info.IsBionic()) {
if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) {
block_size.x = 2;
} else {
@ -816,7 +804,7 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
params.work_group_size = int3(32, 1, 1);
params.work_group_launch_order = int3(0, 1, 2);
}
float precise_threshold = gpu == GpuType::kA12 ? 1.0f : 1.04f;
float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f;
float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
if (precise_ratio > precise_threshold) {
params.linear_wh = false;
@ -852,13 +840,13 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
return params;
}
ConvParams GetConvParams(const Convolution2DAttributes& attr,
ConvParams GetConvParams(const DeviceInfo& device_info,
const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
auto gpu_type = GetGpuType();
if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) {
return GetConvParamsForA7A8(attr, dst_shape, gpu_type);
if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
return GetConvParamsForA7A8(device_info.apple_info, attr, dst_shape);
} else {
return GetConvParamsForA9AndHigher(attr, dst_shape, gpu_type);
return GetConvParamsForA9AndHigher(device_info.apple_info, attr, dst_shape);
}
}
@ -898,8 +886,9 @@ std::pair<uint3, uint3> GetDispatchSizes(const ConvParams& params,
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) {
ConvParams params = GetConvParams(attr, dst_shape);
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
ConvParams params = GetConvParams(device_info, attr, dst_shape);
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
@ -953,7 +942,8 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const RuntimeOptions& options) {
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options) {
const int dst_slices = IntegralDivideRoundUp(attr.weights.shape.o, 4);
ConvParams params;
params.work_group_launch_order = int3(2, 0, 1);
@ -965,8 +955,7 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
params.different_weights_for_height = true;
params.x_kernel_is_1 = true;
params.y_kernel_is_1 = true;
auto gpu_type = GetGpuType();
if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) {
if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
params.work_group_size = int3(32, 1, 1);
params.block_size = int3(4, 1, 4);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -29,11 +30,13 @@ namespace metal {
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const RuntimeOptions& options);
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const RuntimeOptions& options);
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options);
} // namespace metal
} // namespace gpu

View File

@ -119,12 +119,12 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
std::vector<ComputeTaskDescriptorPtr> FullyConnected(
int id, ValueId input_id, ValueId output_id,
const FullyConnectedAttributes& attr, const RuntimeOptions& options) {
const FullyConnectedAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
auto gpu_type = GetGpuType();
bool shared = gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8;
bool shared = device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
desc->shader_source =
GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -33,7 +34,8 @@ namespace metal {
// will be inefficient
std::vector<ComputeTaskDescriptorPtr> FullyConnected(
int id, ValueId input_id, ValueId output_id,
const FullyConnectedAttributes& attr, const RuntimeOptions& options);
const FullyConnectedAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options);
} // namespace metal
} // namespace gpu

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
namespace tflite {
namespace gpu {
@ -77,15 +78,17 @@ absl::Status SingleOpModel::Invoke() {
output_ids.push_back(output.id);
}
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
std::string device_name = std::string([[device name] UTF8String]);
DeviceInfo device_info(device_name);
RuntimeOptions options;
options.storage_precision = RuntimeOptions::Precision::FP32;
options.accumulator_precision = RuntimeOptions::Precision::FP32;
CompiledModel compiled_model;
RETURN_IF_ERROR(Compile(graph_, options, &compiled_model));
RETURN_IF_ERROR(Compile(graph_, device_info, options, &compiled_model));
CompiledModel optimized_model;
RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model));
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
TFLInferenceContext* graph = [[TFLInferenceContext alloc] init];
RETURN_IF_ERROR([graph compileModelWithDevice:device
taskDescriptors:optimized_model

View File

@ -442,7 +442,7 @@ std::string GetDeconvolution4x4(const int2& block_size, bool use_local_mem) {
std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& params,
const RuntimeOptions& options) {
const DeviceInfo& device_info, const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
@ -454,9 +454,8 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed(
const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4);
const int shared_size =
sizeof(float) * 4 * src_depth * src_local_size_x * src_local_size_y;
auto gpu_type = GetGpuType();
if (shared_size < 1000 * 16 &&
(gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8)) {
device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
desc->shader_source =
GetDeconvolutionShared(params, kThreadGroupWidth, kThreadGroupHeight);
} else {
@ -543,7 +542,7 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed(
std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed4x4(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& params,
const RuntimeOptions& options) {
const DeviceInfo& device_info, const RuntimeOptions& options) {
const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4);
const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4);
const int kernel_x = 4;
@ -596,12 +595,10 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed4x4(
desc->id = id;
desc->is_linkable = false;
const auto gpu_type = GetGpuType();
const bool powervr = gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8 ||
gpu_type == GpuType::kA9 || gpu_type == GpuType::kA10;
const bool recommended_2x =
!powervr && options.storage_precision == RuntimeOptions::Precision::FP16;
const bool use_local_mem = powervr;
device_info.apple_info.IsBionic() &&
options.storage_precision == RuntimeOptions::Precision::FP16;
const bool use_local_mem = !device_info.apple_info.IsBionic();
const int2 block_size(recommended_2x ? 2 : 1, 1);
desc->shader_source = GetDeconvolution4x4(block_size, use_local_mem);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -30,12 +31,12 @@ namespace metal {
std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& params,
const RuntimeOptions& options);
const DeviceInfo& device_info, const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> ConvolutionTransposed4x4(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& params,
const RuntimeOptions& options);
const DeviceInfo& device_info, const RuntimeOptions& options);
bool CheckConvolutionTransposed4x4Support(
const ConvolutionTransposedAttributes& attr);

View File

@ -391,7 +391,7 @@ class Delegate {
// TODO(impjdi): Merge these.
CompiledModel compiled_model;
RETURN_IF_ERROR(Compile(graph, runtime_options, &compiled_model));
RETURN_IF_ERROR(Compile(graph, device_info, runtime_options, &compiled_model));
CompiledModel optimized_model;
RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model));