Removed virtual method GPUObjectDescriptor* GetGPUDescriptor() for GPUObject.
PiperOrigin-RevId: 317672008 Change-Id: Iae9189569290c5d50dd1d2645dceac5385b2641c
This commit is contained in:
parent
8498c64d77
commit
e25fcc8393
|
@ -221,8 +221,9 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
void Arguments::AddObject(const std::string& name, AccessType access_type,
|
void Arguments::AddObject(const std::string& name, AccessType access_type,
|
||||||
GPUObjectPtr&& object) {
|
GPUObjectPtr&& object,
|
||||||
objects_[name] = {access_type, std::move(object)};
|
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||||
|
objects_[name] = {access_type, std::move(object), std::move(descriptor_ptr)};
|
||||||
}
|
}
|
||||||
|
|
||||||
void Arguments::AddGPUResources(const std::string& name,
|
void Arguments::AddGPUResources(const std::string& name,
|
||||||
|
@ -411,7 +412,8 @@ absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
absl::StrCat("Object name collision. Name - ", name));
|
absl::StrCat("Object name collision. Name - ", name));
|
||||||
}
|
}
|
||||||
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr)};
|
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr),
|
||||||
|
std::move(v.second.descriptor)};
|
||||||
}
|
}
|
||||||
for (const auto& v : args.int_values_) {
|
for (const auto& v : args.int_values_) {
|
||||||
AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
|
AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
|
||||||
|
@ -677,7 +679,7 @@ absl::Status Arguments::ResolveSelector(
|
||||||
desc_ptr = it->second.descriptor.get();
|
desc_ptr = it->second.descriptor.get();
|
||||||
access_type = it->second.access_type;
|
access_type = it->second.access_type;
|
||||||
} else if (auto it = objects_.find(object_name); it != objects_.end()) {
|
} else if (auto it = objects_.find(object_name); it != objects_.end()) {
|
||||||
desc_ptr = it->second.obj_ptr->GetGPUDescriptor();
|
desc_ptr = it->second.descriptor.get();
|
||||||
access_type = it->second.access_type;
|
access_type = it->second.access_type;
|
||||||
} else {
|
} else {
|
||||||
return absl::NotFoundError(
|
return absl::NotFoundError(
|
||||||
|
@ -760,8 +762,7 @@ absl::Status Arguments::ResolveSelectorsPass(
|
||||||
absl::Status Arguments::AddObjectArgs() {
|
absl::Status Arguments::AddObjectArgs() {
|
||||||
for (auto& t : objects_) {
|
for (auto& t : objects_) {
|
||||||
AddGPUResources(t.first,
|
AddGPUResources(t.first,
|
||||||
t.second.obj_ptr->GetGPUDescriptor()->GetGPUResources(
|
t.second.descriptor->GetGPUResources(t.second.access_type));
|
||||||
t.second.access_type));
|
|
||||||
RETURN_IF_ERROR(SetGPUResources(
|
RETURN_IF_ERROR(SetGPUResources(
|
||||||
t.first, t.second.obj_ptr->GetGPUResources(t.second.access_type)));
|
t.first, t.second.obj_ptr->GetGPUResources(t.second.access_type)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,8 @@ class Arguments {
|
||||||
void AddObjectRef(const std::string& name, AccessType access_type,
|
void AddObjectRef(const std::string& name, AccessType access_type,
|
||||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||||
void AddObject(const std::string& name, AccessType access_type,
|
void AddObject(const std::string& name, AccessType access_type,
|
||||||
GPUObjectPtr&& object);
|
GPUObjectPtr&& object,
|
||||||
|
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||||
|
|
||||||
absl::Status SetInt(const std::string& name, int value);
|
absl::Status SetInt(const std::string& name, int value);
|
||||||
absl::Status SetFloat(const std::string& name, float value);
|
absl::Status SetFloat(const std::string& name, float value);
|
||||||
|
@ -162,6 +163,7 @@ class Arguments {
|
||||||
struct ObjectArg {
|
struct ObjectArg {
|
||||||
AccessType access_type;
|
AccessType access_type;
|
||||||
GPUObjectPtr obj_ptr;
|
GPUObjectPtr obj_ptr;
|
||||||
|
GPUObjectDescriptorPtr descriptor;
|
||||||
};
|
};
|
||||||
std::map<std::string, ObjectArg> objects_;
|
std::map<std::string, ObjectArg> objects_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -149,7 +149,6 @@ class GPUObject {
|
||||||
GPUObject(const GPUObject&) = delete;
|
GPUObject(const GPUObject&) = delete;
|
||||||
GPUObject& operator=(const GPUObject&) = delete;
|
GPUObject& operator=(const GPUObject&) = delete;
|
||||||
virtual ~GPUObject() = default;
|
virtual ~GPUObject() = default;
|
||||||
virtual const GPUObjectDescriptor* GetGPUDescriptor() const = 0;
|
|
||||||
virtual GPUResourcesWithValue GetGPUResources(
|
virtual GPUResourcesWithValue GetGPUResources(
|
||||||
AccessType access_type) const = 0;
|
AccessType access_type) const = 0;
|
||||||
};
|
};
|
||||||
|
|
|
@ -69,17 +69,17 @@ absl::Status CreatePReLU(const CreationContext& creation_context,
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
absl::Status PReLU::UploadParameters(
|
absl::Status PReLU::UploadParameters(
|
||||||
const tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
const tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
||||||
LinearStorageCreateInfo create_info;
|
TensorLinearDescriptor desc;
|
||||||
create_info.storage_type =
|
desc.storage_type =
|
||||||
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
||||||
create_info.data_type = definition_.GetPrimaryDataType();
|
desc.element_type = definition_.GetPrimaryDataType();
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, &alpha_));
|
||||||
CreateLinearStorage(create_info, parameters, context, &alpha_));
|
|
||||||
|
|
||||||
LinearStorage lt;
|
LinearStorage lt;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, parameters, context, <));
|
RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, <));
|
||||||
args_.AddObject("alpha", AccessType::READ,
|
args_.AddObject("alpha", AccessType::READ,
|
||||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||||
|
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -378,14 +378,15 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) {
|
||||||
bt_aligned.data[y * 8 + 7] = 0.0f;
|
bt_aligned.data[y * 8 + 7] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
LinearStorageCreateInfo create_info;
|
TensorLinearDescriptor desc;
|
||||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
desc.storage_type = LinearStorageType::TEXTURE_2D;
|
||||||
create_info.data_type = definition_.GetDataType();
|
desc.element_type = definition_.GetDataType();
|
||||||
|
|
||||||
LinearStorage lt;
|
LinearStorage lt;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, <));
|
RETURN_IF_ERROR(CreateLinearStorage(desc, bt_aligned, context, <));
|
||||||
args_.AddObject("bt", AccessType::READ,
|
args_.AddObject("bt", AccessType::READ,
|
||||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||||
|
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -492,13 +493,14 @@ absl::Status Winograd36To4x4::UploadAt(CLContext* context) {
|
||||||
at_aligned.data[y * 8 + 7] = 0.0f;
|
at_aligned.data[y * 8 + 7] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
LinearStorageCreateInfo create_info;
|
TensorLinearDescriptor desc;
|
||||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
desc.storage_type = LinearStorageType::TEXTURE_2D;
|
||||||
create_info.data_type = definition_.GetDataType();
|
desc.element_type = definition_.GetDataType();
|
||||||
LinearStorage lt;
|
LinearStorage lt;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, at_aligned, context, <));
|
RETURN_IF_ERROR(CreateLinearStorage(desc, at_aligned, context, <));
|
||||||
args_.AddObject("at", AccessType::READ,
|
args_.AddObject("at", AccessType::READ,
|
||||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||||
|
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -550,14 +552,15 @@ absl::Status CreateWinograd36To4x4(
|
||||||
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
|
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
|
||||||
Winograd36To4x4* result) {
|
Winograd36To4x4* result) {
|
||||||
*result = Winograd36To4x4(definition);
|
*result = Winograd36To4x4(definition);
|
||||||
LinearStorageCreateInfo create_info;
|
TensorLinearDescriptor desc;
|
||||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
desc.storage_type = LinearStorageType::TEXTURE_2D;
|
||||||
create_info.data_type = definition.GetDataType();
|
desc.element_type = definition.GetDataType();
|
||||||
LinearStorage lt;
|
LinearStorage lt;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
CreateLinearStorage(create_info, biases, creation_context.context, <));
|
CreateLinearStorage(desc, biases, creation_context.context, <));
|
||||||
result->args_.AddObject("biases", AccessType::READ,
|
result->args_.AddObject("biases", AccessType::READ,
|
||||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||||
|
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||||
return result->UploadAt(creation_context.context);
|
return result->UploadAt(creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,10 +76,7 @@ absl::Status TensorLinearDescriptor::PerformReadSelector(
|
||||||
|
|
||||||
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
|
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
|
||||||
DataType data_type)
|
DataType data_type)
|
||||||
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {
|
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {}
|
||||||
desc_.storage_type = storage_type;
|
|
||||||
desc_.element_type = data_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
LinearStorage::LinearStorage(LinearStorage&& storage)
|
LinearStorage::LinearStorage(LinearStorage&& storage)
|
||||||
: GPUObject(std::move(storage)),
|
: GPUObject(std::move(storage)),
|
||||||
|
@ -89,8 +86,7 @@ LinearStorage::LinearStorage(LinearStorage&& storage)
|
||||||
depth_(storage.depth_),
|
depth_(storage.depth_),
|
||||||
name_(std::move(storage.name_)),
|
name_(std::move(storage.name_)),
|
||||||
storage_type_(storage.storage_type_),
|
storage_type_(storage.storage_type_),
|
||||||
data_type_(storage.data_type_),
|
data_type_(storage.data_type_) {
|
||||||
desc_(storage.desc_) {
|
|
||||||
storage.memory_ = nullptr;
|
storage.memory_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,7 +99,6 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) {
|
||||||
name_ = std::move(storage.name_);
|
name_ = std::move(storage.name_);
|
||||||
std::swap(storage_type_, storage.storage_type_);
|
std::swap(storage_type_, storage.storage_type_);
|
||||||
std::swap(data_type_, storage.data_type_);
|
std::swap(data_type_, storage.data_type_);
|
||||||
desc_ = storage.desc_;
|
|
||||||
GPUObject::operator=(std::move(storage));
|
GPUObject::operator=(std::move(storage));
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
|
|
|
@ -92,9 +92,6 @@ class LinearStorage : public GPUObject {
|
||||||
std::string ReadLinearFLT4(const std::string& z_coord) const;
|
std::string ReadLinearFLT4(const std::string& z_coord) const;
|
||||||
std::string GetDeclaration() const;
|
std::string GetDeclaration() const;
|
||||||
|
|
||||||
const GPUObjectDescriptor* GetGPUDescriptor() const override {
|
|
||||||
return &desc_;
|
|
||||||
}
|
|
||||||
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
|
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -115,7 +112,6 @@ class LinearStorage : public GPUObject {
|
||||||
std::string name_;
|
std::string name_;
|
||||||
LinearStorageType storage_type_;
|
LinearStorageType storage_type_;
|
||||||
DataType data_type_;
|
DataType data_type_;
|
||||||
TensorLinearDescriptor desc_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
|
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
|
||||||
|
@ -152,6 +148,31 @@ absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <DataType T>
|
||||||
|
absl::Status CreateLinearStorage(const TensorLinearDescriptor& descriptor,
|
||||||
|
const tflite::gpu::Tensor<Linear, T>& tensor,
|
||||||
|
CLContext* context, LinearStorage* result) {
|
||||||
|
LinearStorageCreateInfo creation_info;
|
||||||
|
creation_info.storage_type = descriptor.storage_type;
|
||||||
|
creation_info.data_type = descriptor.element_type;
|
||||||
|
int size = creation_info.aligned_size != 0 ? creation_info.aligned_size
|
||||||
|
: tensor.shape.v;
|
||||||
|
const int depth = DivideRoundUp(size, 4);
|
||||||
|
if (creation_info.data_type == DataType::FLOAT32) {
|
||||||
|
std::vector<float4> gpu_data(depth);
|
||||||
|
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
|
||||||
|
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
|
||||||
|
context, result));
|
||||||
|
} else {
|
||||||
|
std::vector<half4> gpu_data(depth);
|
||||||
|
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
|
||||||
|
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
|
||||||
|
context, result));
|
||||||
|
}
|
||||||
|
result->SetName(creation_info.name);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
|
@ -58,9 +58,6 @@ class Tensor : public GPUObject {
|
||||||
|
|
||||||
virtual ~Tensor() { Release(); }
|
virtual ~Tensor() { Release(); }
|
||||||
|
|
||||||
const GPUObjectDescriptor* GetGPUDescriptor() const override {
|
|
||||||
return &descriptor_;
|
|
||||||
}
|
|
||||||
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
|
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
|
||||||
|
|
||||||
int Width() const { return shape_.w; }
|
int Width() const { return shape_.w; }
|
||||||
|
|
Loading…
Reference in New Issue