Added generic arguments to abstract int/float uniforms.
PiperOrigin-RevId: 313327440 Change-Id: I12c82d0499b3ed9eb4f839cf8016a87bd0ea4807
This commit is contained in:
parent
1aef0ba436
commit
8f31b06f53
@ -38,6 +38,20 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "arguments",
|
||||||
|
srcs = ["arguments.cc"],
|
||||||
|
hdrs = ["arguments.h"],
|
||||||
|
deps = [
|
||||||
|
":opencl_wrapper",
|
||||||
|
":util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "buffer",
|
name = "buffer",
|
||||||
srcs = ["buffer.cc"],
|
srcs = ["buffer.cc"],
|
||||||
|
173
tensorflow/lite/delegates/gpu/cl/arguments.cc
Normal file
173
tensorflow/lite/delegates/gpu/cl/arguments.cc
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace cl {
|
||||||
|
namespace {
|
||||||
|
std::string GetNextWord(const std::string& code, size_t first_position) {
|
||||||
|
size_t pos = first_position;
|
||||||
|
char t = code[pos];
|
||||||
|
while (absl::ascii_isalnum(t) || t == '_') {
|
||||||
|
pos++;
|
||||||
|
t = code[pos];
|
||||||
|
}
|
||||||
|
return code.substr(first_position, pos - first_position);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Arguments::Arguments(Arguments&& args)
|
||||||
|
: int_values_(std::move(args.int_values_)),
|
||||||
|
shared_int4s_data_(std::move(args.shared_int4s_data_)),
|
||||||
|
float_values_(std::move(args.float_values_)),
|
||||||
|
shared_float4s_data_(std::move(args.shared_float4s_data_)) {}
|
||||||
|
Arguments& Arguments::operator=(Arguments&& args) {
|
||||||
|
if (this != &args) {
|
||||||
|
int_values_ = std::move(args.int_values_);
|
||||||
|
shared_int4s_data_ = std::move(args.shared_int4s_data_);
|
||||||
|
float_values_ = std::move(args.float_values_);
|
||||||
|
shared_float4s_data_ = std::move(args.shared_float4s_data_);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Arguments::AddFloat(const std::string& name, float value) {
|
||||||
|
float_values_[name].value = value;
|
||||||
|
}
|
||||||
|
void Arguments::AddInt(const std::string& name, int value) {
|
||||||
|
int_values_[name].value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Arguments::SetInt(const std::string& name, int value) {
|
||||||
|
auto ii = int_values_.find(name);
|
||||||
|
if (ii == int_values_.end()) {
|
||||||
|
return absl::NotFoundError(absl::StrCat("No argument with name - ", name));
|
||||||
|
}
|
||||||
|
ii->second.value = value;
|
||||||
|
if (ii->second.active) {
|
||||||
|
shared_int4s_data_[ii->second.offset] = value;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Arguments::SetFloat(const std::string& name, float value) {
|
||||||
|
auto fi = float_values_.find(name);
|
||||||
|
if (fi == float_values_.end()) {
|
||||||
|
return absl::NotFoundError(absl::StrCat("No argument with name - ", name));
|
||||||
|
}
|
||||||
|
fi->second.value = value;
|
||||||
|
if (fi->second.active) {
|
||||||
|
shared_float4s_data_[fi->second.offset] = value;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Arguments::GetListOfArgs() {
|
||||||
|
std::string result;
|
||||||
|
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
|
||||||
|
absl::StrAppend(&result, ",\n int4 shared_int4_", i);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
|
||||||
|
absl::StrAppend(&result, ",\n float4 shared_float4_", i);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Arguments::Bind(cl_kernel kernel, int offset) {
|
||||||
|
for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
|
||||||
|
const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
|
||||||
|
&shared_int4s_data_[i * 4]);
|
||||||
|
if (error_code != CL_SUCCESS) {
|
||||||
|
return absl::UnknownError(absl::StrCat(
|
||||||
|
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||||
|
"(at index - ", offset, ")"));
|
||||||
|
}
|
||||||
|
offset++;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
|
||||||
|
const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
|
||||||
|
&shared_float4s_data_[i * 4]);
|
||||||
|
if (error_code != CL_SUCCESS) {
|
||||||
|
return absl::UnknownError(absl::StrCat(
|
||||||
|
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||||
|
"(at index - ", offset, ")"));
|
||||||
|
}
|
||||||
|
offset++;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Arguments::AddActiveArgument(const std::string& arg_name) {
|
||||||
|
if (auto it = int_values_.find(arg_name); it != int_values_.end()) {
|
||||||
|
int int_index;
|
||||||
|
if (it->second.active) {
|
||||||
|
int_index = it->second.offset;
|
||||||
|
} else {
|
||||||
|
it->second.active = true;
|
||||||
|
it->second.offset = shared_int4s_data_.size();
|
||||||
|
int_index = it->second.offset;
|
||||||
|
shared_int4s_data_.push_back(it->second.value);
|
||||||
|
}
|
||||||
|
std::string index = std::to_string(int_index / 4);
|
||||||
|
std::string postfixes[4] = {"x", "y", "z", "w"};
|
||||||
|
return "shared_int4_" + index + "." + postfixes[int_index % 4];
|
||||||
|
}
|
||||||
|
if (auto it = float_values_.find(arg_name); it != float_values_.end()) {
|
||||||
|
int float_index;
|
||||||
|
if (it->second.active) {
|
||||||
|
float_index = it->second.offset;
|
||||||
|
} else {
|
||||||
|
it->second.active = true;
|
||||||
|
it->second.offset = shared_float4s_data_.size();
|
||||||
|
float_index = it->second.offset;
|
||||||
|
shared_float4s_data_.push_back(it->second.value);
|
||||||
|
}
|
||||||
|
std::string index = std::to_string(float_index / 4);
|
||||||
|
std::string postfixes[4] = {"x", "y", "z", "w"};
|
||||||
|
return "shared_float4_" + index + "." + postfixes[float_index % 4];
|
||||||
|
}
|
||||||
|
return arg_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Arguments::ResolveArgsPass(std::string* code) {
|
||||||
|
std::string result;
|
||||||
|
constexpr char kPrefix[] = "args.";
|
||||||
|
size_t position = 0;
|
||||||
|
size_t next_position = code->find(kPrefix);
|
||||||
|
while (next_position != std::string::npos) {
|
||||||
|
size_t arg_pos = next_position;
|
||||||
|
next_position += strlen(kPrefix);
|
||||||
|
std::string object_name = GetNextWord(*code, next_position);
|
||||||
|
std::string new_name = AddActiveArgument(object_name);
|
||||||
|
code->replace(arg_pos, object_name.size() + strlen(kPrefix), new_name);
|
||||||
|
position = arg_pos + new_name.size();
|
||||||
|
next_position = code->find(kPrefix, position);
|
||||||
|
}
|
||||||
|
|
||||||
|
int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
|
||||||
|
shared_int4s_data_.resize(shared_int4s_aligned_size);
|
||||||
|
int shared_float4s_aligned_size = AlignByN(shared_float4s_data_.size(), 4);
|
||||||
|
shared_float4s_data_.resize(shared_float4s_aligned_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
88
tensorflow/lite/delegates/gpu/cl/arguments.h
Normal file
88
tensorflow/lite/delegates/gpu/cl/arguments.h
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_ARGUMENTS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_ARGUMENTS_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace cl {
|
||||||
|
|
||||||
|
class Arguments {
|
||||||
|
public:
|
||||||
|
Arguments() = default;
|
||||||
|
void AddFloat(const std::string& name, float value = 0.0f);
|
||||||
|
void AddInt(const std::string& name, int value = 0);
|
||||||
|
|
||||||
|
absl::Status SetInt(const std::string& name, int value);
|
||||||
|
absl::Status SetFloat(const std::string& name, float value);
|
||||||
|
|
||||||
|
std::string GetListOfArgs();
|
||||||
|
|
||||||
|
absl::Status Bind(cl_kernel kernel, int offset);
|
||||||
|
|
||||||
|
void ResolveArgsPass(std::string* code);
|
||||||
|
|
||||||
|
// Move only
|
||||||
|
Arguments(Arguments&& args);
|
||||||
|
Arguments& operator=(Arguments&& args);
|
||||||
|
Arguments(const Arguments&) = delete;
|
||||||
|
Arguments& operator=(const Arguments&) = delete;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string AddActiveArgument(const std::string& arg_name);
|
||||||
|
|
||||||
|
struct IntValue {
|
||||||
|
int value;
|
||||||
|
|
||||||
|
// many uniforms generated automatically and not used
|
||||||
|
// to reduce amount of data transferred we adding this optimization
|
||||||
|
bool active = false;
|
||||||
|
|
||||||
|
// offset to shared uniform storage.
|
||||||
|
uint32_t offset = -1;
|
||||||
|
};
|
||||||
|
std::map<std::string, IntValue> int_values_;
|
||||||
|
std::vector<int32_t> shared_int4s_data_;
|
||||||
|
|
||||||
|
struct FloatValue {
|
||||||
|
float value;
|
||||||
|
|
||||||
|
// many uniforms generated automatically and not used
|
||||||
|
// to reduce amount of data transferred we adding this optimization
|
||||||
|
bool active = false;
|
||||||
|
|
||||||
|
// offset to shared uniform storage.
|
||||||
|
uint32_t offset = -1;
|
||||||
|
};
|
||||||
|
std::map<std::string, FloatValue> float_values_;
|
||||||
|
std::vector<float> shared_float4s_data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_ARGUMENTS_H_
|
@ -65,6 +65,7 @@ class CLKernel {
|
|||||||
int GetPrivateMemorySize() const { return private_memory_size_; }
|
int GetPrivateMemorySize() const { return private_memory_size_; }
|
||||||
int GetMaxWorkGroupSize() const { return max_work_group_size_; }
|
int GetMaxWorkGroupSize() const { return max_work_group_size_; }
|
||||||
|
|
||||||
|
int GetBindingCounter() const { return binding_counter_; }
|
||||||
void ResetBindingCounter() { binding_counter_ = 0; }
|
void ResetBindingCounter() { binding_counter_ = 0; }
|
||||||
|
|
||||||
// Do not use this function
|
// Do not use this function
|
||||||
|
@ -1290,8 +1290,10 @@ cc_library(
|
|||||||
":gpu_operation",
|
":gpu_operation",
|
||||||
":util",
|
":util",
|
||||||
":work_group_picking",
|
":work_group_picking",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl:arguments",
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
||||||
|
|
||||||
@ -27,37 +29,45 @@ namespace {
|
|||||||
|
|
||||||
std::string GetTransposeCode(
|
std::string GetTransposeCode(
|
||||||
const OperationDef& op_def, const TransposeAttributes& attr,
|
const OperationDef& op_def, const TransposeAttributes& attr,
|
||||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||||
TensorCodeGenerator src_tensor(
|
Arguments* args) {
|
||||||
"src_data",
|
TensorCodeGenerator src_tensor("src_data",
|
||||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
WHSBPoint{"args.src_width", "args.src_height",
|
||||||
op_def.src_tensors[0]);
|
"args.src_slices", "args.src_batch"},
|
||||||
TensorCodeGenerator dst_tensor(
|
op_def.src_tensors[0]);
|
||||||
"dst_data",
|
TensorCodeGenerator dst_tensor("dst_data",
|
||||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
WHSBPoint{"args.dst_width", "args.dst_height",
|
||||||
op_def.dst_tensors[0]);
|
"args.dst_slices", "args.dst_batch"},
|
||||||
|
op_def.dst_tensors[0]);
|
||||||
|
|
||||||
|
args->AddInt("src_width");
|
||||||
|
args->AddInt("src_height");
|
||||||
|
args->AddInt("src_slices");
|
||||||
|
args->AddInt("src_batch");
|
||||||
|
args->AddInt("dst_width");
|
||||||
|
args->AddInt("dst_height");
|
||||||
|
args->AddInt("dst_slices");
|
||||||
|
args->AddInt("dst_batch");
|
||||||
|
args->AddInt("dst_channels");
|
||||||
|
|
||||||
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
|
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
|
||||||
std::string c = GetCommonDefines(op_def.precision);
|
std::string c = GetCommonDefines(op_def.precision);
|
||||||
c += "__kernel void main_function(\n";
|
c += "__kernel void main_function(\n";
|
||||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||||
c += GetArgsDeclaration(linked_operations);
|
c += GetArgsDeclaration(linked_operations);
|
||||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
c += dst_tensor.GetDeclaration(AccessType::WRITE);
|
||||||
c += " int4 src_size, \n";
|
c += "$0) {\n";
|
||||||
c += " int4 dst_size, \n";
|
|
||||||
c += " int src_channels, \n";
|
|
||||||
c += " int dst_channels \n";
|
|
||||||
c += ") {\n";
|
|
||||||
if (op_def.IsBatchSupported()) {
|
if (op_def.IsBatchSupported()) {
|
||||||
c += " int linear_id = get_global_id(0);\n";
|
c += " int linear_id = get_global_id(0);\n";
|
||||||
c += " int X = linear_id / dst_size.w;\n";
|
c += " int X = linear_id / args.dst_batch;\n";
|
||||||
c += " int B = linear_id % dst_size.w;\n";
|
c += " int B = linear_id % args.dst_batch;\n";
|
||||||
} else {
|
} else {
|
||||||
c += " int X = get_global_id(0);\n";
|
c += " int X = get_global_id(0);\n";
|
||||||
}
|
}
|
||||||
c += " int Y = get_global_id(1);\n";
|
c += " int Y = get_global_id(1);\n";
|
||||||
c += " int Z = get_global_id(2);\n";
|
c += " int Z = get_global_id(2);\n";
|
||||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n";
|
c += " if (X >= args.dst_width || Y >= args.dst_height || Z >= "
|
||||||
|
"args.dst_slices) { \n";
|
||||||
c += " return; \n";
|
c += " return; \n";
|
||||||
c += " } \n";
|
c += " } \n";
|
||||||
c += " FLT temps[4];\n";
|
c += " FLT temps[4];\n";
|
||||||
@ -83,7 +93,7 @@ std::string GetTransposeCode(
|
|||||||
} else {
|
} else {
|
||||||
c += " for (int i = 0; i < 4; ++i) {\n";
|
c += " for (int i = 0; i < 4; ++i) {\n";
|
||||||
c += " int dst_channel = Z * 4 + i;\n";
|
c += " int dst_channel = Z * 4 + i;\n";
|
||||||
c += " if (dst_channel < dst_channels) {;\n";
|
c += " if (dst_channel < args.dst_channels) {;\n";
|
||||||
const std::string bhwc[] = {"B", "Y", "X", "dst_channel"};
|
const std::string bhwc[] = {"B", "Y", "X", "dst_channel"};
|
||||||
std::string src_b = op_def.IsBatchSupported() ? bhwc[remap[0]] : "";
|
std::string src_b = op_def.IsBatchSupported() ? bhwc[remap[0]] : "";
|
||||||
c += " int s_y = " + bhwc[remap[1]] + ";\n";
|
c += " int s_y = " + bhwc[remap[1]] + ";\n";
|
||||||
@ -100,24 +110,27 @@ std::string GetTransposeCode(
|
|||||||
}
|
}
|
||||||
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
|
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
|
||||||
std::string x_3dcoord =
|
std::string x_3dcoord =
|
||||||
op_def.IsBatchSupported() ? "X * dst_size.w + B" : "X";
|
op_def.IsBatchSupported() ? "X * args.dst_batch + B" : "X";
|
||||||
const LinkingContext context{"result", x_3dcoord, "Y", "Z"};
|
const LinkingContext context{"result", x_3dcoord, "Y", "Z"};
|
||||||
c += PostProcess(linked_operations, context);
|
c += PostProcess(linked_operations, context);
|
||||||
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id);
|
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id);
|
||||||
c += "}\n";
|
c += "}\n";
|
||||||
return c;
|
args->ResolveArgsPass(&c);
|
||||||
|
return absl::Substitute(c, args->GetListOfArgs());
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Transpose::Transpose(Transpose&& operation)
|
Transpose::Transpose(Transpose&& operation)
|
||||||
: GPUOperation(std::move(operation)),
|
: GPUOperation(std::move(operation)),
|
||||||
attr_(operation.attr_),
|
attr_(operation.attr_),
|
||||||
|
args_(std::move(operation.args_)),
|
||||||
kernel_(std::move(operation.kernel_)),
|
kernel_(std::move(operation.kernel_)),
|
||||||
work_group_size_(operation.work_group_size_) {}
|
work_group_size_(operation.work_group_size_) {}
|
||||||
|
|
||||||
Transpose& Transpose::operator=(Transpose&& operation) {
|
Transpose& Transpose::operator=(Transpose&& operation) {
|
||||||
if (this != &operation) {
|
if (this != &operation) {
|
||||||
attr_ = operation.attr_;
|
attr_ = operation.attr_;
|
||||||
|
args_ = std::move(operation.args_);
|
||||||
kernel_ = std::move(operation.kernel_);
|
kernel_ = std::move(operation.kernel_);
|
||||||
std::swap(work_group_size_, operation.work_group_size_);
|
std::swap(work_group_size_, operation.work_group_size_);
|
||||||
GPUOperation::operator=(std::move(operation));
|
GPUOperation::operator=(std::move(operation));
|
||||||
@ -126,21 +139,28 @@ Transpose& Transpose::operator=(Transpose&& operation) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Transpose::Compile(const CreationContext& creation_context) {
|
absl::Status Transpose::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetTransposeCode(definition_, attr_, linked_operations_);
|
const auto code =
|
||||||
|
GetTransposeCode(definition_, attr_, linked_operations_, &args_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Transpose::BindArguments() {
|
absl::Status Transpose::BindArguments() {
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("src_width", src_[0]->Width()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("src_height", src_[0]->Height()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("src_slices", src_[0]->Slices()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("src_batch", src_[0]->Batch()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("dst_width", dst_[0]->Width()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("dst_height", dst_[0]->Height()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("dst_slices", dst_[0]->Slices()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("dst_batch", dst_[0]->Batch()));
|
||||||
|
RETURN_IF_ERROR(args_.SetInt("dst_channels", dst_[0]->Channels()));
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
|
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels()));
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_
|
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
@ -43,6 +44,7 @@ class Transpose : public GPUOperation {
|
|||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
TransposeAttributes attr_;
|
TransposeAttributes attr_;
|
||||||
|
Arguments args_;
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
int3 work_group_size_;
|
int3 work_group_size_;
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user