Clarified interface of Arguments.
PiperOrigin-RevId: 328758201 Change-Id: Icc4a06faf9f12c33d64d060df7084a1770e4b246
This commit is contained in:
parent
f08a64f05f
commit
69e75a3a12
@ -76,7 +76,10 @@ cc_test(
|
||||
],
|
||||
deps = [
|
||||
":arguments",
|
||||
":buffer",
|
||||
":device_info",
|
||||
":gpu_object",
|
||||
":tensor",
|
||||
":tensor_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -256,13 +256,6 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
|
||||
object_refs_[name] = {std::move(descriptor_ptr)};
|
||||
}
|
||||
|
||||
void Arguments::AddObject(const std::string& name, AccessType access_type,
|
||||
GPUObjectPtr&& object,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||
descriptor_ptr->SetAccess(access_type);
|
||||
objects_[name] = {std::move(object), std::move(descriptor_ptr)};
|
||||
}
|
||||
|
||||
void Arguments::AddObject(const std::string& name,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||
descriptor_ptr->SetAccess(AccessType::READ);
|
||||
|
@ -39,37 +39,16 @@ class Arguments {
|
||||
void AddFloat(const std::string& name, float value = 0.0f);
|
||||
void AddHalf(const std::string& name, half value = half(0.0f));
|
||||
void AddInt(const std::string& name, int value = 0);
|
||||
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
||||
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
||||
void AddImage2DArray(const std::string& name,
|
||||
const GPUImage2DArrayDescriptor& desc);
|
||||
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
|
||||
void AddImageBuffer(const std::string& name,
|
||||
const GPUImageBufferDescriptor& desc);
|
||||
void AddCustomMemory(const std::string& name,
|
||||
const GPUCustomMemoryDescriptor& desc);
|
||||
|
||||
void AddObjectRef(const std::string& name, AccessType access_type,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||
void AddObject(const std::string& name, AccessType access_type,
|
||||
GPUObjectPtr&& object,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||
void AddObject(const std::string& name,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||
|
||||
absl::Status SetInt(const std::string& name, int value);
|
||||
absl::Status SetFloat(const std::string& name, float value);
|
||||
absl::Status SetHalf(const std::string& name, half value);
|
||||
absl::Status SetImage2D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetBuffer(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImage3D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
|
||||
absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
|
||||
absl::Status SetObjectRef(const std::string& name, const GPUObject* object);
|
||||
|
||||
std::string GetListOfArgs();
|
||||
|
||||
absl::Status Bind(cl_kernel kernel, int offset = 0);
|
||||
|
||||
void RenameArgs(const std::string& postfix, std::string* code) const;
|
||||
@ -87,6 +66,25 @@ class Arguments {
|
||||
Arguments& operator=(const Arguments&) = delete;
|
||||
|
||||
private:
|
||||
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
|
||||
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
|
||||
void AddImage2DArray(const std::string& name,
|
||||
const GPUImage2DArrayDescriptor& desc);
|
||||
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
|
||||
void AddImageBuffer(const std::string& name,
|
||||
const GPUImageBufferDescriptor& desc);
|
||||
void AddCustomMemory(const std::string& name,
|
||||
const GPUCustomMemoryDescriptor& desc);
|
||||
|
||||
absl::Status SetImage2D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetBuffer(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImage3D(const std::string& name, cl_mem memory);
|
||||
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
|
||||
absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
|
||||
|
||||
std::string GetListOfArgs();
|
||||
|
||||
std::string AddActiveArgument(const std::string& arg_name,
|
||||
bool use_f32_for_halfs);
|
||||
void AddGPUResources(const std::string& name, const GPUResources& resources);
|
||||
|
@ -14,85 +14,58 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
struct TestDescriptor : public GPUObjectDescriptor {
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override {
|
||||
if (selector == "Length") {
|
||||
*result = "length";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "Read") {
|
||||
if (args.size() != 1) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("TestDescriptor Read require one argument, but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
*result = absl::StrCat("buffer[", args[0], "]");
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"TestDescriptor don't have selector with name - ", selector));
|
||||
}
|
||||
}
|
||||
|
||||
GPUResources GetGPUResources(AccessType access_type) const override {
|
||||
GPUResources resources;
|
||||
resources.ints.push_back("length");
|
||||
GPUBufferDescriptor desc;
|
||||
desc.data_type = DataType::FLOAT32;
|
||||
desc.element_size = 4;
|
||||
resources.buffers.push_back({"buffer", desc});
|
||||
return resources;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(ArgumentsTest, TestSelectorResolve) {
|
||||
TestDescriptor descriptor;
|
||||
Arguments args;
|
||||
args.AddObjectRef("object", AccessType::WRITE,
|
||||
absl::make_unique<TestDescriptor>(descriptor));
|
||||
std::string sample_code = R"(
|
||||
if (a < 3) {
|
||||
value = args.object.Read(id);
|
||||
}
|
||||
)";
|
||||
const std::string expected_result = R"(
|
||||
if (a < 3) {
|
||||
value = object_buffer[id];
|
||||
}
|
||||
)";
|
||||
ASSERT_OK(args.TransformToCLCode({}, &sample_code));
|
||||
EXPECT_EQ(sample_code, expected_result);
|
||||
BufferDescriptor desc;
|
||||
desc.element_type = DataType::FLOAT32;
|
||||
desc.element_size = 4;
|
||||
desc.memory_type = MemoryType::GLOBAL;
|
||||
|
||||
std::string cl_arguments = args.GetListOfArgs();
|
||||
EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") !=
|
||||
std::string::npos);
|
||||
Arguments args;
|
||||
args.AddObjectRef("weights", AccessType::READ,
|
||||
absl::make_unique<BufferDescriptor>(std::move(desc)));
|
||||
std::string sample_code = R"(
|
||||
__kernel void main_function($0) {
|
||||
if (a < 3) {
|
||||
value = args.weights.Read(id);
|
||||
}
|
||||
})";
|
||||
|
||||
DeviceInfo device_info;
|
||||
ASSERT_OK(args.TransformToCLCode(device_info, {}, &sample_code));
|
||||
EXPECT_TRUE(absl::StrContains(sample_code, "value = weights_buffer[id];"));
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(sample_code, "__global float4* weights_buffer"));
|
||||
}
|
||||
|
||||
TEST(ArgumentsTest, TestNoSelector) {
|
||||
TestDescriptor descriptor;
|
||||
BufferDescriptor desc;
|
||||
desc.element_type = DataType::FLOAT32;
|
||||
desc.element_size = 4;
|
||||
desc.memory_type = MemoryType::GLOBAL;
|
||||
|
||||
Arguments args;
|
||||
args.AddObjectRef("object", AccessType::WRITE,
|
||||
absl::make_unique<TestDescriptor>(descriptor));
|
||||
args.AddObjectRef("weights", AccessType::READ,
|
||||
absl::make_unique<BufferDescriptor>(std::move(desc)));
|
||||
std::string sample_code = R"(
|
||||
if (a < 3) {
|
||||
value = args.object.Write(id);
|
||||
value = args.weights.UnknownSelector(id);
|
||||
}
|
||||
)";
|
||||
EXPECT_FALSE(args.TransformToCLCode({}, &sample_code).ok());
|
||||
DeviceInfo device_info;
|
||||
EXPECT_FALSE(args.TransformToCLCode(device_info, {}, &sample_code).ok());
|
||||
}
|
||||
|
||||
TEST(ArgumentsTest, TestRenameArgs) {
|
||||
|
@ -213,7 +213,6 @@ absl::Status GPUOperation::Compile(const CreationContext& creation_context) {
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(
|
||||
creation_context.device->info_,
|
||||
{{dst_tensors_names_[0], elementwise_code_}}, &code));
|
||||
code = absl::Substitute(code, args_.GetListOfArgs());
|
||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_));
|
||||
|
Loading…
x
Reference in New Issue
Block a user