Added Buffer to Metal backend.

PiperOrigin-RevId: 339913277
Change-Id: If4a92b3fafe922b5abf00bc8dc04deef1ef6ca6d
This commit is contained in:
Raman Sarokin 2020-10-30 11:58:30 -07:00 committed by TensorFlower Gardener
parent baf2bfa96f
commit 1b098d27d3
4 changed files with 263 additions and 0 deletions

View File

@ -48,6 +48,32 @@ cc_library(
],
)
objc_library(
name = "buffer",
srcs = ["buffer.mm"],
hdrs = ["buffer.h"],
copts = DEFAULT_COPTS,
sdk_frameworks = ["Metal"],
deps = [
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/types:span",
],
)
objc_library(
name = "buffer_test_lib",
testonly = 1,
srcs = ["buffer_test.mm"],
sdk_frameworks = [
"XCTest",
"Metal",
],
deps = [
":buffer",
"//tensorflow/lite/delegates/gpu/common:types",
],
)
objc_library(
name = "buffer_convert",
srcs = ["buffer_convert.mm"],
@ -285,6 +311,7 @@ objc_library(
name = "common_tests_lib",
testonly = 1,
srcs = [
"//tensorflow/lite/delegates/gpu/metal:buffer_test.mm",
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
"//tensorflow/lite/delegates/gpu/metal:compiled_model_test.mm",
"//tensorflow/lite/delegates/gpu/metal:inference_context_test.mm",
@ -293,6 +320,8 @@ objc_library(
],
sdk_frameworks = ["XCTest"],
deps = [
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/metal:buffer",
"//tensorflow/lite/delegates/gpu/metal:common",
"//tensorflow/lite/delegates/gpu/metal:environment",
"//tensorflow/lite/delegates/gpu/metal:inference_context",

View File

@ -0,0 +1,95 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the Licensgoe is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
#include <string>
#include <vector>
#import <Metal/Metal.h>
#include "absl/types/span.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace metal {
class Buffer {
public:
Buffer() {} // just for using Buffer as a class members
Buffer(id<MTLBuffer> buffer, size_t size_in_bytes);
// Move only
Buffer(Buffer&& buffer);
Buffer& operator=(Buffer&& buffer);
Buffer(const Buffer&) = delete;
Buffer& operator=(const Buffer&) = delete;
~Buffer();
// for profiling and memory statistics
uint64_t GetMemorySizeInBytes() const { return size_; }
id<MTLBuffer> GetMemoryPtr() const { return buffer_; }
// Writes data to a buffer. Data should point to a region that
// has exact size in bytes as size_in_bytes(constructor parameter).
template <typename T>
absl::Status WriteData(const absl::Span<T> data);
// Reads data from Buffer into CPU memory.
template <typename T>
absl::Status ReadData(std::vector<T>* result) const;
private:
void Release();
id<MTLBuffer> buffer_ = nullptr;
size_t size_;
};
absl::Status CreateBuffer(size_t size_in_bytes, const void* data, id<MTLDevice> device,
Buffer* result);
template <typename T>
absl::Status Buffer::WriteData(const absl::Span<T> data) {
if (size_ != sizeof(T) * data.size()) {
return absl::InvalidArgumentError(
"absl::Span<T> data size is different from buffer allocated size.");
}
std::memcpy([buffer_ contents], data.data(), size_);
return absl::OkStatus();
}
template <typename T>
absl::Status Buffer::ReadData(std::vector<T>* result) const {
if (size_ % sizeof(T) != 0) {
return absl::UnknownError("Wrong element size(typename T is not correct?");
}
const int elements_count = size_ / sizeof(T);
result->resize(elements_count);
std::memcpy(result->data(), [buffer_ contents], size_);
return absl::OkStatus();
}
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_

View File

@ -0,0 +1,69 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
#include <utility>
namespace tflite {
namespace gpu {
namespace metal {
Buffer::Buffer(id<MTLBuffer> buffer, size_t size_in_bytes)
: buffer_(buffer), size_(size_in_bytes) {}
Buffer::Buffer(Buffer&& buffer) : buffer_(buffer.buffer_), size_(buffer.size_) {
buffer.buffer_ = nullptr;
buffer.size_ = 0;
}
Buffer& Buffer::operator=(Buffer&& buffer) {
if (this != &buffer) {
Release();
std::swap(size_, buffer.size_);
std::swap(buffer_, buffer.buffer_);
}
return *this;
}
Buffer::~Buffer() { Release(); }
void Buffer::Release() {
if (buffer_) {
buffer_ = nullptr;
size_ = 0;
}
}
absl::Status CreateBuffer(size_t size_in_bytes, const void* data,
id<MTLDevice> device, Buffer* result) {
id<MTLBuffer> buffer;
if (data) {
buffer = [device newBufferWithBytes:data
length:size_in_bytes
options:MTLResourceStorageModeShared];
} else {
buffer = [device newBufferWithLength:size_in_bytes
options:MTLResourceStorageModeShared];
}
*result = Buffer(buffer, size_in_bytes);
return absl::OkStatus();
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#import <XCTest/XCTest.h>
#import <Metal/Metal.h>
#include <vector>
#include <iostream>
@interface BufferTest : XCTestCase
@end
@implementation BufferTest
- (void)setUp {
[super setUp];
}
using tflite::gpu::half;
- (void)testBufferF32 {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
const std::vector<float> data = {1.0f, 2.0f, 3.0f, -4.0f, 5.1f};
tflite::gpu::metal::Buffer buffer;
XCTAssertTrue(tflite::gpu::metal::CreateBuffer(sizeof(float) * 5, nullptr, device, &buffer).ok());
XCTAssertTrue(buffer.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
std::vector<float> gpu_data;
XCTAssertTrue(buffer.ReadData<float>(&gpu_data).ok());
XCTAssertEqual(gpu_data.size(), data.size());
for (int i = 0; i < gpu_data.size(); ++i) {
XCTAssertEqual(gpu_data[i], data[i]);
}
}
- (void)testBufferF16 {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
const std::vector<half> data = {half(1.0f), half(2.0f), half(3.0f), half(-4.0f), half(5.1f)};
tflite::gpu::metal::Buffer buffer;
XCTAssertTrue(tflite::gpu::metal::CreateBuffer(
sizeof(tflite::gpu::half) * 5, nullptr, device, &buffer).ok());
XCTAssertTrue(buffer.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
std::vector<half> gpu_data;
XCTAssertTrue(buffer.ReadData<half>(&gpu_data).ok());
XCTAssertEqual(gpu_data.size(), data.size());
for (int i = 0; i < gpu_data.size(); ++i) {
XCTAssertEqual(gpu_data[i], data[i]);
}
}
@end