Added GpuInfo to arguments of GpuObject.PerformSelector.

PerformSelector can have API specific elements.
Added Metal specialization to Texture2D.PerformReadSelector.
Added Metal specialization to TensorLinear.PerformReadSelector.

PiperOrigin-RevId: 351830974
Change-Id: Ibae3c0aeb7a423eefc6d341803914648dd77c842
This commit is contained in:
Raman Sarokin 2021-01-14 10:44:48 -08:00 committed by TensorFlower Gardener
parent 3d28cdc603
commit c3f98039e6
16 changed files with 116 additions and 80 deletions

View File

@ -214,7 +214,7 @@ absl::Status CLArguments::Init(
Arguments* args, std::string* code) {
RETURN_IF_ERROR(AllocateObjects(*args, context));
RETURN_IF_ERROR(AddObjectArgs(args));
RETURN_IF_ERROR(ResolveSelectorsPass(*args, linkables, code));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, *args, linkables, code));
object_refs_ = std::move(args->object_refs_);
args->GetActiveArguments(kArgsPrefix, *code);
const bool use_f32_for_halfs = gpu_info.IsPowerVR();
@ -273,8 +273,8 @@ absl::Status CLArguments::SetObjectsResources(const Arguments& args) {
}
absl::Status CLArguments::ResolveSelectorsPass(
const Arguments& args, const std::map<std::string, std::string>& linkables,
std::string* code) {
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables, std::string* code) {
std::string result;
size_t position = 0;
size_t next_position = code->find(kArgsPrefix);
@ -305,10 +305,10 @@ absl::Status CLArguments::ResolveSelectorsPass(
RETURN_IF_ERROR(ParseArgsInsideBrackets(
*code, next_position, &close_bracket_pos, &function_args));
for (auto& arg : function_args) {
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, &arg));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
}
std::string patch;
RETURN_IF_ERROR(ResolveSelector(args, linkables, object_name,
RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
selector_name, function_args,
template_args, &patch));
code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
@ -331,7 +331,8 @@ void CLArguments::ResolveObjectNames(
}
absl::Status CLArguments::ResolveSelector(
const Arguments& args, const std::map<std::string, std::string>& linkables,
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables,
const std::string& object_name, const std::string& selector,
const std::vector<std::string>& function_args,
const std::vector<std::string>& template_args, std::string* result) {
@ -366,14 +367,14 @@ absl::Status CLArguments::ResolveSelector(
ReplaceAllWords("X_COORD", x_coord, result);
ReplaceAllWords("Y_COORD", y_coord, result);
ReplaceAllWords("S_COORD", s_coord, result);
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, result));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
if (selector == "Linking") {
return absl::OkStatus();
}
}
}
std::string patch;
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, function_args,
RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
template_args, &patch));
ResolveObjectNames(object_name, names, &patch);
*result += patch;

View File

