Added MetalExecutionEnvironment that similar to ClExecutionEnvironment.
RunGraph replaced with MetalExecutionEnvironment.ExecuteGPUOperation in ops tests. PiperOrigin-RevId: 348476614 Change-Id: I7a7781e619aeff534c82dacb954b8a05a84151ab
This commit is contained in:
parent
0f99ddcfe3
commit
208675a35a
@ -63,7 +63,7 @@ absl::Status ClExecutionEnvironment::ExecuteGPUOperation(
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context, src_shape,
|
||||
op_def.src_tensors[0], &src[i]));
|
||||
op_def.src_tensors[i], &src[i]));
|
||||
RETURN_IF_ERROR(src[i].WriteData(creation_context.queue, src_cpu[i]));
|
||||
operation->SetSrc(&src[i], i);
|
||||
}
|
||||
@ -76,7 +76,7 @@ absl::Status ClExecutionEnvironment::ExecuteGPUOperation(
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context, dst_shape,
|
||||
op_def.dst_tensors[0], &dst[i]));
|
||||
op_def.dst_tensors[i], &dst[i]));
|
||||
|
||||
operation->SetDst(&dst[i], i);
|
||||
}
|
||||
@ -111,7 +111,7 @@ absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context, src_shape,
|
||||
op_def.src_tensors[0], &src[i]));
|
||||
op_def.src_tensors[i], &src[i]));
|
||||
RETURN_IF_ERROR(src[i].WriteData(creation_context.queue, src_cpu[i]));
|
||||
operation->SetSrc(&src[i], i);
|
||||
}
|
||||
@ -124,7 +124,7 @@ absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context, dst_shape,
|
||||
op_def.dst_tensors[0], &dst[i]));
|
||||
op_def.dst_tensors[i], &dst[i]));
|
||||
|
||||
operation->SetDst(&dst[i], i);
|
||||
}
|
||||
|
@ -803,9 +803,12 @@ objc_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"test_util.mm",
|
||||
"test_util.cc",
|
||||
],
|
||||
hdrs = ["test_util.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"Metal",
|
||||
],
|
||||
@ -824,8 +827,10 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/metal:common",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor",
|
||||
"@FP16",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -284,13 +284,7 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
src_tensor.data[i] = sin(i);
|
||||
}
|
||||
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs_v0;
|
||||
inputs_v0[0] = src_tensor;
|
||||
std::map<ValueId, TensorFloat32> outputs_v0;
|
||||
outputs_v0[1].shape = dst_shape;
|
||||
outputs_v0[1].data.resize(dst_shape.DimensionsProduct());
|
||||
TensorFloat32 output0;
|
||||
|
||||
tflite::gpu::OperationDef op_def;
|
||||
op_def.precision = tflite::gpu::CalculationsPrecision::F32;
|
||||
@ -299,61 +293,43 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
op_def.src_tensors.push_back(tensor_descriptor);
|
||||
op_def.dst_tensors.push_back(tensor_descriptor);
|
||||
|
||||
std::string device_name = std::string([[device name] UTF8String]);
|
||||
tflite::gpu::GpuInfo gpu_info;
|
||||
tflite::gpu::GetGpuInfoFromDeviceDescription(device_name, tflite::gpu::GpuApi::kMetal, &gpu_info);
|
||||
auto gpu_op0 = ConvolutionGeneric(op_def, dst_shape, attr, gpu_info);
|
||||
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op0));
|
||||
nodes[0].src_tensors_ids = {0};
|
||||
nodes[0].dst_tensors_ids = {1};
|
||||
auto status = RunGraph(nodes, device, inputs_v0, &outputs_v0);
|
||||
tflite::gpu::metal::MetalExecutionEnvironment env;
|
||||
auto gpu_op0 = ConvolutionGeneric(op_def, dst_shape, attr, env.GetGpuInfo());
|
||||
auto op0_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op0));
|
||||
auto status = env.ExecuteGPUOperation(src_tensor, std::move(op0_ptr), dst_shape, &output0);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
tflite::gpu::metal::Winograd4x4To36Attributes wino_up_attr;
|
||||
wino_up_attr.padding = attr.padding;
|
||||
auto gpu_op1 = tflite::gpu::metal::Winograd4x4To36(op_def, wino_up_attr);
|
||||
auto op1_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op1));
|
||||
|
||||
auto gpu_op2 = ConvolutionWino4x4To6x6(op_def, conv_shape, attr, gpu_info);
|
||||
auto gpu_op2 = ConvolutionWino4x4To6x6(op_def, conv_shape, attr, env.GetGpuInfo());
|
||||
auto op2_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op2));
|
||||
|
||||
tflite::gpu::metal::Winograd36To4x4Attributes wino_down_attr;
|
||||
wino_down_attr.output_shape = dst_shape;
|
||||
wino_down_attr.biases = attr.bias;
|
||||
auto gpu_op3 = tflite::gpu::metal::Winograd36To4x4(op_def, wino_down_attr);
|
||||
auto op3_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op3));
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs_v1;
|
||||
inputs_v1[0] = src_tensor;
|
||||
std::map<ValueId, TensorFloat32> outputs_v1;
|
||||
outputs_v1[2].shape = conv_shape;
|
||||
outputs_v1[2].shape.c = src_shape.c;
|
||||
outputs_v1[2].data.resize(outputs_v1[2].shape.DimensionsProduct());
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op1));
|
||||
nodes[0].src_tensors_ids = {0};
|
||||
nodes[0].dst_tensors_ids = {2};
|
||||
status = RunGraph(nodes, device, inputs_v1, &outputs_v1);
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs_v2;
|
||||
inputs_v2[2] = outputs_v1[2];
|
||||
std::map<ValueId, TensorFloat32> outputs_v2;
|
||||
outputs_v2[3].shape = conv_shape;
|
||||
outputs_v2[3].data.resize(outputs_v2[3].shape.DimensionsProduct());
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op2));
|
||||
nodes[0].src_tensors_ids = {2};
|
||||
nodes[0].dst_tensors_ids = {3};
|
||||
status = RunGraph(nodes, device, inputs_v2, &outputs_v2);
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs_v3;
|
||||
inputs_v3[3] = outputs_v2[3];
|
||||
std::map<ValueId, TensorFloat32> outputs_v3;
|
||||
outputs_v3[1].shape = dst_shape;
|
||||
outputs_v3[1].data.resize(outputs_v3[1].shape.DimensionsProduct());
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op3));
|
||||
nodes[0].src_tensors_ids = {3};
|
||||
nodes[0].dst_tensors_ids = {1};
|
||||
status = RunGraph(nodes, device, inputs_v3, &outputs_v3);
|
||||
TensorFloat32 output1;
|
||||
BHWC output1_shape = conv_shape;
|
||||
output1_shape.c = src_shape.c;
|
||||
status = env.ExecuteGPUOperation(src_tensor, std::move(op1_ptr), output1_shape, &output1);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
status = CompareVectors(outputs_v0[1].data, outputs_v3[1].data, 1e-4f);
|
||||
TensorFloat32 output2;
|
||||
BHWC output2_shape = conv_shape;
|
||||
status = env.ExecuteGPUOperation(output1, std::move(op2_ptr), output2_shape, &output2);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
TensorFloat32 output3;
|
||||
BHWC output3_shape = dst_shape;
|
||||
status = env.ExecuteGPUOperation(output2, std::move(op3_ptr), output3_shape, &output3);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
status = CompareVectors(output0.data, output3.data, 1e-4f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
|
@ -19,11 +19,15 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
@ -33,14 +37,14 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
SingleOpModel::SingleOpModel(Operation&& operation, const std::vector<TensorRef<BHWC>>& inputs,
|
||||
SingleOpModel::SingleOpModel(Operation&& operation,
|
||||
const std::vector<TensorRef<BHWC>>& inputs,
|
||||
const std::vector<TensorRef<BHWC>>& outputs) {
|
||||
auto node = graph_.NewNode();
|
||||
node->operation = std::move(operation);
|
||||
@ -88,21 +92,22 @@ absl::Status SingleOpModel::Invoke() {
|
||||
CompiledModel compiled_model;
|
||||
RETURN_IF_ERROR(Compile(graph_, gpu_info, precision, &compiled_model));
|
||||
CompiledModel optimized_model;
|
||||
RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model));
|
||||
RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model,
|
||||
&optimized_model));
|
||||
|
||||
InferenceContext inference_context;
|
||||
RETURN_IF_ERROR(inference_context.CompileModelWithDevice(device, optimized_model, input_ids,
|
||||
output_ids, precision));
|
||||
RETURN_IF_ERROR(inference_context.CompileModelWithDevice(
|
||||
device, optimized_model, input_ids, output_ids, precision));
|
||||
std::map<ValueId, BHWC> input_dimensions;
|
||||
std::map<ValueId, id<MTLBuffer>> input_buffers;
|
||||
for (auto& input : inputs_) {
|
||||
input_dimensions[input.id] = input.shape;
|
||||
NSUInteger elements_count =
|
||||
input.shape.w * input.shape.h * AlignByN(input.shape.c, 4) * input.shape.b;
|
||||
NSUInteger elements_count = input.shape.w * input.shape.h *
|
||||
AlignByN(input.shape.c, 4) * input.shape.b;
|
||||
std::vector<float> src_gpu(elements_count);
|
||||
id<MTLBuffer> input_buffer;
|
||||
RETURN_IF_ERROR(
|
||||
ConvertToPHWC4(absl::MakeConstSpan(input.data), input.shape, absl::MakeSpan(src_gpu)));
|
||||
RETURN_IF_ERROR(ConvertToPHWC4(absl::MakeConstSpan(input.data), input.shape,
|
||||
absl::MakeSpan(src_gpu)));
|
||||
input_buffer = [device newBufferWithBytes:src_gpu.data()
|
||||
length:(elements_count * sizeof(float))
|
||||
options:MTLResourceStorageModeShared];
|
||||
@ -114,16 +119,20 @@ absl::Status SingleOpModel::Invoke() {
|
||||
// Uninitialized output buffer.
|
||||
const ValueId key = outputDimension.first;
|
||||
const BHWC& dims = outputDimension.second;
|
||||
const NSUInteger size = dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
|
||||
output_buffers[key] = [device newBufferWithLength:size options:MTLResourceStorageModeShared];
|
||||
const NSUInteger size =
|
||||
dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
|
||||
output_buffers[key] =
|
||||
[device newBufferWithLength:size options:MTLResourceStorageModeShared];
|
||||
}
|
||||
|
||||
// Inference itself.
|
||||
std::map<ValueId, id<MTLBuffer>> inout_buffers(input_buffers.begin(), input_buffers.end());
|
||||
std::map<ValueId, id<MTLBuffer>> inout_buffers(input_buffers.begin(),
|
||||
input_buffers.end());
|
||||
inout_buffers.insert(output_buffers.begin(), output_buffers.end());
|
||||
id<MTLCommandQueue> command_queue = [device newCommandQueue];
|
||||
id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
|
||||
id<MTLComputeCommandEncoder> command_encoder = [command_buffer computeCommandEncoder];
|
||||
id<MTLComputeCommandEncoder> command_encoder =
|
||||
[command_buffer computeCommandEncoder];
|
||||
inference_context.EncodeWithEncoder(command_encoder, inout_buffers);
|
||||
[command_encoder endEncoding];
|
||||
[command_buffer commit];
|
||||
@ -134,34 +143,40 @@ absl::Status SingleOpModel::Invoke() {
|
||||
NSUInteger elements_count = dim.w * dim.h * AlignByN(dim.c, 4) * dim.b;
|
||||
output.shape = dim;
|
||||
output.data.resize(output.shape.DimensionsProduct());
|
||||
float* output_pointer = reinterpret_cast<float*>([output_buffers[output.id] contents]);
|
||||
RETURN_IF_ERROR(ConvertFromPHWC4(absl::MakeConstSpan(output_pointer, elements_count),
|
||||
output.shape, absl::MakeSpan(output.data)));
|
||||
float* output_pointer =
|
||||
reinterpret_cast<float*>([output_buffers[output.id] contents]);
|
||||
RETURN_IF_ERROR(
|
||||
ConvertFromPHWC4(absl::MakeConstSpan(output_pointer, elements_count),
|
||||
output.shape, absl::MakeSpan(output.data)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CompareVectors(const std::vector<float>& reference, const std::vector<float>& output,
|
||||
float max_error) {
|
||||
absl::Status CompareVectors(const std::vector<float>& reference,
|
||||
const std::vector<float>& output, float max_error) {
|
||||
if (reference.size() != output.size()) {
|
||||
const std::string message = "CompareVectors: vectors size does not match for reference: " +
|
||||
std::to_string(reference.size()) +
|
||||
" vs. output: " + std::to_string(output.size());
|
||||
const std::string message =
|
||||
"CompareVectors: vectors size does not match for reference: " +
|
||||
std::to_string(reference.size()) +
|
||||
" vs. output: " + std::to_string(output.size());
|
||||
return absl::InternalError(message);
|
||||
}
|
||||
for (int i = 0; i < reference.size(); i++) {
|
||||
float error = std::abs(reference[i] - output[i]);
|
||||
if (error > max_error) {
|
||||
const std::string message =
|
||||
"Reference: " + std::to_string(reference[i]) + ", output: " + std::to_string(output[i]) +
|
||||
", error: " + std::to_string(error) + ", max allowed error: " + std::to_string(max_error);
|
||||
"Reference: " + std::to_string(reference[i]) +
|
||||
", output: " + std::to_string(output[i]) +
|
||||
", error: " + std::to_string(error) +
|
||||
", max allowed error: " + std::to_string(max_error);
|
||||
return absl::InternalError(message);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> device,
|
||||
absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes,
|
||||
id<MTLDevice> device,
|
||||
const std::map<ValueId, TensorFloat32>& inputs,
|
||||
std::map<ValueId, TensorFloat32>* outputs) {
|
||||
std::vector<ValueId> inputBufferIDs;
|
||||
@ -177,22 +192,22 @@ absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> de
|
||||
std::map<ValueId, BHWC> outputDimensions;
|
||||
CompiledModel raw_model;
|
||||
raw_model.nodes = nodes;
|
||||
for(const auto& input : inputs) {
|
||||
for (const auto& input : inputs) {
|
||||
raw_model.tensor_shapes[input.first] = input.second.shape;
|
||||
}
|
||||
for(const auto& output : *outputs) {
|
||||
for (const auto& output : *outputs) {
|
||||
outputDimensions[output.first] = output.second.shape;
|
||||
raw_model.tensor_shapes[output.first] = output.second.shape;
|
||||
}
|
||||
CompiledModel optimized_model;
|
||||
RETURN_IF_ERROR(
|
||||
ValidateOptimizeModel(inputBufferIDs, outputBufferIDs, raw_model, &optimized_model));
|
||||
RETURN_IF_ERROR(ValidateOptimizeModel(inputBufferIDs, outputBufferIDs,
|
||||
raw_model, &optimized_model));
|
||||
|
||||
CalculationsPrecision precision = CalculationsPrecision::F32;
|
||||
|
||||
InferenceContext inference_context;
|
||||
RETURN_IF_ERROR(inference_context.CompileModelWithDevice(device, optimized_model, inputBufferIDs,
|
||||
outputBufferIDs, precision));
|
||||
RETURN_IF_ERROR(inference_context.CompileModelWithDevice(
|
||||
device, optimized_model, inputBufferIDs, outputBufferIDs, precision));
|
||||
std::map<ValueId, BHWC> inputDimensions;
|
||||
std::map<ValueId, std::vector<float>> inputBuffersCPU;
|
||||
std::map<ValueId, id<MTLBuffer>> inputBuffersGPU;
|
||||
@ -200,11 +215,12 @@ absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> de
|
||||
const auto& src = input.second;
|
||||
inputDimensions[input.first] = src.shape;
|
||||
const int paddedDepth = AlignByN(src.shape.c, 4);
|
||||
NSUInteger elementsCount = src.shape.w * src.shape.h * paddedDepth * src.shape.b;
|
||||
NSUInteger elementsCount =
|
||||
src.shape.w * src.shape.h * paddedDepth * src.shape.b;
|
||||
std::vector<float> src_gpu(elementsCount);
|
||||
id<MTLBuffer> inputBuffer;
|
||||
RETURN_IF_ERROR(
|
||||
ConvertToPHWC4(absl::MakeConstSpan(src.data), src.shape, absl::MakeSpan(src_gpu)));
|
||||
RETURN_IF_ERROR(ConvertToPHWC4(absl::MakeConstSpan(src.data), src.shape,
|
||||
absl::MakeSpan(src_gpu)));
|
||||
inputBuffer = [device newBufferWithBytes:src_gpu.data()
|
||||
length:(elementsCount * sizeof(float))
|
||||
options:MTLResourceStorageModeShared];
|
||||
@ -218,8 +234,9 @@ absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> de
|
||||
const BHWC& dims = outputDimension.second;
|
||||
const NSUInteger outputDataSize =
|
||||
dims.b * dims.w * dims.h * AlignByN(dims.c, 4) * sizeof(float);
|
||||
outputBuffers[key] = [device newBufferWithLength:outputDataSize
|
||||
options:MTLResourceStorageModeShared];
|
||||
outputBuffers[key] =
|
||||
[device newBufferWithLength:outputDataSize
|
||||
options:MTLResourceStorageModeShared];
|
||||
}
|
||||
|
||||
// Inference itself.
|
||||
@ -228,7 +245,8 @@ absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> de
|
||||
inputOutputBuffers.insert(outputBuffers.begin(), outputBuffers.end());
|
||||
id<MTLCommandQueue> commandQueue = [device newCommandQueue];
|
||||
id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
|
||||
id<MTLComputeCommandEncoder> commandEncoder = [commandBuffer computeCommandEncoder];
|
||||
id<MTLComputeCommandEncoder> commandEncoder =
|
||||
[commandBuffer computeCommandEncoder];
|
||||
inference_context.EncodeWithEncoder(commandEncoder, inputOutputBuffers);
|
||||
[commandEncoder endEncoding];
|
||||
[commandBuffer commit];
|
||||
@ -241,9 +259,134 @@ absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes, id<MTLDevice> de
|
||||
auto& dst = output.second;
|
||||
dst.shape = dim;
|
||||
dst.data.resize(dst.shape.DimensionsProduct());
|
||||
float* outputPointer = reinterpret_cast<float*>([outputBuffers[output.first] contents]);
|
||||
RETURN_IF_ERROR(ConvertFromPHWC4(absl::MakeConstSpan(outputPointer, elementsCount), dst.shape,
|
||||
absl::MakeSpan(dst.data)));
|
||||
float* outputPointer =
|
||||
reinterpret_cast<float*>([outputBuffers[output.first] contents]);
|
||||
RETURN_IF_ERROR(
|
||||
ConvertFromPHWC4(absl::MakeConstSpan(outputPointer, elementsCount),
|
||||
dst.shape, absl::MakeSpan(dst.data)));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MetalExecutionEnvironment::MetalExecutionEnvironment() {
|
||||
device_ = MTLCreateSystemDefaultDevice();
|
||||
std::string device_name = std::string([[device_ name] UTF8String]);
|
||||
GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info_);
|
||||
}
|
||||
|
||||
std::vector<CalculationsPrecision>
|
||||
MetalExecutionEnvironment::GetSupportedPrecisions() const {
|
||||
return {CalculationsPrecision::F32, CalculationsPrecision::F32_F16,
|
||||
CalculationsPrecision::F16};
|
||||
}
|
||||
|
||||
std::vector<TensorStorageType> MetalExecutionEnvironment::GetSupportedStorages()
|
||||
const {
|
||||
return {TensorStorageType::BUFFER};
|
||||
}
|
||||
|
||||
// returns storage types that support zero clamping when reading OOB in HW
|
||||
// (Height/Width) dimensions.
|
||||
std::vector<TensorStorageType>
|
||||
MetalExecutionEnvironment::GetSupportedStoragesWithHWZeroClampSupport() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
absl::Status MetalExecutionEnvironment::ExecuteGPUOperation(
|
||||
const std::vector<TensorFloat32>& src_cpu,
|
||||
std::unique_ptr<ComputeTaskDescriptor>&& operation,
|
||||
const std::vector<BHWC>& dst_sizes,
|
||||
const std::vector<TensorFloat32*>& dst_cpu) {
|
||||
const OperationDef op_def = operation->definition;
|
||||
std::vector<MetalSpatialTensor> src(src_cpu.size());
|
||||
for (int i = 0; i < src_cpu.size(); ++i) {
|
||||
auto src_shape = src_cpu[i].shape;
|
||||
if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(
|
||||
CreateTensor(device_, src_shape, op_def.src_tensors[i], &src[i]));
|
||||
RETURN_IF_ERROR(src[i].WriteData(src_cpu[i]));
|
||||
}
|
||||
|
||||
std::vector<MetalSpatialTensor> dst(dst_cpu.size());
|
||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||
auto dst_shape = dst_sizes[i];
|
||||
if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||
}
|
||||
RETURN_IF_ERROR(
|
||||
CreateTensor(device_, dst_shape, op_def.dst_tensors[i], &dst[i]));
|
||||
}
|
||||
|
||||
std::map<ValueId, BHWC> tensor_shapes;
|
||||
NodeDescriptor metal_node;
|
||||
metal_node.task = std::move(operation);
|
||||
metal_node.src_tensors_ids.resize(src_cpu.size());
|
||||
for (int i = 0; i < src_cpu.size(); ++i) {
|
||||
metal_node.src_tensors_ids[i] = i;
|
||||
tensor_shapes[i] = src_cpu[i].shape;
|
||||
}
|
||||
metal_node.dst_tensors_ids.resize(dst_cpu.size());
|
||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||
metal_node.dst_tensors_ids[i] = src_cpu.size() + i;
|
||||
tensor_shapes[src_cpu.size() + i] = dst_sizes[i];
|
||||
}
|
||||
metal_node.description = "test_op";
|
||||
metal_node.id = 0;
|
||||
|
||||
std::string buffer_declarations;
|
||||
int index = 0;
|
||||
for (int i = 0; i < metal_node.task->dst_tensors_names.size(); ++i) {
|
||||
buffer_declarations += metal_node.task->dst_tensors_names[i] + "[[buffer(" +
|
||||
std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
for (int i = 0; i < metal_node.task->src_tensors_names.size(); ++i) {
|
||||
buffer_declarations += metal_node.task->src_tensors_names[i] + "[[buffer(" +
|
||||
std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
for (const auto& buffer : metal_node.task->immutable_buffers) {
|
||||
buffer_declarations +=
|
||||
buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
for (const auto& buffer : metal_node.task->uniform_buffers) {
|
||||
buffer_declarations +=
|
||||
buffer.declaration + "[[buffer(" + std::to_string(index) + ")]],\n";
|
||||
index++;
|
||||
}
|
||||
|
||||
metal_node.task->shader_source = absl::Substitute(
|
||||
metal_node.task->shader_source, "$0", buffer_declarations + "$1", "");
|
||||
|
||||
ComputeTask gpu_task;
|
||||
RETURN_IF_ERROR(
|
||||
gpu_task.CompileWithDevice(device_, metal_node, op_def.precision));
|
||||
RETURN_IF_ERROR(gpu_task.UpdateParamsWithDevice(device_, tensor_shapes));
|
||||
for (int i = 0; i < src_cpu.size(); ++i) {
|
||||
gpu_task.SetSrcTensor(src[i], i);
|
||||
}
|
||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||
gpu_task.SetDstTensor(dst[i], i);
|
||||
}
|
||||
|
||||
id<MTLCommandQueue> command_queue = [device_ newCommandQueue];
|
||||
id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
|
||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
||||
gpu_task.EncodeWithEncoder(encoder);
|
||||
[encoder endEncoding];
|
||||
[command_buffer commit];
|
||||
[command_buffer waitUntilCompleted];
|
||||
|
||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||
dst_cpu[i]->shape = dst_sizes[i];
|
||||
dst_cpu[i]->data = std::vector<float>(dst_sizes[i].DimensionsProduct(), 0);
|
||||
RETURN_IF_ERROR(dst[i].ReadData(dst_cpu[i]));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TEST_UTIL_H_
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
@ -67,6 +69,48 @@ absl::Status RunGraph(const std::vector<NodeDescriptor>& nodes,
|
||||
const std::map<ValueId, TensorFloat32>& inputs,
|
||||
std::map<ValueId, TensorFloat32>* outputs);
|
||||
|
||||
class MetalExecutionEnvironment {
|
||||
public:
|
||||
MetalExecutionEnvironment();
|
||||
~MetalExecutionEnvironment() = default;
|
||||
|
||||
std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
|
||||
std::vector<TensorStorageType> GetSupportedStorages() const;
|
||||
// returns storage types that support zero clamping when reading OOB in HW
|
||||
// (Height/Width) dimensions.
|
||||
std::vector<TensorStorageType> GetSupportedStoragesWithHWZeroClampSupport()
|
||||
const;
|
||||
|
||||
const GpuInfo& GetGpuInfo() const { return gpu_info_; }
|
||||
|
||||
absl::Status ExecuteGPUOperation(
|
||||
const std::vector<TensorFloat32>& src_cpu,
|
||||
std::unique_ptr<ComputeTaskDescriptor>&& operation,
|
||||
const std::vector<BHWC>& dst_sizes,
|
||||
const std::vector<TensorFloat32*>& dst_cpu);
|
||||
|
||||
absl::Status ExecuteGPUOperation(
|
||||
const TensorFloat32& src_cpu,
|
||||
std::unique_ptr<ComputeTaskDescriptor>&& operation, const BHWC& dst_size,
|
||||
TensorFloat32* result) {
|
||||
return ExecuteGPUOperation(std::vector<TensorFloat32>{src_cpu},
|
||||
std::move(operation), dst_size, result);
|
||||
}
|
||||
|
||||
absl::Status ExecuteGPUOperation(
|
||||
const std::vector<TensorFloat32>& src_cpu,
|
||||
std::unique_ptr<ComputeTaskDescriptor>&& operation, const BHWC& dst_size,
|
||||
TensorFloat32* result) {
|
||||
return ExecuteGPUOperation(
|
||||
std::vector<TensorFloat32>{src_cpu}, std::move(operation),
|
||||
std::vector<BHWC>{dst_size}, std::vector<TensorFloat32*>{result});
|
||||
}
|
||||
|
||||
private:
|
||||
id<MTLDevice> device_;
|
||||
GpuInfo gpu_info_;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -92,22 +92,15 @@ using ::tflite::gpu::metal::CompareVectors;
|
||||
op_def.src_tensors.push_back(tensor_descriptor);
|
||||
op_def.dst_tensors.push_back(tensor_descriptor);
|
||||
auto gpu_op = tflite::gpu::metal::Winograd4x4To36(op_def, attr);
|
||||
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
nodes[0].src_tensors_ids = {0};
|
||||
nodes[0].dst_tensors_ids = {1};
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs;
|
||||
inputs[0] = src_tensor;
|
||||
std::map<ValueId, TensorFloat32> outputs;
|
||||
outputs[1].shape = BHWC(1, 36, 1, 1);
|
||||
outputs[1].data.resize(36, 0.0f);
|
||||
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
auto status = RunGraph(nodes, device, inputs, &outputs);
|
||||
tflite::gpu::metal::MetalExecutionEnvironment env;
|
||||
auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
TensorFloat32 gpu_output;
|
||||
auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
|
||||
BHWC(1, 36, 1, 1), &gpu_output);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
|
||||
status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
@ -162,22 +155,15 @@ using ::tflite::gpu::metal::CompareVectors;
|
||||
op_def.src_tensors.push_back(tensor_descriptor);
|
||||
op_def.dst_tensors.push_back(tensor_descriptor);
|
||||
auto gpu_op = tflite::gpu::metal::Winograd4x4To36TileX6(op_def, attr);
|
||||
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
nodes[0].src_tensors_ids = {0};
|
||||
nodes[0].dst_tensors_ids = {1};
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs;
|
||||
inputs[0] = src_tensor;
|
||||
std::map<ValueId, TensorFloat32> outputs;
|
||||
outputs[1].shape = BHWC(1, 36, 1, 1);
|
||||
outputs[1].data.resize(36, 0.0f);
|
||||
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
auto status = RunGraph(nodes, device, inputs, &outputs);
|
||||
tflite::gpu::metal::MetalExecutionEnvironment env;
|
||||
auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
TensorFloat32 gpu_output;
|
||||
auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
|
||||
BHWC(1, 36, 1, 1), &gpu_output);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
|
||||
status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
@ -233,22 +219,15 @@ using ::tflite::gpu::metal::CompareVectors;
|
||||
op_def.src_tensors.push_back(tensor_descriptor);
|
||||
op_def.dst_tensors.push_back(tensor_descriptor);
|
||||
auto gpu_op = tflite::gpu::metal::Winograd36To4x4(op_def, attr);
|
||||
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
nodes[0].src_tensors_ids = {0};
|
||||
nodes[0].dst_tensors_ids = {1};
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs;
|
||||
inputs[0] = src_tensor;
|
||||
std::map<ValueId, TensorFloat32> outputs;
|
||||
outputs[1].shape = BHWC(1, 4, 4, 1);
|
||||
outputs[1].data.resize(16, 0.0f);
|
||||
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
auto status = RunGraph(nodes, device, inputs, &outputs);
|
||||
tflite::gpu::metal::MetalExecutionEnvironment env;
|
||||
auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
TensorFloat32 gpu_output;
|
||||
auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
|
||||
BHWC(1, 4, 4, 1), &gpu_output);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-5f);
|
||||
status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-5f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
@ -304,22 +283,15 @@ using ::tflite::gpu::metal::CompareVectors;
|
||||
op_def.src_tensors.push_back(tensor_descriptor);
|
||||
op_def.dst_tensors.push_back(tensor_descriptor);
|
||||
auto gpu_op = tflite::gpu::metal::Winograd36To4x4Tile4x1(op_def, attr);
|
||||
std::vector<tflite::gpu::metal::NodeDescriptor> nodes(1);
|
||||
nodes[0].task = std::make_shared<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
nodes[0].src_tensors_ids = {0};
|
||||
nodes[0].dst_tensors_ids = {1};
|
||||
|
||||
std::map<ValueId, TensorFloat32> inputs;
|
||||
inputs[0] = src_tensor;
|
||||
std::map<ValueId, TensorFloat32> outputs;
|
||||
outputs[1].shape = BHWC(1, 4, 4, 1);
|
||||
outputs[1].data.resize(16, 0.0f);
|
||||
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
auto status = RunGraph(nodes, device, inputs, &outputs);
|
||||
tflite::gpu::metal::MetalExecutionEnvironment env;
|
||||
auto op_ptr = absl::make_unique<tflite::gpu::metal::ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
TensorFloat32 gpu_output;
|
||||
auto status = env.ExecuteGPUOperation(src_tensor, std::move(op_ptr),
|
||||
BHWC(1, 4, 4, 1), &gpu_output);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
|
||||
status = CompareVectors(dst_tensor.data, outputs[1].data, 1e-6f);
|
||||
status = CompareVectors(dst_tensor.data, gpu_output.data, 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user