Added device info to TransformToCLCode function.
Storing half parameters on PowerVR as float32 values. PiperOrigin-RevId: 316119180 Change-Id: I60e48cbd7e16cbb35b960acfd31e78d1dc379854
This commit is contained in:
parent
2b7fb42e3b
commit
588854df78
@ -43,6 +43,7 @@ cc_library(
|
||||
srcs = ["arguments.cc"],
|
||||
hdrs = ["arguments.h"],
|
||||
deps = [
|
||||
":cl_device",
|
||||
":gpu_object",
|
||||
":opencl_wrapper",
|
||||
":tensor_type",
|
||||
|
@ -283,8 +283,12 @@ absl::Status Arguments::SetHalf(const std::string& name, half value) {
|
||||
}
|
||||
it->second.value = value;
|
||||
if (it->second.active) {
|
||||
if (it->second.store_as_f32) {
|
||||
shared_float4s_data_[it->second.offset] = value;
|
||||
} else {
|
||||
shared_half4s_data_[it->second.offset] = value;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -436,10 +440,11 @@ absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
|
||||
}
|
||||
|
||||
absl::Status Arguments::TransformToCLCode(
|
||||
const DeviceInfo& device_info,
|
||||
const std::map<std::string, std::string>& linkables, std::string* code) {
|
||||
RETURN_IF_ERROR(AddObjectArgs());
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(linkables, code));
|
||||
ResolveArgsPass(code);
|
||||
ResolveArgsPass(device_info, code);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -568,7 +573,8 @@ absl::Status Arguments::Bind(cl_kernel kernel, int offset) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::string Arguments::AddActiveArgument(const std::string& arg_name) {
|
||||
std::string Arguments::AddActiveArgument(const std::string& arg_name,
|
||||
bool use_f32_for_halfs) {
|
||||
if (auto it = int_values_.find(arg_name); it != int_values_.end()) {
|
||||
int int_index;
|
||||
if (it->second.active) {
|
||||
@ -603,26 +609,39 @@ std::string Arguments::AddActiveArgument(const std::string& arg_name) {
|
||||
half_index = it->second.offset;
|
||||
} else {
|
||||
it->second.active = true;
|
||||
if (use_f32_for_halfs) {
|
||||
it->second.store_as_f32 = true;
|
||||
it->second.offset = shared_float4s_data_.size();
|
||||
shared_float4s_data_.push_back(it->second.value);
|
||||
} else {
|
||||
it->second.offset = shared_half4s_data_.size();
|
||||
half_index = it->second.offset;
|
||||
shared_half4s_data_.push_back(it->second.value);
|
||||
}
|
||||
half_index = it->second.offset;
|
||||
}
|
||||
std::string index = std::to_string(half_index / 4);
|
||||
std::string postfixes[4] = {"x", "y", "z", "w"};
|
||||
if (it->second.store_as_f32) {
|
||||
return "(half)(shared_float4_" + index + "." + postfixes[half_index % 4] +
|
||||
")";
|
||||
} else {
|
||||
return "shared_half4_" + index + "." + postfixes[half_index % 4];
|
||||
}
|
||||
}
|
||||
return arg_name;
|
||||
}
|
||||
|
||||
void Arguments::ResolveArgsPass(std::string* code) {
|
||||
std::string result;
|
||||
void Arguments::ResolveArgsPass(const DeviceInfo& device_info,
|
||||
std::string* code) {
|
||||
bool use_f32_for_half_arguments = device_info.vendor == Vendor::POWERVR;
|
||||
size_t position = 0;
|
||||
size_t next_position = code->find(kArgsPrefix);
|
||||
while (next_position != std::string::npos) {
|
||||
size_t arg_pos = next_position;
|
||||
next_position += strlen(kArgsPrefix);
|
||||
std::string object_name = GetNextWord(*code, next_position);
|
||||
std::string new_name = AddActiveArgument(object_name);
|
||||
std::string new_name =
|
||||
AddActiveArgument(object_name, use_f32_for_half_arguments);
|
||||
code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
|
||||
position = arg_pos + new_name.size();
|
||||
next_position = code->find(kArgsPrefix, position);
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/util.h"
|
||||
@ -69,6 +70,7 @@ class Arguments {
|
||||
absl::Status Merge(Arguments&& args, const std::string& postfix);
|
||||
|
||||
absl::Status TransformToCLCode(
|
||||
const DeviceInfo& device_info,
|
||||
const std::map<std::string, std::string>& linkables, std::string* code);
|
||||
|
||||
// Move only
|
||||
@ -78,7 +80,8 @@ class Arguments {
|
||||
Arguments& operator=(const Arguments&) = delete;
|
||||
|
||||
private:
|
||||
std::string AddActiveArgument(const std::string& arg_name);
|
||||
std::string AddActiveArgument(const std::string& arg_name,
|
||||
bool use_f32_for_halfs);
|
||||
void AddGPUResources(const std::string& name, const GPUResources& resources);
|
||||
|
||||
absl::Status SetGPUResources(const std::string& name,
|
||||
@ -86,7 +89,7 @@ class Arguments {
|
||||
|
||||
absl::Status AddObjectArgs();
|
||||
|
||||
void ResolveArgsPass(std::string* code);
|
||||
void ResolveArgsPass(const DeviceInfo& device_info, std::string* code);
|
||||
absl::Status ResolveSelectorsPass(
|
||||
const std::map<std::string, std::string>& linkables, std::string* code);
|
||||
|
||||
@ -135,6 +138,9 @@ class Arguments {
|
||||
// to reduce amount of data transferred we adding this optimization
|
||||
bool active = false;
|
||||
|
||||
// some devices have issues with half parameters.
|
||||
bool store_as_f32 = false;
|
||||
|
||||
// offset to shared uniform storage.
|
||||
uint32_t offset = -1;
|
||||
};
|
||||
|
@ -91,7 +91,8 @@ Softmax& Softmax::operator=(Softmax&& kernel) {
|
||||
absl::Status Softmax::Compile(const CreationContext& creation_context) {
|
||||
std::string code =
|
||||
GetSoftmaxKernelCode(definition_, linked_operations_, &args_);
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode({}, &code));
|
||||
RETURN_IF_ERROR(
|
||||
args_.TransformToCLCode(creation_context.device->GetInfo(), {}, &code));
|
||||
code = absl::Substitute(code, args_.GetListOfArgs());
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
|
@ -130,8 +130,9 @@ absl::Status Transpose::Compile(const CreationContext& creation_context) {
|
||||
element_wise_code += "{\n" + code + "\n}\n";
|
||||
RETURN_IF_ERROR(args_.Merge(std::move(link_args), postfix));
|
||||
}
|
||||
RETURN_IF_ERROR(
|
||||
args_.TransformToCLCode({{"dst_tensor", element_wise_code}}, &code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
code = absl::Substitute(code, args_.GetListOfArgs());
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
|
@ -391,7 +391,8 @@ absl::Status Winograd4x4To36::Compile(const CreationContext& creation_context) {
|
||||
RETURN_IF_ERROR(UploadBt(creation_context.context));
|
||||
std::string code =
|
||||
GetWinograd4x4To36Code(definition_, linked_operations_, &args_);
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode({}, &code));
|
||||
RETURN_IF_ERROR(
|
||||
args_.TransformToCLCode(creation_context.device->GetInfo(), {}, &code));
|
||||
code = absl::Substitute(code, args_.GetListOfArgs());
|
||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", options, *creation_context.context,
|
||||
|
Loading…
Reference in New Issue
Block a user