@ -67,10 +67,10 @@ class CLArguments : public ArgumentsBinder {
absl::Status AddObjectArgs(Arguments* args);
absl::Status ResolveSelectorsPass(
const Arguments& args,
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables, std::string* code);
absl::Status ResolveSelector(
const Arguments& args,
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables,
const std::string& object_name, const std::string& selector,
const std::vector<std::string>& function_args,

View File

@ -46,6 +46,7 @@ cc_library(
":serialization_base_cc_fbs",
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:gpu_info",
"//tensorflow/lite/delegates/gpu/common:status",
],
)

View File

@ -40,7 +40,8 @@ GPUResources BufferDescriptor::GetGPUResources() const {
}
absl::Status BufferDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const GpuInfo& gpu_info, const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Read") {
return PerformReadSelector(args, result);

View File

@ -42,7 +42,8 @@ struct BufferDescriptor : public GPUObjectDescriptor {
BufferDescriptor(BufferDescriptor&& desc) = default;
BufferDescriptor& operator=(BufferDescriptor&& desc) = default;
absl::Status PerformSelector(const std::string& selector,
absl::Status PerformSelector(const GpuInfo& gpu_info,
const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
@ -117,7 +118,8 @@ class GPUObjectDescriptor {
}
virtual absl::Status PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const GpuInfo& gpu_info, const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const {
*result = "";

View File

@ -164,7 +164,8 @@ GPUResources TensorDescriptor::GetGPUResources() const {
}
absl::Status TensorDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const GpuInfo& gpu_info, const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Width") {
*result = GetWidth();

View File

@ -59,7 +59,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
bool operator!=(const TensorDescriptor& d) const { return !(*this == d); }
absl::Status PerformSelector(const std::string& selector,
absl::Status PerformSelector(const GpuInfo& gpu_info,
const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override;

View File

@ -46,13 +46,14 @@ GPUResources TensorLinearDescriptor::GetGPUResources() const {
}
absl::Status TensorLinearDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const GpuInfo& gpu_info, const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Length") {
*result = "length";
return absl::OkStatus();
} else if (selector == "Read") {
return PerformReadSelector(args, result);
return PerformReadSelector(gpu_info, args, result);
} else if (selector == "GetPtr") {
if (storage_type != LinearStorageType::BUFFER) {
return absl::InvalidArgumentError(
@ -67,7 +68,8 @@ absl::Status TensorLinearDescriptor::PerformSelector(
}
absl::Status TensorLinearDescriptor::PerformReadSelector(
const std::vector<std::string>& args, std::string* result) const {
const GpuInfo& gpu_info, const std::vector<std::string>& args,
std::string* result) const {
if (args.size() != 1) {
return absl::NotFoundError(
absl::StrCat("TensorLinearDescriptor Read require one argument, but ",
@ -77,10 +79,19 @@ absl::Status TensorLinearDescriptor::PerformReadSelector(
*result = absl::StrCat("buffer[", args[0], "]");
return absl::OkStatus();
} else {
const std::string read =
element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
return absl::OkStatus();
if (gpu_info.IsApiMetal()) {
*result = absl::StrCat("tex2d.read(ushort2(", args[0], ", 0))");
return absl::OkStatus();
} else if (gpu_info.IsApiOpenCl()) {
const std::string read =
element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
*result =
absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
return absl::OkStatus();
} else {
return absl::UnimplementedError(
"No implementation of TensorLinear.Read for this API.");
}
}
}

View File

@ -48,13 +48,15 @@ struct TensorLinearDescriptor : public GPUObjectDescriptor {
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src,
int aligned_size = 0);
absl::Status PerformSelector(const std::string& selector,
absl::Status PerformSelector(const GpuInfo& gpu_info,
const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override;
GPUResources GetGPUResources() const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args,
absl::Status PerformReadSelector(const GpuInfo& gpu_info,
const std::vector<std::string>& args,
std::string* result) const;
void Release() override;

View File

@ -32,10 +32,11 @@ GPUResources Texture2DDescriptor::GetGPUResources() const {
}
absl::Status Texture2DDescriptor::PerformSelector(
const std::string& selector, const std::vector<std::string>& args,
const GpuInfo& gpu_info, const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
if (selector == "Read") {
return PerformReadSelector(args, result);
return PerformReadSelector(gpu_info, args, result);
} else {
return absl::NotFoundError(absl::StrCat(
"Texture2DDescriptor don't have selector with name - ", selector));
@ -43,47 +44,57 @@ absl::Status Texture2DDescriptor::PerformSelector(
}
absl::Status Texture2DDescriptor::PerformReadSelector(
const std::vector<std::string>& args, std::string* result) const {
const GpuInfo& gpu_info, const std::vector<std::string>& args,
std::string* result) const {
if (args.size() != 2) {
return absl::NotFoundError(
absl::StrCat("Texture2DDescriptor Read require two arguments, but ",
args.size(), " was passed"));
}
std::string read;
switch (element_type) {
case DataType::FLOAT32:
read = "read_imagef";
break;
case DataType::FLOAT16:
read = "read_imageh";
break;
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
if (normalized) {
read = normalized_type == DataType::FLOAT16 ? "read_imageh"
: "read_imagef";
} else {
read = "read_imagei";
}
break;
case DataType::UINT8:
case DataType::UINT16:
case DataType::UINT32:
if (normalized) {
read = normalized_type == DataType::FLOAT16 ? "read_imageh"
: "read_imagef";
} else {
read = "read_imageui";
}
break;
default:
read = "unknown_type";
break;
if (gpu_info.IsApiMetal()) {
*result =
absl::StrCat("tex2d.read(ushort2(", args[0], ", " + args[1] + "))");
return absl::OkStatus();
} else if (gpu_info.IsApiOpenCl()) {
std::string read;
switch (element_type) {
case DataType::FLOAT32:
read = "read_imagef";
break;
case DataType::FLOAT16:
read = "read_imageh";
break;
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
if (normalized) {
read = normalized_type == DataType::FLOAT16 ? "read_imageh"
: "read_imagef";
} else {
read = "read_imagei";
}
break;
case DataType::UINT8:
case DataType::UINT16:
case DataType::UINT32:
if (normalized) {
read = normalized_type == DataType::FLOAT16 ? "read_imageh"
: "read_imagef";
} else {
read = "read_imageui";
}
break;
default:
read = "unknown_type";
break;
}
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0],
", " + args[1] + "))");
return absl::OkStatus();
} else {
return absl::UnimplementedError(
"No implementation of Texture2D.Read for this API.");
}
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0],
", " + args[1] + "))");
return absl::OkStatus();
}
} // namespace gpu

View File

@ -41,13 +41,15 @@ struct Texture2DDescriptor : public GPUObjectDescriptor {
Texture2DDescriptor(Texture2DDescriptor&& desc) = default;
Texture2DDescriptor& operator=(Texture2DDescriptor&& desc) = default;
absl::Status PerformSelector(const std::string& selector,
absl::Status PerformSelector(const GpuInfo& gpu_info,
const std::string& selector,
const std::vector<std::string>& args,
const std::vector<std::string>& template_args,
std::string* result) const override;
GPUResources GetGPUResources() const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args,
absl::Status PerformReadSelector(const GpuInfo& gpu_info,
const std::vector<std::string>& args,
std::string* result) const;
void Release() override;

View File

@ -195,6 +195,7 @@ objc_library(
deps = [
":buffer",
":gpu_object",
":metal_device",
":metal_spatial_tensor",
":texture2d",
"//tensorflow/lite/delegates/gpu/common:status",

View File

@ -43,8 +43,7 @@ absl::Status ComputeTask::Compile(CalculationsPrecision precision,
task_desc_->AssembleCode();
const std::map<std::string, std::string> linkables = {
{task_desc_->dst_tensors_names[0], task_desc_->elementwise_code}};
RETURN_IF_ERROR(metal_args_.Init(device->device(), linkables,
&task_desc_->args,
RETURN_IF_ERROR(metal_args_.Init(linkables, device, &task_desc_->args,
&task_desc_->shader_source));
task_desc_->args.ReleaseCPURepresentation();
NSString* barrier;

View File

@ -155,11 +155,12 @@ absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
constexpr char MetalArguments::kArgsPrefix[];
absl::Status MetalArguments::Init(
id<MTLDevice> device, const std::map<std::string, std::string>& linkables,
const std::map<std::string, std::string>& linkables, MetalDevice* device,
Arguments* args, std::string* code) {
RETURN_IF_ERROR(AllocateObjects(*args, device));
RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
RETURN_IF_ERROR(AddObjectArgs(args));
RETURN_IF_ERROR(ResolveSelectorsPass(*args, linkables, code));
RETURN_IF_ERROR(
ResolveSelectorsPass(device->GetInfo(), *args, linkables, code));
object_refs_ = std::move(args->object_refs_);
args->GetActiveArguments(kArgsPrefix, *code);
std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
@ -445,8 +446,8 @@ absl::Status MetalArguments::SetBuffer(const std::string& name,
}
absl::Status MetalArguments::ResolveSelectorsPass(
const Arguments& args, const std::map<std::string, std::string>& linkables,
std::string* code) {
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables, std::string* code) {
std::string result;
size_t position = 0;
size_t next_position = code->find(kArgsPrefix);
@ -477,10 +478,10 @@ absl::Status MetalArguments::ResolveSelectorsPass(
RETURN_IF_ERROR(ParseArgsInsideBrackets(
*code, next_position, &close_bracket_pos, &function_args));
for (auto& arg : function_args) {
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, &arg));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
}
std::string patch;
RETURN_IF_ERROR(ResolveSelector(args, linkables, object_name,
RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
selector_name, function_args,
template_args, &patch));
code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
@ -494,7 +495,8 @@ absl::Status MetalArguments::ResolveSelectorsPass(
}
absl::Status MetalArguments::ResolveSelector(
const Arguments& args, const std::map<std::string, std::string>& linkables,
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables,
const std::string& object_name, const std::string& selector,
const std::vector<std::string>& function_args,
const std::vector<std::string>& template_args, std::string* result) {
@ -529,14 +531,14 @@ absl::Status MetalArguments::ResolveSelector(
ReplaceAllWords("X_COORD", x_coord, result);
ReplaceAllWords("Y_COORD", y_coord, result);
ReplaceAllWords("S_COORD", s_coord, result);
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, result));
RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
if (selector == "Linking") {
return absl::OkStatus();
}
}
}
std::string patch;
RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, function_args,
RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
template_args, &patch));
ResolveObjectNames(object_name, names, &patch);
*result += patch;

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
#include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
#include "tensorflow/lite/delegates/gpu/metal/metal_device.h"
namespace tflite {
namespace gpu {
@ -34,9 +35,8 @@ class MetalArguments : public ArgumentsBinder {
public:
MetalArguments() = default;
absl::Status Init(id<MTLDevice> device,
const std::map<std::string, std::string>& linkables,
Arguments* args, std::string* code);
absl::Status Init(const std::map<std::string, std::string>& linkables,
MetalDevice* device, Arguments* args, std::string* code);
// Move only
MetalArguments(MetalArguments&& args) = default;
@ -88,11 +88,11 @@ class MetalArguments : public ArgumentsBinder {
absl::Status SetObjectsResources(const Arguments& args);
absl::Status ResolveSelectorsPass(
const Arguments& args,
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables, std::string* code);
absl::Status ResolveSelector(
const Arguments& args,
const GpuInfo& gpu_info, const Arguments& args,
const std::map<std::string, std::string>& linkables,
const std::string& object_name, const std::string& selector,
const std::vector<std::string>& function_args,