Clarified interface of Arguments.

PiperOrigin-RevId: 328758201
Change-Id: Icc4a06faf9f12c33d64d060df7084a1770e4b246
This commit is contained in:
Raman Sarokin 2020-08-27 09:59:32 -07:00 committed by TensorFlower Gardener
parent f08a64f05f
commit 69e75a3a12
5 changed files with 55 additions and 89 deletions

View File

@ -76,7 +76,10 @@ cc_test(
],
deps = [
":arguments",
":buffer",
":device_info",
":gpu_object",
":tensor",
":tensor_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"@com_google_absl//absl/strings",

View File

@ -256,13 +256,6 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
object_refs_[name] = {std::move(descriptor_ptr)};
}
void Arguments::AddObject(const std::string& name, AccessType access_type,
GPUObjectPtr&& object,
GPUObjectDescriptorPtr&& descriptor_ptr) {
descriptor_ptr->SetAccess(access_type);
objects_[name] = {std::move(object), std::move(descriptor_ptr)};
}
void Arguments::AddObject(const std::string& name,
GPUObjectDescriptorPtr&& descriptor_ptr) {
descriptor_ptr->SetAccess(AccessType::READ);

View File

@ -39,37 +39,16 @@ class Arguments {
void AddFloat(const std::string& name, float value = 0.0f);
void AddHalf(const std::string& name, half value = half(0.0f));
void AddInt(const std::string& name, int value = 0);
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
void AddImage2DArray(const std::string& name,
const GPUImage2DArrayDescriptor& desc);
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
void AddImageBuffer(const std::string& name,
const GPUImageBufferDescriptor& desc);
void AddCustomMemory(const std::string& name,
const GPUCustomMemoryDescriptor& desc);
void AddObjectRef(const std::string& name, AccessType access_type,
GPUObjectDescriptorPtr&& descriptor_ptr);
void AddObject(const std::string& name, AccessType access_type,
GPUObjectPtr&& object,
GPUObjectDescriptorPtr&& descriptor_ptr);
void AddObject(const std::string& name,
GPUObjectDescriptorPtr&& descriptor_ptr);
absl::Status SetInt(const std::string& name, int value);
absl::Status SetFloat(const std::string& name, float value);
absl::Status SetHalf(const std::string& name, half value);
absl::Status SetImage2D(const std::string& name, cl_mem memory);
absl::Status SetBuffer(const std::string& name, cl_mem memory);
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
absl::Status SetImage3D(const std::string& name, cl_mem memory);
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
absl::Status SetObjectRef(const std::string& name, const GPUObject* object);
std::string GetListOfArgs();
absl::Status Bind(cl_kernel kernel, int offset = 0);
void RenameArgs(const std::string& postfix, std::string* code) const;
@ -87,6 +66,25 @@ class Arguments {
Arguments& operator=(const Arguments&) = delete;
private:
void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
void AddImage2DArray(const std::string& name,
const GPUImage2DArrayDescriptor& desc);
void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
void AddImageBuffer(const std::string& name,
const GPUImageBufferDescriptor& desc);
void AddCustomMemory(const std::string& name,
const GPUCustomMemoryDescriptor& desc);
absl::Status SetImage2D(const std::string& name, cl_mem memory);
absl::Status SetBuffer(const std::string& name, cl_mem memory);
absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
absl::Status SetImage3D(const std::string& name, cl_mem memory);
absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
std::string GetListOfArgs();
std::string AddActiveArgument(const std::string& arg_name,
bool use_f32_for_halfs);
void AddGPUResources(const std::string& name, const GPUResources& resources);

View File

@ -14,85 +14,58 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
#include <cstdint>
#include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/match.h"
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
struct TestDescriptor : public GPUObjectDescriptor {
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override {
if (selector == "Length") {
*result = "length";
return absl::OkStatus();
} else if (selector == "Read") {
if (args.size() != 1) {
return absl::NotFoundError(
absl::StrCat("TestDescriptor Read require one argument, but ",
args.size(), " was passed"));
}
*result = absl::StrCat("buffer[", args[0], "]");
return absl::OkStatus();
} else {
return absl::NotFoundError(absl::StrCat(
"TestDescriptor don't have selector with name - ", selector));
}
}
GPUResources GetGPUResources(AccessType access_type) const override {
GPUResources resources;
resources.ints.push_back("length");
GPUBufferDescriptor desc;
desc.data_type = DataType::FLOAT32;
desc.element_size = 4;
resources.buffers.push_back({"buffer", desc});
return resources;
}
};
} // namespace
TEST(ArgumentsTest, TestSelectorResolve) {
TestDescriptor descriptor;
Arguments args;
args.AddObjectRef("object", AccessType::WRITE,
absl::make_unique<TestDescriptor>(descriptor));
std::string sample_code = R"(
if (a < 3) {
value = args.object.Read(id);
}
)";
const std::string expected_result = R"(
if (a < 3) {
value = object_buffer[id];
}
)";
ASSERT_OK(args.TransformToCLCode({}, &sample_code));
EXPECT_EQ(sample_code, expected_result);
BufferDescriptor desc;
desc.element_type = DataType::FLOAT32;
desc.element_size = 4;
desc.memory_type = MemoryType::GLOBAL;
std::string cl_arguments = args.GetListOfArgs();
EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") !=
std::string::npos);
Arguments args;
args.AddObjectRef("weights", AccessType::READ,
absl::make_unique<BufferDescriptor>(std::move(desc)));
std::string sample_code = R"(
__kernel void main_function($0) {
if (a < 3) {
value = args.weights.Read(id);
}
})";
DeviceInfo device_info;
ASSERT_OK(args.TransformToCLCode(device_info, {}, &sample_code));
EXPECT_TRUE(absl::StrContains(sample_code, "value = weights_buffer[id];"));
EXPECT_TRUE(
absl::StrContains(sample_code, "__global float4* weights_buffer"));
}
TEST(ArgumentsTest, TestNoSelector) {
TestDescriptor descriptor;
BufferDescriptor desc;
desc.element_type = DataType::FLOAT32;
desc.element_size = 4;
desc.memory_type = MemoryType::GLOBAL;
Arguments args;
args.AddObjectRef("object", AccessType::WRITE,
absl::make_unique<TestDescriptor>(descriptor));
args.AddObjectRef("weights", AccessType::READ,
absl::make_unique<BufferDescriptor>(std::move(desc)));
std::string sample_code = R"(
if (a < 3) {
value = args.object.Write(id);
value = args.weights.UnknownSelector(id);
}
)";
EXPECT_FALSE(args.TransformToCLCode({}, &sample_code).ok());
DeviceInfo device_info;
EXPECT_FALSE(args.TransformToCLCode(device_info, {}, &sample_code).ok());
}
TEST(ArgumentsTest, TestRenameArgs) {

View File

@ -213,7 +213,6 @@ absl::Status GPUOperation::Compile(const CreationContext& creation_context) {
RETURN_IF_ERROR(args_.TransformToCLCode(
creation_context.device->info_,
{{dst_tensors_names_[0], elementwise_code_}}, &code));
code = absl::Substitute(code, args_.GetListOfArgs());
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_));