Added GpuInfo to arguments of GpuObject.PerformSelector.
PerformSelector can have API specific elements. Added Metal specialization to Texture2D.PerformReadSelector. Added Metal specialization to TensorLinear.PerformReadSelector. PiperOrigin-RevId: 351830974 Change-Id: Ibae3c0aeb7a423eefc6d341803914648dd77c842
This commit is contained in:
parent
3d28cdc603
commit
c3f98039e6
@ -214,7 +214,7 @@ absl::Status CLArguments::Init(
|
||||
Arguments* args, std::string* code) {
|
||||
RETURN_IF_ERROR(AllocateObjects(*args, context));
|
||||
RETURN_IF_ERROR(AddObjectArgs(args));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(*args, linkables, code));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, *args, linkables, code));
|
||||
object_refs_ = std::move(args->object_refs_);
|
||||
args->GetActiveArguments(kArgsPrefix, *code);
|
||||
const bool use_f32_for_halfs = gpu_info.IsPowerVR();
|
||||
@ -273,8 +273,8 @@ absl::Status CLArguments::SetObjectsResources(const Arguments& args) {
|
||||
}
|
||||
|
||||
absl::Status CLArguments::ResolveSelectorsPass(
|
||||
const Arguments& args, const std::map<std::string, std::string>& linkables,
|
||||
std::string* code) {
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables, std::string* code) {
|
||||
std::string result;
|
||||
size_t position = 0;
|
||||
size_t next_position = code->find(kArgsPrefix);
|
||||
@ -305,10 +305,10 @@ absl::Status CLArguments::ResolveSelectorsPass(
|
||||
RETURN_IF_ERROR(ParseArgsInsideBrackets(
|
||||
*code, next_position, &close_bracket_pos, &function_args));
|
||||
for (auto& arg : function_args) {
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, &arg));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
|
||||
}
|
||||
std::string patch;
|
||||
RETURN_IF_ERROR(ResolveSelector(args, linkables, object_name,
|
||||
RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
|
||||
selector_name, function_args,
|
||||
template_args, &patch));
|
||||
code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
|
||||
@ -331,7 +331,8 @@ void CLArguments::ResolveObjectNames(
|
||||
}
|
||||
|
||||
absl::Status CLArguments::ResolveSelector(
|
||||
const Arguments& args, const std::map<std::string, std::string>& linkables,
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables,
|
||||
const std::string& object_name, const std::string& selector,
|
||||
const std::vector<std::string>& function_args,
|
||||
const std::vector<std::string>& template_args, std::string* result) {
|
||||
@ -366,14 +367,14 @@ absl::Status CLArguments::ResolveSelector(
|
||||
ReplaceAllWords("X_COORD", x_coord, result);
|
||||
ReplaceAllWords("Y_COORD", y_coord, result);
|
||||
ReplaceAllWords("S_COORD", s_coord, result);
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, result));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
|
||||
if (selector == "Linking") {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string patch;
|
||||
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, function_args,
|
||||
RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
|
||||
template_args, &patch));
|
||||
ResolveObjectNames(object_name, names, &patch);
|
||||
*result += patch;
|
||||
|
@ -67,10 +67,10 @@ class CLArguments : public ArgumentsBinder {
|
||||
absl::Status AddObjectArgs(Arguments* args);
|
||||
|
||||
absl::Status ResolveSelectorsPass(
|
||||
const Arguments& args,
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables, std::string* code);
|
||||
absl::Status ResolveSelector(
|
||||
const Arguments& args,
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables,
|
||||
const std::string& object_name, const std::string& selector,
|
||||
const std::vector<std::string>& function_args,
|
||||
|
@ -46,6 +46,7 @@ cc_library(
|
||||
":serialization_base_cc_fbs",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:gpu_info",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
|
@ -40,7 +40,8 @@ GPUResources BufferDescriptor::GetGPUResources() const {
|
||||
}
|
||||
|
||||
absl::Status BufferDescriptor::PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const GpuInfo& gpu_info, const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
|
@ -42,7 +42,8 @@ struct BufferDescriptor : public GPUObjectDescriptor {
|
||||
BufferDescriptor(BufferDescriptor&& desc) = default;
|
||||
BufferDescriptor& operator=(BufferDescriptor&& desc) = default;
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
absl::Status PerformSelector(const GpuInfo& gpu_info,
|
||||
const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override;
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
|
||||
|
||||
@ -117,7 +118,8 @@ class GPUObjectDescriptor {
|
||||
}
|
||||
|
||||
virtual absl::Status PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const GpuInfo& gpu_info, const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const {
|
||||
*result = "";
|
||||
|
@ -164,7 +164,8 @@ GPUResources TensorDescriptor::GetGPUResources() const {
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const GpuInfo& gpu_info, const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (selector == "Width") {
|
||||
*result = GetWidth();
|
||||
|
@ -59,7 +59,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
|
||||
bool operator!=(const TensorDescriptor& d) const { return !(*this == d); }
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
absl::Status PerformSelector(const GpuInfo& gpu_info,
|
||||
const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override;
|
||||
|
@ -46,13 +46,14 @@ GPUResources TensorLinearDescriptor::GetGPUResources() const {
|
||||
}
|
||||
|
||||
absl::Status TensorLinearDescriptor::PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const GpuInfo& gpu_info, const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (selector == "Length") {
|
||||
*result = "length";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
return PerformReadSelector(gpu_info, args, result);
|
||||
} else if (selector == "GetPtr") {
|
||||
if (storage_type != LinearStorageType::BUFFER) {
|
||||
return absl::InvalidArgumentError(
|
||||
@ -67,7 +68,8 @@ absl::Status TensorLinearDescriptor::PerformSelector(
|
||||
}
|
||||
|
||||
absl::Status TensorLinearDescriptor::PerformReadSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
const GpuInfo& gpu_info, const std::vector<std::string>& args,
|
||||
std::string* result) const {
|
||||
if (args.size() != 1) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("TensorLinearDescriptor Read require one argument, but ",
|
||||
@ -77,10 +79,19 @@ absl::Status TensorLinearDescriptor::PerformReadSelector(
|
||||
*result = absl::StrCat("buffer[", args[0], "]");
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
if (gpu_info.IsApiMetal()) {
|
||||
*result = absl::StrCat("tex2d.read(ushort2(", args[0], ", 0))");
|
||||
return absl::OkStatus();
|
||||
} else if (gpu_info.IsApiOpenCl()) {
|
||||
const std::string read =
|
||||
element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
|
||||
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
|
||||
*result =
|
||||
absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::UnimplementedError(
|
||||
"No implementation of TensorLinear.Read for this API.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,13 +48,15 @@ struct TensorLinearDescriptor : public GPUObjectDescriptor {
|
||||
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src,
|
||||
int aligned_size = 0);
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
absl::Status PerformSelector(const GpuInfo& gpu_info,
|
||||
const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override;
|
||||
|
||||
GPUResources GetGPUResources() const override;
|
||||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||
absl::Status PerformReadSelector(const GpuInfo& gpu_info,
|
||||
const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
void Release() override;
|
||||
|
@ -32,10 +32,11 @@ GPUResources Texture2DDescriptor::GetGPUResources() const {
|
||||
}
|
||||
|
||||
absl::Status Texture2DDescriptor::PerformSelector(
|
||||
const std::string& selector, const std::vector<std::string>& args,
|
||||
const GpuInfo& gpu_info, const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
return PerformReadSelector(gpu_info, args, result);
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"Texture2DDescriptor don't have selector with name - ", selector));
|
||||
@ -43,12 +44,18 @@ absl::Status Texture2DDescriptor::PerformSelector(
|
||||
}
|
||||
|
||||
absl::Status Texture2DDescriptor::PerformReadSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
const GpuInfo& gpu_info, const std::vector<std::string>& args,
|
||||
std::string* result) const {
|
||||
if (args.size() != 2) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("Texture2DDescriptor Read require two arguments, but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
if (gpu_info.IsApiMetal()) {
|
||||
*result =
|
||||
absl::StrCat("tex2d.read(ushort2(", args[0], ", " + args[1] + "))");
|
||||
return absl::OkStatus();
|
||||
} else if (gpu_info.IsApiOpenCl()) {
|
||||
std::string read;
|
||||
switch (element_type) {
|
||||
case DataType::FLOAT32:
|
||||
@ -84,6 +91,10 @@ absl::Status Texture2DDescriptor::PerformReadSelector(
|
||||
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0],
|
||||
", " + args[1] + "))");
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::UnimplementedError(
|
||||
"No implementation of Texture2D.Read for this API.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -41,13 +41,15 @@ struct Texture2DDescriptor : public GPUObjectDescriptor {
|
||||
Texture2DDescriptor(Texture2DDescriptor&& desc) = default;
|
||||
Texture2DDescriptor& operator=(Texture2DDescriptor&& desc) = default;
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
absl::Status PerformSelector(const GpuInfo& gpu_info,
|
||||
const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args,
|
||||
std::string* result) const override;
|
||||
|
||||
GPUResources GetGPUResources() const override;
|
||||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||
absl::Status PerformReadSelector(const GpuInfo& gpu_info,
|
||||
const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
void Release() override;
|
||||
|
@ -195,6 +195,7 @@ objc_library(
|
||||
deps = [
|
||||
":buffer",
|
||||
":gpu_object",
|
||||
":metal_device",
|
||||
":metal_spatial_tensor",
|
||||
":texture2d",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
|
@ -43,8 +43,7 @@ absl::Status ComputeTask::Compile(CalculationsPrecision precision,
|
||||
task_desc_->AssembleCode();
|
||||
const std::map<std::string, std::string> linkables = {
|
||||
{task_desc_->dst_tensors_names[0], task_desc_->elementwise_code}};
|
||||
RETURN_IF_ERROR(metal_args_.Init(device->device(), linkables,
|
||||
&task_desc_->args,
|
||||
RETURN_IF_ERROR(metal_args_.Init(linkables, device, &task_desc_->args,
|
||||
&task_desc_->shader_source));
|
||||
task_desc_->args.ReleaseCPURepresentation();
|
||||
NSString* barrier;
|
||||
|
@ -155,11 +155,12 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
|
||||
constexpr char MetalArguments::kArgsPrefix[];
|
||||
|
||||
absl::Status MetalArguments::Init(
|
||||
id<MTLDevice> device, const std::map<std::string, std::string>& linkables,
|
||||
const std::map<std::string, std::string>& linkables, MetalDevice* device,
|
||||
Arguments* args, std::string* code) {
|
||||
RETURN_IF_ERROR(AllocateObjects(*args, device));
|
||||
RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
|
||||
RETURN_IF_ERROR(AddObjectArgs(args));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(*args, linkables, code));
|
||||
RETURN_IF_ERROR(
|
||||
ResolveSelectorsPass(device->GetInfo(), *args, linkables, code));
|
||||
object_refs_ = std::move(args->object_refs_);
|
||||
args->GetActiveArguments(kArgsPrefix, *code);
|
||||
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
|
||||
@ -445,8 +446,8 @@ absl::Status MetalArguments::SetBuffer(const std::string& name,
|
||||
}
|
||||
|
||||
absl::Status MetalArguments::ResolveSelectorsPass(
|
||||
const Arguments& args, const std::map<std::string, std::string>& linkables,
|
||||
std::string* code) {
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables, std::string* code) {
|
||||
std::string result;
|
||||
size_t position = 0;
|
||||
size_t next_position = code->find(kArgsPrefix);
|
||||
@ -477,10 +478,10 @@ absl::Status MetalArguments::ResolveSelectorsPass(
|
||||
RETURN_IF_ERROR(ParseArgsInsideBrackets(
|
||||
*code, next_position, &close_bracket_pos, &function_args));
|
||||
for (auto& arg : function_args) {
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, &arg));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
|
||||
}
|
||||
std::string patch;
|
||||
RETURN_IF_ERROR(ResolveSelector(args, linkables, object_name,
|
||||
RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
|
||||
selector_name, function_args,
|
||||
template_args, &patch));
|
||||
code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
|
||||
@ -494,7 +495,8 @@ absl::Status MetalArguments::ResolveSelectorsPass(
|
||||
}
|
||||
|
||||
absl::Status MetalArguments::ResolveSelector(
|
||||
const Arguments& args, const std::map<std::string, std::string>& linkables,
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables,
|
||||
const std::string& object_name, const std::string& selector,
|
||||
const std::vector<std::string>& function_args,
|
||||
const std::vector<std::string>& template_args, std::string* result) {
|
||||
@ -529,14 +531,14 @@ absl::Status MetalArguments::ResolveSelector(
|
||||
ReplaceAllWords("X_COORD", x_coord, result);
|
||||
ReplaceAllWords("Y_COORD", y_coord, result);
|
||||
ReplaceAllWords("S_COORD", s_coord, result);
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, result));
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
|
||||
if (selector == "Linking") {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string patch;
|
||||
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, function_args,
|
||||
RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
|
||||
template_args, &patch));
|
||||
ResolveObjectNames(object_name, names, &patch);
|
||||
*result += patch;
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_device.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -34,9 +35,8 @@ class MetalArguments : public ArgumentsBinder {
|
||||
public:
|
||||
MetalArguments() = default;
|
||||
|
||||
absl::Status Init(id<MTLDevice> device,
|
||||
const std::map<std::string, std::string>& linkables,
|
||||
Arguments* args, std::string* code);
|
||||
absl::Status Init(const std::map<std::string, std::string>& linkables,
|
||||
MetalDevice* device, Arguments* args, std::string* code);
|
||||
|
||||
// Move only
|
||||
MetalArguments(MetalArguments&& args) = default;
|
||||
@ -88,11 +88,11 @@ class MetalArguments : public ArgumentsBinder {
|
||||
absl::Status SetObjectsResources(const Arguments& args);
|
||||
|
||||
absl::Status ResolveSelectorsPass(
|
||||
const Arguments& args,
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables, std::string* code);
|
||||
|
||||
absl::Status ResolveSelector(
|
||||
const Arguments& args,
|
||||
const GpuInfo& gpu_info, const Arguments& args,
|
||||
const std::map<std::string, std::string>& linkables,
|
||||
const std::string& object_name, const std::string& selector,
|
||||
const std::vector<std::string>& function_args,
|
||||
|
Loading…
Reference in New Issue
Block a user