Using common GPUObjectDescriptor for Metal backend.

PiperOrigin-RevId: 340553367
Change-Id: Ice0a9785d2c2ad5803eada50d8cb1240828f6753
This commit is contained in:
Raman Sarokin 2020-11-03 17:04:01 -08:00 committed by TensorFlower Gardener
parent a64738beb2
commit ef62b6727a
11 changed files with 46 additions and 156 deletions

View File

@ -30,5 +30,18 @@ std::string MemoryTypeToCLType(MemoryType type) {
return "";
}
std::string MemoryTypeToMetalType(MemoryType type) {
switch (type) {
case MemoryType::GLOBAL:
return "device";
case MemoryType::CONSTANT:
return "constant";
break;
case MemoryType::LOCAL:
return "threadgroup";
}
return "";
}
} // namespace gpu
} // namespace tflite

View File

@ -25,6 +25,8 @@ namespace gpu {
std::string MemoryTypeToCLType(MemoryType type);
std::string MemoryTypeToMetalType(MemoryType type);
} // namespace gpu
} // namespace tflite

View File

@ -43,8 +43,8 @@ cc_library(
srcs = ["arguments.cc"],
hdrs = ["arguments.h"],
deps = [
":gpu_object_desc",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
],
)
@ -57,6 +57,7 @@ objc_library(
deps = [
":gpu_object",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
"@com_google_absl//absl/types:span",
],
)
@ -208,22 +209,11 @@ objc_library(
hdrs = ["gpu_object.h"],
copts = DEFAULT_COPTS,
sdk_frameworks = ["Metal"],
deps = [
":gpu_object_desc",
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
],
)
cc_library(
name = "gpu_object_desc",
srcs = ["gpu_object_desc.cc"],
hdrs = ["gpu_object_desc.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
],
)
@ -276,9 +266,10 @@ objc_library(
sdk_frameworks = ["Metal"],
deps = [
":arguments",
":gpu_object_desc",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
"//tensorflow/lite/delegates/gpu/common/task:util",
"@com_google_absl//absl/strings",
],
)

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
namespace tflite {
namespace gpu {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
#include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
namespace tflite {
@ -59,6 +60,8 @@ class Buffer : public GPUObject {
absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
GPUResourcesWithValue* resources) const override;
absl::Status CreateFromBufferDescriptor(const BufferDescriptor& desc, id<MTLDevice> device);
private:
void Release();

View File

@ -49,10 +49,29 @@ void Buffer::Release() {
absl::Status Buffer::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
GPUResourcesWithValue* resources) const {
const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(obj_ptr);
if (!buffer_desc) {
return absl::InvalidArgumentError("Expected BufferDescriptor on input.");
}
resources->buffers.push_back({"buffer", buffer_});
return absl::OkStatus();
}
absl::Status Buffer::CreateFromBufferDescriptor(const BufferDescriptor& desc,
id<MTLDevice> device) {
size_ = desc.size;
if (desc.data.empty()) {
buffer_ = [device newBufferWithLength:size_
options:MTLResourceStorageModeShared];
} else {
buffer_ = [device newBufferWithBytes:desc.data.data()
length:size_
options:MTLResourceStorageModeShared];
}
return absl::OkStatus();
}
absl::Status CreateBuffer(size_t size_in_bytes, const void* data,
id<MTLDevice> device, Buffer* result) {
id<MTLBuffer> buffer;

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
namespace tflite {
namespace gpu {

View File

@ -1,37 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
namespace tflite {
namespace gpu {
namespace metal {
std::string MemoryTypeToMetalType(MemoryType type) {
switch (type) {
case MemoryType::GLOBAL:
return "device";
case MemoryType::CONSTANT:
return "constant";
break;
case MemoryType::LOCAL:
return "threadgroup";
}
return "";
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -1,102 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace metal {
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
std::string MemoryTypeToMetalType(MemoryType type);
struct GPUBufferDescriptor {
DataType data_type;
AccessType access_type;
int element_size;
MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
};
struct GPUResources {
std::vector<std::string> ints;
std::vector<std::string> floats;
std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers;
std::vector<std::string> GetNames() const {
std::vector<std::string> names = ints;
names.insert(names.end(), floats.begin(), floats.end());
for (const auto& obj : buffers) {
names.push_back(obj.first);
}
return names;
}
};
class GPUObjectDescriptor {
public:
GPUObjectDescriptor() = default;
GPUObjectDescriptor(const GPUObjectDescriptor&) = default;
GPUObjectDescriptor& operator=(const GPUObjectDescriptor&) = default;
GPUObjectDescriptor(GPUObjectDescriptor&& obj_desc) = default;
GPUObjectDescriptor& operator=(GPUObjectDescriptor&& obj_desc) = default;
virtual ~GPUObjectDescriptor() = default;
void SetStateVar(const std::string& key, const std::string& value) const {
state_vars_[key] = value;
}
virtual std::string PerformConstExpr(const std::string& const_expr) const {
return "";
}
virtual absl::Status PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const {
*result = "";
return absl::OkStatus();
}
virtual GPUResources GetGPUResources() const { return GPUResources(); }
virtual void Release() {}
void SetAccess(AccessType access_type) { access_type_ = access_type; }
AccessType GetAccess() const { return access_type_; }
protected:
mutable std::map<std::string, std::string> state_vars_;
AccessType access_type_;
};
using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>;
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_

View File

@ -22,8 +22,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h"
namespace tflite {
namespace gpu {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
namespace tflite {
namespace gpu {