Removed virtual method GPUObjectDescriptor* GetGPUDescriptor() for GPUObject.
PiperOrigin-RevId: 317672008 Change-Id: Iae9189569290c5d50dd1d2645dceac5385b2641c
This commit is contained in:
parent
8498c64d77
commit
e25fcc8393
|
@ -221,8 +221,9 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
|
|||
}
|
||||
|
||||
void Arguments::AddObject(const std::string& name, AccessType access_type,
|
||||
GPUObjectPtr&& object) {
|
||||
objects_[name] = {access_type, std::move(object)};
|
||||
GPUObjectPtr&& object,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr) {
|
||||
objects_[name] = {access_type, std::move(object), std::move(descriptor_ptr)};
|
||||
}
|
||||
|
||||
void Arguments::AddGPUResources(const std::string& name,
|
||||
|
@ -411,7 +412,8 @@ absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
|
|||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Object name collision. Name - ", name));
|
||||
}
|
||||
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr)};
|
||||
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr),
|
||||
std::move(v.second.descriptor)};
|
||||
}
|
||||
for (const auto& v : args.int_values_) {
|
||||
AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
|
||||
|
@ -677,7 +679,7 @@ absl::Status Arguments::ResolveSelector(
|
|||
desc_ptr = it->second.descriptor.get();
|
||||
access_type = it->second.access_type;
|
||||
} else if (auto it = objects_.find(object_name); it != objects_.end()) {
|
||||
desc_ptr = it->second.obj_ptr->GetGPUDescriptor();
|
||||
desc_ptr = it->second.descriptor.get();
|
||||
access_type = it->second.access_type;
|
||||
} else {
|
||||
return absl::NotFoundError(
|
||||
|
@ -760,8 +762,7 @@ absl::Status Arguments::ResolveSelectorsPass(
|
|||
absl::Status Arguments::AddObjectArgs() {
|
||||
for (auto& t : objects_) {
|
||||
AddGPUResources(t.first,
|
||||
t.second.obj_ptr->GetGPUDescriptor()->GetGPUResources(
|
||||
t.second.access_type));
|
||||
t.second.descriptor->GetGPUResources(t.second.access_type));
|
||||
RETURN_IF_ERROR(SetGPUResources(
|
||||
t.first, t.second.obj_ptr->GetGPUResources(t.second.access_type)));
|
||||
}
|
||||
|
|
|
@ -50,7 +50,8 @@ class Arguments {
|
|||
void AddObjectRef(const std::string& name, AccessType access_type,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||
void AddObject(const std::string& name, AccessType access_type,
|
||||
GPUObjectPtr&& object);
|
||||
GPUObjectPtr&& object,
|
||||
GPUObjectDescriptorPtr&& descriptor_ptr);
|
||||
|
||||
absl::Status SetInt(const std::string& name, int value);
|
||||
absl::Status SetFloat(const std::string& name, float value);
|
||||
|
@ -162,6 +163,7 @@ class Arguments {
|
|||
struct ObjectArg {
|
||||
AccessType access_type;
|
||||
GPUObjectPtr obj_ptr;
|
||||
GPUObjectDescriptorPtr descriptor;
|
||||
};
|
||||
std::map<std::string, ObjectArg> objects_;
|
||||
};
|
||||
|
|
|
@ -149,7 +149,6 @@ class GPUObject {
|
|||
GPUObject(const GPUObject&) = delete;
|
||||
GPUObject& operator=(const GPUObject&) = delete;
|
||||
virtual ~GPUObject() = default;
|
||||
virtual const GPUObjectDescriptor* GetGPUDescriptor() const = 0;
|
||||
virtual GPUResourcesWithValue GetGPUResources(
|
||||
AccessType access_type) const = 0;
|
||||
};
|
||||
|
|
|
@ -69,17 +69,17 @@ absl::Status CreatePReLU(const CreationContext& creation_context,
|
|||
template <DataType T>
|
||||
absl::Status PReLU::UploadParameters(
|
||||
const tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
||||
LinearStorageCreateInfo create_info;
|
||||
create_info.storage_type =
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type =
|
||||
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
||||
create_info.data_type = definition_.GetPrimaryDataType();
|
||||
RETURN_IF_ERROR(
|
||||
CreateLinearStorage(create_info, parameters, context, &alpha_));
|
||||
desc.element_type = definition_.GetPrimaryDataType();
|
||||
RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, &alpha_));
|
||||
|
||||
LinearStorage lt;
|
||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, parameters, context, <));
|
||||
RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, <));
|
||||
args_.AddObject("alpha", AccessType::READ,
|
||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
||||
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -378,14 +378,15 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) {
|
|||
bt_aligned.data[y * 8 + 7] = 0.0f;
|
||||
}
|
||||
|
||||
LinearStorageCreateInfo create_info;
|
||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
create_info.data_type = definition_.GetDataType();
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
desc.element_type = definition_.GetDataType();
|
||||
|
||||
LinearStorage lt;
|
||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, <));
|
||||
RETURN_IF_ERROR(CreateLinearStorage(desc, bt_aligned, context, <));
|
||||
args_.AddObject("bt", AccessType::READ,
|
||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
||||
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -492,13 +493,14 @@ absl::Status Winograd36To4x4::UploadAt(CLContext* context) {
|
|||
at_aligned.data[y * 8 + 7] = 0.0f;
|
||||
}
|
||||
|
||||
LinearStorageCreateInfo create_info;
|
||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
create_info.data_type = definition_.GetDataType();
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
desc.element_type = definition_.GetDataType();
|
||||
LinearStorage lt;
|
||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, at_aligned, context, <));
|
||||
RETURN_IF_ERROR(CreateLinearStorage(desc, at_aligned, context, <));
|
||||
args_.AddObject("at", AccessType::READ,
|
||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
||||
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -550,14 +552,15 @@ absl::Status CreateWinograd36To4x4(
|
|||
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
|
||||
Winograd36To4x4* result) {
|
||||
*result = Winograd36To4x4(definition);
|
||||
LinearStorageCreateInfo create_info;
|
||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
create_info.data_type = definition.GetDataType();
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
desc.element_type = definition.GetDataType();
|
||||
LinearStorage lt;
|
||||
RETURN_IF_ERROR(
|
||||
CreateLinearStorage(create_info, biases, creation_context.context, <));
|
||||
CreateLinearStorage(desc, biases, creation_context.context, <));
|
||||
result->args_.AddObject("biases", AccessType::READ,
|
||||
absl::make_unique<LinearStorage>(std::move(lt)));
|
||||
absl::make_unique<LinearStorage>(std::move(lt)),
|
||||
absl::make_unique<TensorLinearDescriptor>(desc));
|
||||
return result->UploadAt(creation_context.context);
|
||||
}
|
||||
|
||||
|
|
|
@ -76,10 +76,7 @@ absl::Status TensorLinearDescriptor::PerformReadSelector(
|
|||
|
||||
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
|
||||
DataType data_type)
|
||||
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {
|
||||
desc_.storage_type = storage_type;
|
||||
desc_.element_type = data_type;
|
||||
}
|
||||
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {}
|
||||
|
||||
LinearStorage::LinearStorage(LinearStorage&& storage)
|
||||
: GPUObject(std::move(storage)),
|
||||
|
@ -89,8 +86,7 @@ LinearStorage::LinearStorage(LinearStorage&& storage)
|
|||
depth_(storage.depth_),
|
||||
name_(std::move(storage.name_)),
|
||||
storage_type_(storage.storage_type_),
|
||||
data_type_(storage.data_type_),
|
||||
desc_(storage.desc_) {
|
||||
data_type_(storage.data_type_) {
|
||||
storage.memory_ = nullptr;
|
||||
}
|
||||
|
||||
|
@ -103,7 +99,6 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) {
|
|||
name_ = std::move(storage.name_);
|
||||
std::swap(storage_type_, storage.storage_type_);
|
||||
std::swap(data_type_, storage.data_type_);
|
||||
desc_ = storage.desc_;
|
||||
GPUObject::operator=(std::move(storage));
|
||||
}
|
||||
return *this;
|
||||
|
|
|
@ -92,9 +92,6 @@ class LinearStorage : public GPUObject {
|
|||
std::string ReadLinearFLT4(const std::string& z_coord) const;
|
||||
std::string GetDeclaration() const;
|
||||
|
||||
const GPUObjectDescriptor* GetGPUDescriptor() const override {
|
||||
return &desc_;
|
||||
}
|
||||
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
|
||||
|
||||
private:
|
||||
|
@ -115,7 +112,6 @@ class LinearStorage : public GPUObject {
|
|||
std::string name_;
|
||||
LinearStorageType storage_type_;
|
||||
DataType data_type_;
|
||||
TensorLinearDescriptor desc_;
|
||||
};
|
||||
|
||||
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
|
||||
|
@ -152,6 +148,31 @@ absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info,
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <DataType T>
|
||||
absl::Status CreateLinearStorage(const TensorLinearDescriptor& descriptor,
|
||||
const tflite::gpu::Tensor<Linear, T>& tensor,
|
||||
CLContext* context, LinearStorage* result) {
|
||||
LinearStorageCreateInfo creation_info;
|
||||
creation_info.storage_type = descriptor.storage_type;
|
||||
creation_info.data_type = descriptor.element_type;
|
||||
int size = creation_info.aligned_size != 0 ? creation_info.aligned_size
|
||||
: tensor.shape.v;
|
||||
const int depth = DivideRoundUp(size, 4);
|
||||
if (creation_info.data_type == DataType::FLOAT32) {
|
||||
std::vector<float4> gpu_data(depth);
|
||||
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
|
||||
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
|
||||
context, result));
|
||||
} else {
|
||||
std::vector<half4> gpu_data(depth);
|
||||
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
|
||||
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
|
||||
context, result));
|
||||
}
|
||||
result->SetName(creation_info.name);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
|
|
@ -58,9 +58,6 @@ class Tensor : public GPUObject {
|
|||
|
||||
virtual ~Tensor() { Release(); }
|
||||
|
||||
const GPUObjectDescriptor* GetGPUDescriptor() const override {
|
||||
return &descriptor_;
|
||||
}
|
||||
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
|
||||
|
||||
int Width() const { return shape_.w; }
|
||||
|
|
Loading…
Reference in New Issue