Using common GPUObjectDescriptor for Metal backend.
PiperOrigin-RevId: 340553367 Change-Id: Ice0a9785d2c2ad5803eada50d8cb1240828f6753
This commit is contained in:
parent
a64738beb2
commit
ef62b6727a
@ -30,5 +30,18 @@ std::string MemoryTypeToCLType(MemoryType type) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string MemoryTypeToMetalType(MemoryType type) {
|
||||
switch (type) {
|
||||
case MemoryType::GLOBAL:
|
||||
return "device";
|
||||
case MemoryType::CONSTANT:
|
||||
return "constant";
|
||||
break;
|
||||
case MemoryType::LOCAL:
|
||||
return "threadgroup";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -25,6 +25,8 @@ namespace gpu {
|
||||
|
||||
std::string MemoryTypeToCLType(MemoryType type);
|
||||
|
||||
std::string MemoryTypeToMetalType(MemoryType type);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -43,8 +43,8 @@ cc_library(
|
||||
srcs = ["arguments.cc"],
|
||||
hdrs = ["arguments.h"],
|
||||
deps = [
|
||||
":gpu_object_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
],
|
||||
)
|
||||
|
||||
@ -57,6 +57,7 @@ objc_library(
|
||||
deps = [
|
||||
":gpu_object",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
@ -208,22 +209,11 @@ objc_library(
|
||||
hdrs = ["gpu_object.h"],
|
||||
copts = DEFAULT_COPTS,
|
||||
sdk_frameworks = ["Metal"],
|
||||
deps = [
|
||||
":gpu_object_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_object_desc",
|
||||
srcs = ["gpu_object_desc.cc"],
|
||||
hdrs = ["gpu_object_desc.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
],
|
||||
)
|
||||
|
||||
@ -276,9 +266,10 @@ objc_library(
|
||||
sdk_frameworks = ["Metal"],
|
||||
deps = [
|
||||
":arguments",
|
||||
":gpu_object_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -59,6 +60,8 @@ class Buffer : public GPUObject {
|
||||
absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
|
||||
GPUResourcesWithValue* resources) const override;
|
||||
|
||||
absl::Status CreateFromBufferDescriptor(const BufferDescriptor& desc, id<MTLDevice> device);
|
||||
|
||||
private:
|
||||
void Release();
|
||||
|
||||
|
@ -49,10 +49,29 @@ void Buffer::Release() {
|
||||
|
||||
absl::Status Buffer::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
|
||||
GPUResourcesWithValue* resources) const {
|
||||
const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(obj_ptr);
|
||||
if (!buffer_desc) {
|
||||
return absl::InvalidArgumentError("Expected BufferDescriptor on input.");
|
||||
}
|
||||
|
||||
resources->buffers.push_back({"buffer", buffer_});
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Buffer::CreateFromBufferDescriptor(const BufferDescriptor& desc,
|
||||
id<MTLDevice> device) {
|
||||
size_ = desc.size;
|
||||
if (desc.data.empty()) {
|
||||
buffer_ = [device newBufferWithLength:size_
|
||||
options:MTLResourceStorageModeShared];
|
||||
} else {
|
||||
buffer_ = [device newBufferWithBytes:desc.data.data()
|
||||
length:size_
|
||||
options:MTLResourceStorageModeShared];
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CreateBuffer(size_t size_in_bytes, const void* data,
|
||||
id<MTLDevice> device, Buffer* result) {
|
||||
id<MTLBuffer> buffer;
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
@ -1,37 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
std::string MemoryTypeToMetalType(MemoryType type) {
|
||||
switch (type) {
|
||||
case MemoryType::GLOBAL:
|
||||
return "device";
|
||||
case MemoryType::CONSTANT:
|
||||
return "constant";
|
||||
break;
|
||||
case MemoryType::LOCAL:
|
||||
return "threadgroup";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,102 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
|
||||
|
||||
std::string MemoryTypeToMetalType(MemoryType type);
|
||||
|
||||
struct GPUBufferDescriptor {
|
||||
DataType data_type;
|
||||
AccessType access_type;
|
||||
int element_size;
|
||||
MemoryType memory_type = MemoryType::GLOBAL;
|
||||
std::vector<std::string> attributes;
|
||||
};
|
||||
|
||||
struct GPUResources {
|
||||
std::vector<std::string> ints;
|
||||
std::vector<std::string> floats;
|
||||
std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers;
|
||||
|
||||
std::vector<std::string> GetNames() const {
|
||||
std::vector<std::string> names = ints;
|
||||
names.insert(names.end(), floats.begin(), floats.end());
|
||||
for (const auto& obj : buffers) {
|
||||
names.push_back(obj.first);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
};
|
||||
|
||||
class GPUObjectDescriptor {
|
||||
public:
|
||||
GPUObjectDescriptor() = default;
|
||||
GPUObjectDescriptor(const GPUObjectDescriptor&) = default;
|
||||
GPUObjectDescriptor& operator=(const GPUObjectDescriptor&) = default;
|
||||
GPUObjectDescriptor(GPUObjectDescriptor&& obj_desc) = default;
|
||||
GPUObjectDescriptor& operator=(GPUObjectDescriptor&& obj_desc) = default;
|
||||
|
||||
virtual ~GPUObjectDescriptor() = default;
|
||||
|
||||
void SetStateVar(const std::string& key, const std::string& value) const {
|
||||
state_vars_[key] = value;
|
||||
}
|
||||
|
||||
virtual std::string PerformConstExpr(const std::string& const_expr) const {
|
||||
return "";
|
||||
}
|
||||
|
||||
virtual absl::Status PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const {
|
||||
*result = "";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
virtual GPUResources GetGPUResources() const { return GPUResources(); }
|
||||
|
||||
virtual void Release() {}
|
||||
|
||||
void SetAccess(AccessType access_type) { access_type_ = access_type; }
|
||||
AccessType GetAccess() const { return access_type_; }
|
||||
|
||||
protected:
|
||||
mutable std::map<std::string, std::string> state_vars_;
|
||||
AccessType access_type_;
|
||||
};
|
||||
|
||||
using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>;
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_
|
@ -22,8 +22,8 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
Loading…
Reference in New Issue
Block a user