Added Buffer to Metal backend.
PiperOrigin-RevId: 339913277 Change-Id: If4a92b3fafe922b5abf00bc8dc04deef1ef6ca6d
This commit is contained in:
parent
baf2bfa96f
commit
1b098d27d3
@ -48,6 +48,32 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "buffer",
|
||||
srcs = ["buffer.mm"],
|
||||
hdrs = ["buffer.h"],
|
||||
copts = DEFAULT_COPTS,
|
||||
sdk_frameworks = ["Metal"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "buffer_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["buffer_test.mm"],
|
||||
sdk_frameworks = [
|
||||
"XCTest",
|
||||
"Metal",
|
||||
],
|
||||
deps = [
|
||||
":buffer",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "buffer_convert",
|
||||
srcs = ["buffer_convert.mm"],
|
||||
@ -285,6 +311,7 @@ objc_library(
|
||||
name = "common_tests_lib",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"//tensorflow/lite/delegates/gpu/metal:buffer_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compiled_model_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context_test.mm",
|
||||
@ -293,6 +320,8 @@ objc_library(
|
||||
],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/metal:buffer",
|
||||
"//tensorflow/lite/delegates/gpu/metal:common",
|
||||
"//tensorflow/lite/delegates/gpu/metal:environment",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||
|
95
tensorflow/lite/delegates/gpu/metal/buffer.h
Normal file
95
tensorflow/lite/delegates/gpu/metal/buffer.h
Normal file
@ -0,0 +1,95 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the Licensgoe is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
class Buffer {
|
||||
public:
|
||||
Buffer() {} // just for using Buffer as a class members
|
||||
Buffer(id<MTLBuffer> buffer, size_t size_in_bytes);
|
||||
|
||||
// Move only
|
||||
Buffer(Buffer&& buffer);
|
||||
Buffer& operator=(Buffer&& buffer);
|
||||
Buffer(const Buffer&) = delete;
|
||||
Buffer& operator=(const Buffer&) = delete;
|
||||
|
||||
~Buffer();
|
||||
|
||||
// for profiling and memory statistics
|
||||
uint64_t GetMemorySizeInBytes() const { return size_; }
|
||||
|
||||
id<MTLBuffer> GetMemoryPtr() const { return buffer_; }
|
||||
|
||||
// Writes data to a buffer. Data should point to a region that
|
||||
// has exact size in bytes as size_in_bytes(constructor parameter).
|
||||
template <typename T>
|
||||
absl::Status WriteData(const absl::Span<T> data);
|
||||
|
||||
// Reads data from Buffer into CPU memory.
|
||||
template <typename T>
|
||||
absl::Status ReadData(std::vector<T>* result) const;
|
||||
|
||||
private:
|
||||
void Release();
|
||||
|
||||
id<MTLBuffer> buffer_ = nullptr;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
absl::Status CreateBuffer(size_t size_in_bytes, const void* data, id<MTLDevice> device,
|
||||
Buffer* result);
|
||||
|
||||
template <typename T>
|
||||
absl::Status Buffer::WriteData(const absl::Span<T> data) {
|
||||
if (size_ != sizeof(T) * data.size()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"absl::Span<T> data size is different from buffer allocated size.");
|
||||
}
|
||||
std::memcpy([buffer_ contents], data.data(), size_);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::Status Buffer::ReadData(std::vector<T>* result) const {
|
||||
if (size_ % sizeof(T) != 0) {
|
||||
return absl::UnknownError("Wrong element size(typename T is not correct?");
|
||||
}
|
||||
|
||||
const int elements_count = size_ / sizeof(T);
|
||||
result->resize(elements_count);
|
||||
std::memcpy(result->data(), [buffer_ contents], size_);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
|
69
tensorflow/lite/delegates/gpu/metal/buffer.mm
Normal file
69
tensorflow/lite/delegates/gpu/metal/buffer.mm
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
Buffer::Buffer(id<MTLBuffer> buffer, size_t size_in_bytes)
|
||||
: buffer_(buffer), size_(size_in_bytes) {}
|
||||
|
||||
Buffer::Buffer(Buffer&& buffer) : buffer_(buffer.buffer_), size_(buffer.size_) {
|
||||
buffer.buffer_ = nullptr;
|
||||
buffer.size_ = 0;
|
||||
}
|
||||
|
||||
Buffer& Buffer::operator=(Buffer&& buffer) {
|
||||
if (this != &buffer) {
|
||||
Release();
|
||||
std::swap(size_, buffer.size_);
|
||||
std::swap(buffer_, buffer.buffer_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Buffer::~Buffer() { Release(); }
|
||||
|
||||
void Buffer::Release() {
|
||||
if (buffer_) {
|
||||
buffer_ = nullptr;
|
||||
size_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status CreateBuffer(size_t size_in_bytes, const void* data,
|
||||
id<MTLDevice> device, Buffer* result) {
|
||||
id<MTLBuffer> buffer;
|
||||
if (data) {
|
||||
buffer = [device newBufferWithBytes:data
|
||||
length:size_in_bytes
|
||||
options:MTLResourceStorageModeShared];
|
||||
} else {
|
||||
buffer = [device newBufferWithLength:size_in_bytes
|
||||
options:MTLResourceStorageModeShared];
|
||||
}
|
||||
|
||||
*result = Buffer(buffer, size_in_bytes);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
70
tensorflow/lite/delegates/gpu/metal/buffer_test.mm
Normal file
70
tensorflow/lite/delegates/gpu/metal/buffer_test.mm
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
@interface BufferTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation BufferTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
using tflite::gpu::half;
|
||||
|
||||
- (void)testBufferF32 {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
const std::vector<float> data = {1.0f, 2.0f, 3.0f, -4.0f, 5.1f};
|
||||
tflite::gpu::metal::Buffer buffer;
|
||||
XCTAssertTrue(tflite::gpu::metal::CreateBuffer(sizeof(float) * 5, nullptr, device, &buffer).ok());
|
||||
XCTAssertTrue(buffer.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
|
||||
std::vector<float> gpu_data;
|
||||
XCTAssertTrue(buffer.ReadData<float>(&gpu_data).ok());
|
||||
|
||||
XCTAssertEqual(gpu_data.size(), data.size());
|
||||
for (int i = 0; i < gpu_data.size(); ++i) {
|
||||
XCTAssertEqual(gpu_data[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
- (void)testBufferF16 {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
const std::vector<half> data = {half(1.0f), half(2.0f), half(3.0f), half(-4.0f), half(5.1f)};
|
||||
tflite::gpu::metal::Buffer buffer;
|
||||
XCTAssertTrue(tflite::gpu::metal::CreateBuffer(
|
||||
sizeof(tflite::gpu::half) * 5, nullptr, device, &buffer).ok());
|
||||
XCTAssertTrue(buffer.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
|
||||
std::vector<half> gpu_data;
|
||||
XCTAssertTrue(buffer.ReadData<half>(&gpu_data).ok());
|
||||
|
||||
XCTAssertEqual(gpu_data.size(), data.size());
|
||||
for (int i = 0; i < gpu_data.size(); ++i) {
|
||||
XCTAssertEqual(gpu_data[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
Loading…
x
Reference in New Issue
Block a user