Added Texture2D to Metal.
PiperOrigin-RevId: 351820359 Change-Id: I941c9508eab3f52d1192182c70e7b3bd8b6ceb3f
This commit is contained in:
parent
5efecaf1a1
commit
fcc9c486b5
tensorflow/lite/delegates/gpu/metal
@ -73,7 +73,7 @@ objc_library(
|
|||||||
"Metal",
|
"Metal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/delegates/gpu/common:gpu_info",
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -196,6 +196,7 @@ objc_library(
|
|||||||
":buffer",
|
":buffer",
|
||||||
":gpu_object",
|
":gpu_object",
|
||||||
":metal_spatial_tensor",
|
":metal_spatial_tensor",
|
||||||
|
":texture2d",
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
"//tensorflow/lite/delegates/gpu/common:util",
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
"//tensorflow/lite/delegates/gpu/common/task:arguments",
|
"//tensorflow/lite/delegates/gpu/common/task:arguments",
|
||||||
@ -250,6 +251,38 @@ objc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "texture2d",
|
||||||
|
srcs = ["texture2d.cc"],
|
||||||
|
hdrs = ["texture2d.h"],
|
||||||
|
copts = DEFAULT_COPTS + [
|
||||||
|
"-ObjC++",
|
||||||
|
],
|
||||||
|
sdk_frameworks = ["Metal"],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":gpu_object",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "texture2d_test_lib",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["texture2d_test.mm"],
|
||||||
|
sdk_frameworks = [
|
||||||
|
"XCTest",
|
||||||
|
"Metal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":texture2d",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "TestBinary",
|
name = "TestBinary",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -282,6 +315,7 @@ objc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/metal:buffer_test.mm",
|
"//tensorflow/lite/delegates/gpu/metal:buffer_test.mm",
|
||||||
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
|
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
|
||||||
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor_test.mm",
|
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor_test.mm",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:texture2d_test.mm",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
],
|
],
|
||||||
@ -293,6 +327,7 @@ objc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/metal:common",
|
"//tensorflow/lite/delegates/gpu/metal:common",
|
||||||
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||||
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor",
|
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:texture2d",
|
||||||
"//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
|
"//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -43,6 +44,9 @@ absl::Status CreateComputeProgram(id<MTLDevice> device, NSString* code, NSString
|
|||||||
NSDictionary<NSString*, NSString*>* macros,
|
NSDictionary<NSString*, NSString*>* macros,
|
||||||
id<MTLComputePipelineState>* program);
|
id<MTLComputePipelineState>* program);
|
||||||
|
|
||||||
|
int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format);
|
||||||
|
MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized = false);
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -90,6 +90,49 @@ absl::Status CreateComputeProgram(id<MTLDevice> device, NSString* code, NSString
|
|||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format) {
|
||||||
|
if (pixel_format == MTLPixelFormatRGBA32Uint ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA32Sint ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA32Float) {
|
||||||
|
return 16;
|
||||||
|
} else if (pixel_format == MTLPixelFormatRGBA16Unorm ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA16Snorm ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA16Uint ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA16Sint ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA16Float) {
|
||||||
|
return 8;
|
||||||
|
} else if (pixel_format == MTLPixelFormatRGBA8Unorm ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA8Snorm ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA8Uint ||
|
||||||
|
pixel_format == MTLPixelFormatRGBA8Sint) {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized) {
|
||||||
|
switch (type) {
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return MTLPixelFormatRGBA32Float;
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
return MTLPixelFormatRGBA16Float;
|
||||||
|
case DataType::INT8:
|
||||||
|
return normalized ? MTLPixelFormatRGBA8Snorm : MTLPixelFormatRGBA8Sint;
|
||||||
|
case DataType::UINT8:
|
||||||
|
return normalized ? MTLPixelFormatRGBA8Unorm : MTLPixelFormatRGBA8Uint;
|
||||||
|
case DataType::INT16:
|
||||||
|
return normalized ? MTLPixelFormatRGBA16Snorm : MTLPixelFormatRGBA16Sint;
|
||||||
|
case DataType::UINT16:
|
||||||
|
return normalized ? MTLPixelFormatRGBA16Unorm : MTLPixelFormatRGBA16Uint;
|
||||||
|
case DataType::INT32:
|
||||||
|
return MTLPixelFormatRGBA32Sint;
|
||||||
|
case DataType::UINT32:
|
||||||
|
return MTLPixelFormatRGBA32Uint;
|
||||||
|
default:
|
||||||
|
return MTLPixelFormatInvalid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
|
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
@ -129,6 +130,15 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
|
|||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(desc);
|
||||||
|
if (texture_desc) {
|
||||||
|
Texture2D gpu_texture;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
gpu_texture.CreateFromTexture2DDescriptor(*texture_desc, device));
|
||||||
|
*result = absl::make_unique<Texture2D>(std::move(gpu_texture));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
|
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
|
||||||
if (tensor_desc) {
|
if (tensor_desc) {
|
||||||
MetalSpatialTensor gpu_tensor;
|
MetalSpatialTensor gpu_tensor;
|
||||||
|
179
tensorflow/lite/delegates/gpu/metal/texture2d.cc
Normal file
179
tensorflow/lite/delegates/gpu/metal/texture2d.cc
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Creates new 4-channel 2D texture with cl_channel_type elements
|
||||||
|
absl::Status CreateTexture2D(int width, int height, DataType type, void* data,
|
||||||
|
id<MTLDevice> device, Texture2D* result) {
|
||||||
|
MTLPixelFormat pixel_format = DataTypeToRGBAPixelFormat(type);
|
||||||
|
|
||||||
|
MTLTextureDescriptor* texture_desc =
|
||||||
|
[MTLTextureDescriptor texture2DDescriptorWithPixelFormat:pixel_format
|
||||||
|
width:width
|
||||||
|
height:height
|
||||||
|
mipmapped:NO];
|
||||||
|
texture_desc.textureType = MTLTextureType2D;
|
||||||
|
texture_desc.usage = MTLTextureUsageShaderRead;
|
||||||
|
texture_desc.storageMode = MTLStorageModePrivate;
|
||||||
|
|
||||||
|
id<MTLTexture> texture = [device newTextureWithDescriptor:texture_desc];
|
||||||
|
if (!texture) {
|
||||||
|
return absl::UnknownError("Failed to allocate id<MTLTexture>");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data) {
|
||||||
|
MTLRegion region = {
|
||||||
|
{0, 0, 0},
|
||||||
|
{static_cast<NSUInteger>(width), static_cast<NSUInteger>(height), 1}};
|
||||||
|
const int pixel_size = PixelFormatToSizeInBytes(pixel_format);
|
||||||
|
[texture replaceRegion:region
|
||||||
|
mipmapLevel:0
|
||||||
|
withBytes:data
|
||||||
|
bytesPerRow:width * pixel_size];
|
||||||
|
|
||||||
|
if (!texture) {
|
||||||
|
return absl::UnknownError("Failed to upload data to id<MTLTexture>");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*result = Texture2D(texture, width, height, pixel_format);
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Texture2D::Texture2D(id<MTLTexture> texture, int width, int height,
|
||||||
|
MTLPixelFormat pixel_format)
|
||||||
|
: texture_(texture),
|
||||||
|
width_(width),
|
||||||
|
height_(height),
|
||||||
|
pixel_format_(pixel_format) {}
|
||||||
|
|
||||||
|
Texture2D::Texture2D(Texture2D&& texture)
|
||||||
|
: texture_(texture.texture_),
|
||||||
|
width_(texture.width_),
|
||||||
|
height_(texture.height_),
|
||||||
|
pixel_format_(texture.pixel_format_) {
|
||||||
|
texture.texture_ = nullptr;
|
||||||
|
texture.width_ = 0;
|
||||||
|
texture.height_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Texture2D& Texture2D::operator=(Texture2D&& texture) {
|
||||||
|
if (this != &texture) {
|
||||||
|
Release();
|
||||||
|
std::swap(pixel_format_, texture.pixel_format_);
|
||||||
|
std::swap(width_, texture.width_);
|
||||||
|
std::swap(height_, texture.height_);
|
||||||
|
std::swap(texture_, texture.texture_);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Texture2D::Release() {
|
||||||
|
if (texture_) {
|
||||||
|
texture_ = nullptr;
|
||||||
|
width_ = 0;
|
||||||
|
height_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Texture2D::GetGPUResources(
|
||||||
|
const GPUObjectDescriptor* obj_ptr,
|
||||||
|
GPUResourcesWithValue* resources) const {
|
||||||
|
const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(obj_ptr);
|
||||||
|
if (!texture_desc) {
|
||||||
|
return absl::InvalidArgumentError("Expected Texture2DDescriptor on input.");
|
||||||
|
}
|
||||||
|
|
||||||
|
resources->images2d.push_back({"tex2d", texture_});
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Texture2D::CreateFromTexture2DDescriptor(
|
||||||
|
const Texture2DDescriptor& desc, id<MTLDevice> device) {
|
||||||
|
width_ = desc.size.x;
|
||||||
|
height_ = desc.size.y;
|
||||||
|
pixel_format_ = DataTypeToRGBAPixelFormat(desc.element_type, desc.normalized);
|
||||||
|
uint8_t* data_ptr = desc.data.empty()
|
||||||
|
? nullptr
|
||||||
|
: const_cast<unsigned char*>(desc.data.data());
|
||||||
|
|
||||||
|
MTLTextureDescriptor* texture_desc =
|
||||||
|
[MTLTextureDescriptor texture2DDescriptorWithPixelFormat:pixel_format_
|
||||||
|
width:width_
|
||||||
|
height:height_
|
||||||
|
mipmapped:NO];
|
||||||
|
texture_desc.textureType = MTLTextureType2D;
|
||||||
|
texture_desc.usage = MTLTextureUsageShaderRead;
|
||||||
|
texture_desc.storageMode = MTLStorageModePrivate;
|
||||||
|
|
||||||
|
texture_ = [device newTextureWithDescriptor:texture_desc];
|
||||||
|
if (!texture_) {
|
||||||
|
return absl::UnknownError("Failed to allocate id<MTLTexture>");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data_ptr) {
|
||||||
|
MTLRegion region = {
|
||||||
|
{0, 0, 0},
|
||||||
|
{static_cast<NSUInteger>(width_), static_cast<NSUInteger>(height_), 1}};
|
||||||
|
const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
|
||||||
|
[texture_ replaceRegion:region
|
||||||
|
mipmapLevel:0
|
||||||
|
withBytes:data_ptr
|
||||||
|
bytesPerRow:width_ * pixel_size];
|
||||||
|
|
||||||
|
if (!texture_) {
|
||||||
|
return absl::UnknownError("Failed to upload data to id<MTLTexture>");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates new 4-channel 2D texture with f32 elements
|
||||||
|
absl::Status CreateTexture2DRGBA32F(int width, int height, id<MTLDevice> device,
|
||||||
|
Texture2D* result) {
|
||||||
|
return CreateTexture2D(width, height, DataType::FLOAT32, nullptr, device,
|
||||||
|
result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates new 4-channel 2D texture with f16 elements
|
||||||
|
absl::Status CreateTexture2DRGBA16F(int width, int height, id<MTLDevice> device,
|
||||||
|
Texture2D* result) {
|
||||||
|
return CreateTexture2D(width, height, DataType::FLOAT16, nullptr, device,
|
||||||
|
result);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status CreateTexture2DRGBA(DataType type, int width, int height,
|
||||||
|
id<MTLDevice> device, Texture2D* result) {
|
||||||
|
return CreateTexture2D(width, height, type, nullptr, device, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status CreateTexture2DRGBA(DataType type, int width, int height,
|
||||||
|
void* data, id<MTLDevice> device,
|
||||||
|
Texture2D* result) {
|
||||||
|
return CreateTexture2D(width, height, type, data, device, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
122
tensorflow/lite/delegates/gpu/metal/texture2d.h
Normal file
122
tensorflow/lite/delegates/gpu/metal/texture2d.h
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
|
||||||
|
|
||||||
|
#import <Metal/Metal.h>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
// Texture2D represent formatted GPU data storage.
|
||||||
|
// Texture2D is moveable but not copyable.
|
||||||
|
class Texture2D : public GPUObject {
|
||||||
|
public:
|
||||||
|
Texture2D() {} // just for using Texture2D as a class members
|
||||||
|
Texture2D(id<MTLTexture> texture, int width, int height, MTLPixelFormat pixel_format);
|
||||||
|
|
||||||
|
// Move only
|
||||||
|
Texture2D(Texture2D&& texture);
|
||||||
|
Texture2D& operator=(Texture2D&& texture);
|
||||||
|
Texture2D(const Texture2D&) = delete;
|
||||||
|
Texture2D& operator=(const Texture2D&) = delete;
|
||||||
|
|
||||||
|
~Texture2D() override { Release(); }
|
||||||
|
|
||||||
|
// Writes data to a texture. Data should point to a region that
|
||||||
|
// has exact width * height * sizeof(pixel) bytes.
|
||||||
|
template <typename T>
|
||||||
|
absl::Status WriteData(const absl::Span<T> data);
|
||||||
|
|
||||||
|
// Reads data from Texture2D into CPU memory.
|
||||||
|
template <typename T>
|
||||||
|
absl::Status ReadData(std::vector<T>* result) const;
|
||||||
|
|
||||||
|
absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
|
||||||
|
GPUResourcesWithValue* resources) const override;
|
||||||
|
|
||||||
|
absl::Status CreateFromTexture2DDescriptor(const Texture2DDescriptor& desc, id<MTLDevice> device);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Release();
|
||||||
|
|
||||||
|
id<MTLTexture> texture_ = nullptr;
|
||||||
|
int width_;
|
||||||
|
int height_;
|
||||||
|
MTLPixelFormat pixel_format_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates new 4-channel 2D texture with f32 elements
|
||||||
|
absl::Status CreateTexture2DRGBA32F(int width, int height, id<MTLDevice> device, Texture2D* result);
|
||||||
|
|
||||||
|
// Creates new 4-channel 2D texture with f16 elements
|
||||||
|
absl::Status CreateTexture2DRGBA16F(int width, int height, id<MTLDevice> device, Texture2D* result);
|
||||||
|
|
||||||
|
absl::Status CreateTexture2DRGBA(DataType type, int width, int height, id<MTLDevice> device,
|
||||||
|
Texture2D* result);
|
||||||
|
|
||||||
|
absl::Status CreateTexture2DRGBA(DataType type, int width, int height, void* data,
|
||||||
|
id<MTLDevice> device, Texture2D* result);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::Status Texture2D::WriteData(const absl::Span<T> data) {
|
||||||
|
const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
|
||||||
|
if (width_ * height_ * pixel_size != data.size() * sizeof(T)) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"absl::Span<T> data size is different from texture allocated size.");
|
||||||
|
}
|
||||||
|
|
||||||
|
MTLRegion region = {{0, 0, 0},
|
||||||
|
{static_cast<NSUInteger>(width_), static_cast<NSUInteger>(height_), 1}};
|
||||||
|
[texture_ replaceRegion:region
|
||||||
|
mipmapLevel:0
|
||||||
|
withBytes:data.data()
|
||||||
|
bytesPerRow:width_ * pixel_size];
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::Status Texture2D::ReadData(std::vector<T>* result) const {
|
||||||
|
const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
|
||||||
|
if (pixel_size % sizeof(T) != 0) {
|
||||||
|
return absl::InvalidArgumentError("Pixel format is different.");
|
||||||
|
}
|
||||||
|
result->resize(width_ * height_ * (pixel_size / sizeof(T)));
|
||||||
|
|
||||||
|
MTLRegion region = {{0, 0, 0},
|
||||||
|
{static_cast<NSUInteger>(width_), static_cast<NSUInteger>(height_), 1}};
|
||||||
|
[texture_ getBytes:result->data()
|
||||||
|
bytesPerRow:width_ * pixel_size
|
||||||
|
fromRegion:region
|
||||||
|
mipmapLevel:0];
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
|
70
tensorflow/lite/delegates/gpu/metal/texture2d_test.mm
Normal file
70
tensorflow/lite/delegates/gpu/metal/texture2d_test.mm
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
|
||||||
|
|
||||||
|
#import <Metal/Metal.h>
|
||||||
|
#import <XCTest/XCTest.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
@interface Texture2DTest : XCTestCase
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation Texture2DTest
|
||||||
|
- (void)setUp {
|
||||||
|
[super setUp];
|
||||||
|
}
|
||||||
|
|
||||||
|
using tflite::gpu::half;
|
||||||
|
|
||||||
|
- (void)testTexture2DF32 {
|
||||||
|
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||||
|
|
||||||
|
const std::vector<float> data = {1.0, 2.0, 3.0, -4.0, 5.1, 6.7, 4.1, 6.17};
|
||||||
|
tflite::gpu::metal::Texture2D texture;
|
||||||
|
XCTAssertTrue(tflite::gpu::metal::CreateTexture2DRGBA32F(1, 2, device, &texture).ok());
|
||||||
|
XCTAssertTrue(texture.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
|
||||||
|
std::vector<float> gpu_data;
|
||||||
|
XCTAssertTrue(texture.ReadData<float>(&gpu_data).ok());
|
||||||
|
|
||||||
|
XCTAssertEqual(gpu_data.size(), data.size());
|
||||||
|
for (int i = 0; i < gpu_data.size(); ++i) {
|
||||||
|
XCTAssertEqual(gpu_data[i], data[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testTexture2DF16 {
|
||||||
|
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||||
|
|
||||||
|
const std::vector<half> data = {half(1.4), half(2.1), half(2.2), half(1.34),
|
||||||
|
half(20.1), half(2.24), half(0.1), half(0.2)};
|
||||||
|
|
||||||
|
tflite::gpu::metal::Texture2D texture;
|
||||||
|
XCTAssertTrue(tflite::gpu::metal::CreateTexture2DRGBA16F(2, 1, device, &texture).ok());
|
||||||
|
XCTAssertTrue(texture.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
|
||||||
|
std::vector<half> gpu_data;
|
||||||
|
XCTAssertTrue(texture.ReadData<half>(&gpu_data).ok());
|
||||||
|
|
||||||
|
XCTAssertEqual(gpu_data.size(), data.size());
|
||||||
|
for (int i = 0; i < gpu_data.size(); ++i) {
|
||||||
|
XCTAssertEqual(gpu_data[i], data[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
Loading…
Reference in New Issue
Block a user