Clarified interface of Arguments.
PiperOrigin-RevId: 328758201 Change-Id: Icc4a06faf9f12c33d64d060df7084a1770e4b246
This commit is contained in:
parent
f08a64f05f
commit
69e75a3a12
@ -76,7 +76,10 @@ cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":arguments",
|
":arguments",
|
||||||
|
":buffer",
|
||||||
|
":device_info",
|
||||||
":gpu_object",
|
":gpu_object",
|
||||||
|
":tensor",
|
||||||
":tensor_type",
|
":tensor_type",
|
||||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -256,13 +256,6 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
|
|||||||
object_refs_[name] = {std::move(descriptor_ptr)};
|
object_refs_[name] = {std::move(descriptor_ptr)};
|
||||||
}
|
}
|
||||||
|
|
||||||
void Arguments::AddObject(const std::string& name, AccessType access_type,
|
|
||||||
GPUObjectPtr&& object,
|
|
||||||
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
|
||||||
descriptor_ptr->SetAccess(access_type);
|
|
||||||
objects_[name] = {std::move(object), std::move(descriptor_ptr)};
|
|
||||||
}
|
|
||||||
|
|
||||||
void Arguments::AddObject(const std::string& name,
|
void Arguments::AddObject(const std::string& name,
|
||||||
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||||
descriptor_ptr->SetAccess(AccessType::READ);
|
descriptor_ptr->SetAccess(AccessType::READ);
|
||||||
|
@ -39,37 +39,16 @@ class Arguments {
|
|||||||
void AddFloat(const std::string& name, float value = 0.0f);
|
void AddFloat(const std::string& name, float value = 0.0f);
|
||||||
void AddHalf(const std::string& name, half value = half(0.0f));
|
void AddHalf(const std::string& name, half value = half(0.0f));
|
||||||
void AddInt(const std::string& name, int value = 0);
|
void AddInt(const std::string& name, int value = 0);
|
||||||
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
|
||||||
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
|
||||||
void AddImage2DArray(const std::string& name,
|
|
||||||
const GPUImage2DArrayDescriptor& desc);
|
|
||||||
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
|
|
||||||
void AddImageBuffer(const std::string& name,
|
|
||||||
const GPUImageBufferDescriptor& desc);
|
|
||||||
void AddCustomMemory(const std::string& name,
|
|
||||||
const GPUCustomMemoryDescriptor& desc);
|
|
||||||
|
|
||||||
void AddObjectRef(const std::string& name, AccessType access_type,
|
void AddObjectRef(const std::string& name, AccessType access_type,
|
||||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||||
void AddObject(const std::string& name, AccessType access_type,
|
|
||||||
GPUObjectPtr&& object,
|
|
||||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
|
||||||
void AddObject(const std::string& name,
|
void AddObject(const std::string& name,
|
||||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||||
|
|
||||||
absl::Status SetInt(const std::string& name, int value);
|
absl::Status SetInt(const std::string& name, int value);
|
||||||
absl::Status SetFloat(const std::string& name, float value);
|
absl::Status SetFloat(const std::string& name, float value);
|
||||||
absl::Status SetHalf(const std::string& name, half value);
|
absl::Status SetHalf(const std::string& name, half value);
|
||||||
absl::Status SetImage2D(const std::string& name, cl_mem memory);
|
|
||||||
absl::Status SetBuffer(const std::string& name, cl_mem memory);
|
|
||||||
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
|
|
||||||
absl::Status SetImage3D(const std::string& name, cl_mem memory);
|
|
||||||
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
|
|
||||||
absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
|
|
||||||
absl::Status SetObjectRef(const std::string& name, const GPUObject* object);
|
absl::Status SetObjectRef(const std::string& name, const GPUObject* object);
|
||||||
|
|
||||||
std::string GetListOfArgs();
|
|
||||||
|
|
||||||
absl::Status Bind(cl_kernel kernel, int offset = 0);
|
absl::Status Bind(cl_kernel kernel, int offset = 0);
|
||||||
|
|
||||||
void RenameArgs(const std::string& postfix, std::string* code) const;
|
void RenameArgs(const std::string& postfix, std::string* code) const;
|
||||||
@ -87,6 +66,25 @@ class Arguments {
|
|||||||
Arguments& operator=(const Arguments&) = delete;
|
Arguments& operator=(const Arguments&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
||||||
|
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
||||||
|
void AddImage2DArray(const std::string& name,
|
||||||
|
const GPUImage2DArrayDescriptor& desc);
|
||||||
|
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
|
||||||
|
void AddImageBuffer(const std::string& name,
|
||||||
|
const GPUImageBufferDescriptor& desc);
|
||||||
|
void AddCustomMemory(const std::string& name,
|
||||||
|
const GPUCustomMemoryDescriptor& desc);
|
||||||
|
|
||||||
|
absl::Status SetImage2D(const std::string& name, cl_mem memory);
|
||||||
|
absl::Status SetBuffer(const std::string& name, cl_mem memory);
|
||||||
|
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
|
||||||
|
absl::Status SetImage3D(const std::string& name, cl_mem memory);
|
||||||
|
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
|
||||||
|
absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
|
||||||
|
|
||||||
|
std::string GetListOfArgs();
|
||||||
|
|
||||||
std::string AddActiveArgument(const std::string& arg_name,
|
std::string AddActiveArgument(const std::string& arg_name,
|
||||||
bool use_f32_for_halfs);
|
bool use_f32_for_halfs);
|
||||||
void AddGPUResources(const std::string& name, const GPUResources& resources);
|
void AddGPUResources(const std::string& name, const GPUResources& resources);
|
||||||
|
@ -14,85 +14,58 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/strings/match.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
|
||||||
struct TestDescriptor : public GPUObjectDescriptor {
|
|
||||||
absl::Status PerformSelector(const std::string& selector,
|
|
||||||
const std::vector<std::string>& args,
|
|
||||||
const std::vector<std::string>& template_args,
|
|
||||||
std::string* result) const override {
|
|
||||||
if (selector == "Length") {
|
|
||||||
*result = "length";
|
|
||||||
return absl::OkStatus();
|
|
||||||
} else if (selector == "Read") {
|
|
||||||
if (args.size() != 1) {
|
|
||||||
return absl::NotFoundError(
|
|
||||||
absl::StrCat("TestDescriptor Read require one argument, but ",
|
|
||||||
args.size(), " was passed"));
|
|
||||||
}
|
|
||||||
*result = absl::StrCat("buffer[", args[0], "]");
|
|
||||||
return absl::OkStatus();
|
|
||||||
} else {
|
|
||||||
return absl::NotFoundError(absl::StrCat(
|
|
||||||
"TestDescriptor don't have selector with name - ", selector));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GPUResources GetGPUResources(AccessType access_type) const override {
|
|
||||||
GPUResources resources;
|
|
||||||
resources.ints.push_back("length");
|
|
||||||
GPUBufferDescriptor desc;
|
|
||||||
desc.data_type = DataType::FLOAT32;
|
|
||||||
desc.element_size = 4;
|
|
||||||
resources.buffers.push_back({"buffer", desc});
|
|
||||||
return resources;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
TEST(ArgumentsTest, TestSelectorResolve) {
|
TEST(ArgumentsTest, TestSelectorResolve) {
|
||||||
TestDescriptor descriptor;
|
BufferDescriptor desc;
|
||||||
Arguments args;
|
desc.element_type = DataType::FLOAT32;
|
||||||
args.AddObjectRef("object", AccessType::WRITE,
|
desc.element_size = 4;
|
||||||
absl::make_unique<TestDescriptor>(descriptor));
|
desc.memory_type = MemoryType::GLOBAL;
|
||||||
std::string sample_code = R"(
|
|
||||||
if (a < 3) {
|
|
||||||
value = args.object.Read(id);
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
const std::string expected_result = R"(
|
|
||||||
if (a < 3) {
|
|
||||||
value = object_buffer[id];
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
ASSERT_OK(args.TransformToCLCode({}, &sample_code));
|
|
||||||
EXPECT_EQ(sample_code, expected_result);
|
|
||||||
|
|
||||||
std::string cl_arguments = args.GetListOfArgs();
|
Arguments args;
|
||||||
EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") !=
|
args.AddObjectRef("weights", AccessType::READ,
|
||||||
std::string::npos);
|
absl::make_unique<BufferDescriptor>(std::move(desc)));
|
||||||
|
std::string sample_code = R"(
|
||||||
|
__kernel void main_function($0) {
|
||||||
|
if (a < 3) {
|
||||||
|
value = args.weights.Read(id);
|
||||||
|
}
|
||||||
|
})";
|
||||||
|
|
||||||
|
DeviceInfo device_info;
|
||||||
|
ASSERT_OK(args.TransformToCLCode(device_info, {}, &sample_code));
|
||||||
|
EXPECT_TRUE(absl::StrContains(sample_code, "value = weights_buffer[id];"));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
absl::StrContains(sample_code, "__global float4* weights_buffer"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArgumentsTest, TestNoSelector) {
|
TEST(ArgumentsTest, TestNoSelector) {
|
||||||
TestDescriptor descriptor;
|
BufferDescriptor desc;
|
||||||
|
desc.element_type = DataType::FLOAT32;
|
||||||
|
desc.element_size = 4;
|
||||||
|
desc.memory_type = MemoryType::GLOBAL;
|
||||||
|
|
||||||
Arguments args;
|
Arguments args;
|
||||||
args.AddObjectRef("object", AccessType::WRITE,
|
args.AddObjectRef("weights", AccessType::READ,
|
||||||
absl::make_unique<TestDescriptor>(descriptor));
|
absl::make_unique<BufferDescriptor>(std::move(desc)));
|
||||||
std::string sample_code = R"(
|
std::string sample_code = R"(
|
||||||
if (a < 3) {
|
if (a < 3) {
|
||||||
value = args.object.Write(id);
|
value = args.weights.UnknownSelector(id);
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
EXPECT_FALSE(args.TransformToCLCode({}, &sample_code).ok());
|
DeviceInfo device_info;
|
||||||
|
EXPECT_FALSE(args.TransformToCLCode(device_info, {}, &sample_code).ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArgumentsTest, TestRenameArgs) {
|
TEST(ArgumentsTest, TestRenameArgs) {
|
||||||
|
@ -213,7 +213,6 @@ absl::Status GPUOperation::Compile(const CreationContext& creation_context) {
|
|||||||
RETURN_IF_ERROR(args_.TransformToCLCode(
|
RETURN_IF_ERROR(args_.TransformToCLCode(
|
||||||
creation_context.device->info_,
|
creation_context.device->info_,
|
||||||
{{dst_tensors_names_[0], elementwise_code_}}, &code));
|
{{dst_tensors_names_[0], elementwise_code_}}, &code));
|
||||||
code = absl::Substitute(code, args_.GetListOfArgs());
|
|
||||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_));
|
*creation_context.device, &kernel_));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user