Added Texture2D to Metal.
PiperOrigin-RevId: 351820359 Change-Id: I941c9508eab3f52d1192182c70e7b3bd8b6ceb3f
This commit is contained in:
parent
5efecaf1a1
commit
fcc9c486b5
tensorflow/lite/delegates/gpu/metal
@ -73,7 +73,7 @@ objc_library(
|
||||
"Metal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:gpu_info",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
@ -196,6 +196,7 @@ objc_library(
|
||||
":buffer",
|
||||
":gpu_object",
|
||||
":metal_spatial_tensor",
|
||||
":texture2d",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:arguments",
|
||||
@ -250,6 +251,38 @@ objc_library(
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "texture2d",
|
||||
srcs = ["texture2d.cc"],
|
||||
hdrs = ["texture2d.h"],
|
||||
copts = DEFAULT_COPTS + [
|
||||
"-ObjC++",
|
||||
],
|
||||
sdk_frameworks = ["Metal"],
|
||||
deps = [
|
||||
":common",
|
||||
":gpu_object",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "texture2d_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["texture2d_test.mm"],
|
||||
sdk_frameworks = [
|
||||
"XCTest",
|
||||
"Metal",
|
||||
],
|
||||
deps = [
|
||||
":texture2d",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TestBinary",
|
||||
testonly = 1,
|
||||
@ -282,6 +315,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/metal:buffer_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:texture2d_test.mm",
|
||||
],
|
||||
hdrs = [
|
||||
],
|
||||
@ -293,6 +327,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/metal:common",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||
"//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor",
|
||||
"//tensorflow/lite/delegates/gpu/metal:texture2d",
|
||||
"//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -43,6 +44,9 @@ absl::Status CreateComputeProgram(id<MTLDevice> device, NSString* code, NSString
|
||||
NSDictionary<NSString*, NSString*>* macros,
|
||||
id<MTLComputePipelineState>* program);
|
||||
|
||||
int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format);
|
||||
MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized = false);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -90,6 +90,49 @@ absl::Status CreateComputeProgram(id<MTLDevice> device, NSString* code, NSString
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format) {
|
||||
if (pixel_format == MTLPixelFormatRGBA32Uint ||
|
||||
pixel_format == MTLPixelFormatRGBA32Sint ||
|
||||
pixel_format == MTLPixelFormatRGBA32Float) {
|
||||
return 16;
|
||||
} else if (pixel_format == MTLPixelFormatRGBA16Unorm ||
|
||||
pixel_format == MTLPixelFormatRGBA16Snorm ||
|
||||
pixel_format == MTLPixelFormatRGBA16Uint ||
|
||||
pixel_format == MTLPixelFormatRGBA16Sint ||
|
||||
pixel_format == MTLPixelFormatRGBA16Float) {
|
||||
return 8;
|
||||
} else if (pixel_format == MTLPixelFormatRGBA8Unorm ||
|
||||
pixel_format == MTLPixelFormatRGBA8Snorm ||
|
||||
pixel_format == MTLPixelFormatRGBA8Uint ||
|
||||
pixel_format == MTLPixelFormatRGBA8Sint) {
|
||||
return 4;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized) {
|
||||
switch (type) {
|
||||
case DataType::FLOAT32:
|
||||
return MTLPixelFormatRGBA32Float;
|
||||
case DataType::FLOAT16:
|
||||
return MTLPixelFormatRGBA16Float;
|
||||
case DataType::INT8:
|
||||
return normalized ? MTLPixelFormatRGBA8Snorm : MTLPixelFormatRGBA8Sint;
|
||||
case DataType::UINT8:
|
||||
return normalized ? MTLPixelFormatRGBA8Unorm : MTLPixelFormatRGBA8Uint;
|
||||
case DataType::INT16:
|
||||
return normalized ? MTLPixelFormatRGBA16Snorm : MTLPixelFormatRGBA16Sint;
|
||||
case DataType::UINT16:
|
||||
return normalized ? MTLPixelFormatRGBA16Unorm : MTLPixelFormatRGBA16Uint;
|
||||
case DataType::INT32:
|
||||
return MTLPixelFormatRGBA32Sint;
|
||||
case DataType::UINT32:
|
||||
return MTLPixelFormatRGBA32Uint;
|
||||
default:
|
||||
return MTLPixelFormatInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -129,6 +130,15 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(desc);
|
||||
if (texture_desc) {
|
||||
Texture2D gpu_texture;
|
||||
RETURN_IF_ERROR(
|
||||
gpu_texture.CreateFromTexture2DDescriptor(*texture_desc, device));
|
||||
*result = absl::make_unique<Texture2D>(std::move(gpu_texture));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
|
||||
if (tensor_desc) {
|
||||
MetalSpatialTensor gpu_tensor;
|
||||
|
179
tensorflow/lite/delegates/gpu/metal/texture2d.cc
Normal file
179
tensorflow/lite/delegates/gpu/metal/texture2d.cc
Normal file
@ -0,0 +1,179 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
// Creates new 4-channel 2D texture with cl_channel_type elements
|
||||
absl::Status CreateTexture2D(int width, int height, DataType type, void* data,
|
||||
id<MTLDevice> device, Texture2D* result) {
|
||||
MTLPixelFormat pixel_format = DataTypeToRGBAPixelFormat(type);
|
||||
|
||||
MTLTextureDescriptor* texture_desc =
|
||||
[MTLTextureDescriptor texture2DDescriptorWithPixelFormat:pixel_format
|
||||
width:width
|
||||
height:height
|
||||
mipmapped:NO];
|
||||
texture_desc.textureType = MTLTextureType2D;
|
||||
texture_desc.usage = MTLTextureUsageShaderRead;
|
||||
texture_desc.storageMode = MTLStorageModePrivate;
|
||||
|
||||
id<MTLTexture> texture = [device newTextureWithDescriptor:texture_desc];
|
||||
if (!texture) {
|
||||
return absl::UnknownError("Failed to allocate id<MTLTexture>");
|
||||
}
|
||||
|
||||
if (data) {
|
||||
MTLRegion region = {
|
||||
{0, 0, 0},
|
||||
{static_cast<NSUInteger>(width), static_cast<NSUInteger>(height), 1}};
|
||||
const int pixel_size = PixelFormatToSizeInBytes(pixel_format);
|
||||
[texture replaceRegion:region
|
||||
mipmapLevel:0
|
||||
withBytes:data
|
||||
bytesPerRow:width * pixel_size];
|
||||
|
||||
if (!texture) {
|
||||
return absl::UnknownError("Failed to upload data to id<MTLTexture>");
|
||||
}
|
||||
}
|
||||
|
||||
*result = Texture2D(texture, width, height, pixel_format);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Texture2D::Texture2D(id<MTLTexture> texture, int width, int height,
|
||||
MTLPixelFormat pixel_format)
|
||||
: texture_(texture),
|
||||
width_(width),
|
||||
height_(height),
|
||||
pixel_format_(pixel_format) {}
|
||||
|
||||
Texture2D::Texture2D(Texture2D&& texture)
|
||||
: texture_(texture.texture_),
|
||||
width_(texture.width_),
|
||||
height_(texture.height_),
|
||||
pixel_format_(texture.pixel_format_) {
|
||||
texture.texture_ = nullptr;
|
||||
texture.width_ = 0;
|
||||
texture.height_ = 0;
|
||||
}
|
||||
|
||||
Texture2D& Texture2D::operator=(Texture2D&& texture) {
|
||||
if (this != &texture) {
|
||||
Release();
|
||||
std::swap(pixel_format_, texture.pixel_format_);
|
||||
std::swap(width_, texture.width_);
|
||||
std::swap(height_, texture.height_);
|
||||
std::swap(texture_, texture.texture_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Texture2D::Release() {
|
||||
if (texture_) {
|
||||
texture_ = nullptr;
|
||||
width_ = 0;
|
||||
height_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status Texture2D::GetGPUResources(
|
||||
const GPUObjectDescriptor* obj_ptr,
|
||||
GPUResourcesWithValue* resources) const {
|
||||
const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(obj_ptr);
|
||||
if (!texture_desc) {
|
||||
return absl::InvalidArgumentError("Expected Texture2DDescriptor on input.");
|
||||
}
|
||||
|
||||
resources->images2d.push_back({"tex2d", texture_});
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Texture2D::CreateFromTexture2DDescriptor(
|
||||
const Texture2DDescriptor& desc, id<MTLDevice> device) {
|
||||
width_ = desc.size.x;
|
||||
height_ = desc.size.y;
|
||||
pixel_format_ = DataTypeToRGBAPixelFormat(desc.element_type, desc.normalized);
|
||||
uint8_t* data_ptr = desc.data.empty()
|
||||
? nullptr
|
||||
: const_cast<unsigned char*>(desc.data.data());
|
||||
|
||||
MTLTextureDescriptor* texture_desc =
|
||||
[MTLTextureDescriptor texture2DDescriptorWithPixelFormat:pixel_format_
|
||||
width:width_
|
||||
height:height_
|
||||
mipmapped:NO];
|
||||
texture_desc.textureType = MTLTextureType2D;
|
||||
texture_desc.usage = MTLTextureUsageShaderRead;
|
||||
texture_desc.storageMode = MTLStorageModePrivate;
|
||||
|
||||
texture_ = [device newTextureWithDescriptor:texture_desc];
|
||||
if (!texture_) {
|
||||
return absl::UnknownError("Failed to allocate id<MTLTexture>");
|
||||
}
|
||||
|
||||
if (data_ptr) {
|
||||
MTLRegion region = {
|
||||
{0, 0, 0},
|
||||
{static_cast<NSUInteger>(width_), static_cast<NSUInteger>(height_), 1}};
|
||||
const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
|
||||
[texture_ replaceRegion:region
|
||||
mipmapLevel:0
|
||||
withBytes:data_ptr
|
||||
bytesPerRow:width_ * pixel_size];
|
||||
|
||||
if (!texture_) {
|
||||
return absl::UnknownError("Failed to upload data to id<MTLTexture>");
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Creates new 4-channel 2D texture with f32 elements
|
||||
absl::Status CreateTexture2DRGBA32F(int width, int height, id<MTLDevice> device,
|
||||
Texture2D* result) {
|
||||
return CreateTexture2D(width, height, DataType::FLOAT32, nullptr, device,
|
||||
result);
|
||||
}
|
||||
|
||||
// Creates new 4-channel 2D texture with f16 elements
|
||||
absl::Status CreateTexture2DRGBA16F(int width, int height, id<MTLDevice> device,
|
||||
Texture2D* result) {
|
||||
return CreateTexture2D(width, height, DataType::FLOAT16, nullptr, device,
|
||||
result);
|
||||
}
|
||||
|
||||
absl::Status CreateTexture2DRGBA(DataType type, int width, int height,
|
||||
id<MTLDevice> device, Texture2D* result) {
|
||||
return CreateTexture2D(width, height, type, nullptr, device, result);
|
||||
}
|
||||
|
||||
absl::Status CreateTexture2DRGBA(DataType type, int width, int height,
|
||||
void* data, id<MTLDevice> device,
|
||||
Texture2D* result) {
|
||||
return CreateTexture2D(width, height, type, data, device, result);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
122
tensorflow/lite/delegates/gpu/metal/texture2d.h
Normal file
122
tensorflow/lite/delegates/gpu/metal/texture2d.h
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
// Texture2D represent formatted GPU data storage.
|
||||
// Texture2D is moveable but not copyable.
|
||||
class Texture2D : public GPUObject {
|
||||
public:
|
||||
Texture2D() {} // just for using Texture2D as a class members
|
||||
Texture2D(id<MTLTexture> texture, int width, int height, MTLPixelFormat pixel_format);
|
||||
|
||||
// Move only
|
||||
Texture2D(Texture2D&& texture);
|
||||
Texture2D& operator=(Texture2D&& texture);
|
||||
Texture2D(const Texture2D&) = delete;
|
||||
Texture2D& operator=(const Texture2D&) = delete;
|
||||
|
||||
~Texture2D() override { Release(); }
|
||||
|
||||
// Writes data to a texture. Data should point to a region that
|
||||
// has exact width * height * sizeof(pixel) bytes.
|
||||
template <typename T>
|
||||
absl::Status WriteData(const absl::Span<T> data);
|
||||
|
||||
// Reads data from Texture2D into CPU memory.
|
||||
template <typename T>
|
||||
absl::Status ReadData(std::vector<T>* result) const;
|
||||
|
||||
absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
|
||||
GPUResourcesWithValue* resources) const override;
|
||||
|
||||
absl::Status CreateFromTexture2DDescriptor(const Texture2DDescriptor& desc, id<MTLDevice> device);
|
||||
|
||||
private:
|
||||
void Release();
|
||||
|
||||
id<MTLTexture> texture_ = nullptr;
|
||||
int width_;
|
||||
int height_;
|
||||
MTLPixelFormat pixel_format_;
|
||||
};
|
||||
|
||||
// Creates new 4-channel 2D texture with f32 elements
|
||||
absl::Status CreateTexture2DRGBA32F(int width, int height, id<MTLDevice> device, Texture2D* result);
|
||||
|
||||
// Creates new 4-channel 2D texture with f16 elements
|
||||
absl::Status CreateTexture2DRGBA16F(int width, int height, id<MTLDevice> device, Texture2D* result);
|
||||
|
||||
absl::Status CreateTexture2DRGBA(DataType type, int width, int height, id<MTLDevice> device,
|
||||
Texture2D* result);
|
||||
|
||||
absl::Status CreateTexture2DRGBA(DataType type, int width, int height, void* data,
|
||||
id<MTLDevice> device, Texture2D* result);
|
||||
|
||||
template <typename T>
|
||||
absl::Status Texture2D::WriteData(const absl::Span<T> data) {
|
||||
const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
|
||||
if (width_ * height_ * pixel_size != data.size() * sizeof(T)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"absl::Span<T> data size is different from texture allocated size.");
|
||||
}
|
||||
|
||||
MTLRegion region = {{0, 0, 0},
|
||||
{static_cast<NSUInteger>(width_), static_cast<NSUInteger>(height_), 1}};
|
||||
[texture_ replaceRegion:region
|
||||
mipmapLevel:0
|
||||
withBytes:data.data()
|
||||
bytesPerRow:width_ * pixel_size];
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::Status Texture2D::ReadData(std::vector<T>* result) const {
|
||||
const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
|
||||
if (pixel_size % sizeof(T) != 0) {
|
||||
return absl::InvalidArgumentError("Pixel format is different.");
|
||||
}
|
||||
result->resize(width_ * height_ * (pixel_size / sizeof(T)));
|
||||
|
||||
MTLRegion region = {{0, 0, 0},
|
||||
{static_cast<NSUInteger>(width_), static_cast<NSUInteger>(height_), 1}};
|
||||
[texture_ getBytes:result->data()
|
||||
bytesPerRow:width_ * pixel_size
|
||||
fromRegion:region
|
||||
mipmapLevel:0];
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
|
70
tensorflow/lite/delegates/gpu/metal/texture2d_test.mm
Normal file
70
tensorflow/lite/delegates/gpu/metal/texture2d_test.mm
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
@interface Texture2DTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation Texture2DTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
using tflite::gpu::half;
|
||||
|
||||
- (void)testTexture2DF32 {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
const std::vector<float> data = {1.0, 2.0, 3.0, -4.0, 5.1, 6.7, 4.1, 6.17};
|
||||
tflite::gpu::metal::Texture2D texture;
|
||||
XCTAssertTrue(tflite::gpu::metal::CreateTexture2DRGBA32F(1, 2, device, &texture).ok());
|
||||
XCTAssertTrue(texture.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
|
||||
std::vector<float> gpu_data;
|
||||
XCTAssertTrue(texture.ReadData<float>(&gpu_data).ok());
|
||||
|
||||
XCTAssertEqual(gpu_data.size(), data.size());
|
||||
for (int i = 0; i < gpu_data.size(); ++i) {
|
||||
XCTAssertEqual(gpu_data[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
- (void)testTexture2DF16 {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
const std::vector<half> data = {half(1.4), half(2.1), half(2.2), half(1.34),
|
||||
half(20.1), half(2.24), half(0.1), half(0.2)};
|
||||
|
||||
tflite::gpu::metal::Texture2D texture;
|
||||
XCTAssertTrue(tflite::gpu::metal::CreateTexture2DRGBA16F(2, 1, device, &texture).ok());
|
||||
XCTAssertTrue(texture.WriteData(absl::MakeConstSpan(data.data(), data.size())).ok());
|
||||
std::vector<half> gpu_data;
|
||||
XCTAssertTrue(texture.ReadData<half>(&gpu_data).ok());
|
||||
|
||||
XCTAssertEqual(gpu_data.size(), data.size());
|
||||
for (int i = 0; i < gpu_data.size(); ++i) {
|
||||
XCTAssertEqual(gpu_data[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
Loading…
Reference in New Issue
Block a user