Removed virtual method GPUObjectDescriptor* GetGPUDescriptor() for GPUObject.

PiperOrigin-RevId: 317672008
Change-Id: Iae9189569290c5d50dd1d2645dceac5385b2641c
This commit is contained in:
Raman Sarokin 2020-06-22 09:41:48 -07:00 committed by TensorFlower Gardener
parent 8498c64d77
commit e25fcc8393
8 changed files with 62 additions and 44 deletions

View File

@ -221,8 +221,9 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
} }
void Arguments::AddObject(const std::string& name, AccessType access_type, void Arguments::AddObject(const std::string& name, AccessType access_type,
GPUObjectPtr&& object) { GPUObjectPtr&& object,
objects_[name] = {access_type, std::move(object)}; GPUObjectDescriptorPtr&& descriptor_ptr) {
objects_[name] = {access_type, std::move(object), std::move(descriptor_ptr)};
} }
void Arguments::AddGPUResources(const std::string& name, void Arguments::AddGPUResources(const std::string& name,
@ -411,7 +412,8 @@ absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Object name collision. Name - ", name)); absl::StrCat("Object name collision. Name - ", name));
} }
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr)}; objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr),
std::move(v.second.descriptor)};
} }
for (const auto& v : args.int_values_) { for (const auto& v : args.int_values_) {
AddInt(RenameArg(object_names, postfix, v.first), v.second.value); AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
@ -677,7 +679,7 @@ absl::Status Arguments::ResolveSelector(
desc_ptr = it->second.descriptor.get(); desc_ptr = it->second.descriptor.get();
access_type = it->second.access_type; access_type = it->second.access_type;
} else if (auto it = objects_.find(object_name); it != objects_.end()) { } else if (auto it = objects_.find(object_name); it != objects_.end()) {
desc_ptr = it->second.obj_ptr->GetGPUDescriptor(); desc_ptr = it->second.descriptor.get();
access_type = it->second.access_type; access_type = it->second.access_type;
} else { } else {
return absl::NotFoundError( return absl::NotFoundError(
@ -760,8 +762,7 @@ absl::Status Arguments::ResolveSelectorsPass(
absl::Status Arguments::AddObjectArgs() { absl::Status Arguments::AddObjectArgs() {
for (auto& t : objects_) { for (auto& t : objects_) {
AddGPUResources(t.first, AddGPUResources(t.first,
t.second.obj_ptr->GetGPUDescriptor()->GetGPUResources( t.second.descriptor->GetGPUResources(t.second.access_type));
t.second.access_type));
RETURN_IF_ERROR(SetGPUResources( RETURN_IF_ERROR(SetGPUResources(
t.first, t.second.obj_ptr->GetGPUResources(t.second.access_type))); t.first, t.second.obj_ptr->GetGPUResources(t.second.access_type)));
} }

View File

@ -50,7 +50,8 @@ class Arguments {
void AddObjectRef(const std::string& name, AccessType access_type, void AddObjectRef(const std::string& name, AccessType access_type,
GPUObjectDescriptorPtr&& descriptor_ptr); GPUObjectDescriptorPtr&& descriptor_ptr);
void AddObject(const std::string& name, AccessType access_type, void AddObject(const std::string& name, AccessType access_type,
GPUObjectPtr&& object); GPUObjectPtr&& object,
GPUObjectDescriptorPtr&& descriptor_ptr);
absl::Status SetInt(const std::string& name, int value); absl::Status SetInt(const std::string& name, int value);
absl::Status SetFloat(const std::string& name, float value); absl::Status SetFloat(const std::string& name, float value);
@ -162,6 +163,7 @@ class Arguments {
struct ObjectArg { struct ObjectArg {
AccessType access_type; AccessType access_type;
GPUObjectPtr obj_ptr; GPUObjectPtr obj_ptr;
GPUObjectDescriptorPtr descriptor;
}; };
std::map<std::string, ObjectArg> objects_; std::map<std::string, ObjectArg> objects_;
}; };

View File

@ -149,7 +149,6 @@ class GPUObject {
GPUObject(const GPUObject&) = delete; GPUObject(const GPUObject&) = delete;
GPUObject& operator=(const GPUObject&) = delete; GPUObject& operator=(const GPUObject&) = delete;
virtual ~GPUObject() = default; virtual ~GPUObject() = default;
virtual const GPUObjectDescriptor* GetGPUDescriptor() const = 0;
virtual GPUResourcesWithValue GetGPUResources( virtual GPUResourcesWithValue GetGPUResources(
AccessType access_type) const = 0; AccessType access_type) const = 0;
}; };

View File

@ -69,17 +69,17 @@ absl::Status CreatePReLU(const CreationContext& creation_context,
template <DataType T> template <DataType T>
absl::Status PReLU::UploadParameters( absl::Status PReLU::UploadParameters(
const tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) { const tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
LinearStorageCreateInfo create_info; TensorLinearDescriptor desc;
create_info.storage_type = desc.storage_type =
DeduceLinearStorageType(definition_.GetPrimaryStorageType()); DeduceLinearStorageType(definition_.GetPrimaryStorageType());
create_info.data_type = definition_.GetPrimaryDataType(); desc.element_type = definition_.GetPrimaryDataType();
RETURN_IF_ERROR( RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, &alpha_));
CreateLinearStorage(create_info, parameters, context, &alpha_));
LinearStorage lt; LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, parameters, context, &lt)); RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, &lt));
args_.AddObject("alpha", AccessType::READ, args_.AddObject("alpha", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt))); absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -378,14 +378,15 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) {
bt_aligned.data[y * 8 + 7] = 0.0f; bt_aligned.data[y * 8 + 7] = 0.0f;
} }
LinearStorageCreateInfo create_info; TensorLinearDescriptor desc;
create_info.storage_type = LinearStorageType::TEXTURE_2D; desc.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition_.GetDataType(); desc.element_type = definition_.GetDataType();
LinearStorage lt; LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, &lt)); RETURN_IF_ERROR(CreateLinearStorage(desc, bt_aligned, context, &lt));
args_.AddObject("bt", AccessType::READ, args_.AddObject("bt", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt))); absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -492,13 +493,14 @@ absl::Status Winograd36To4x4::UploadAt(CLContext* context) {
at_aligned.data[y * 8 + 7] = 0.0f; at_aligned.data[y * 8 + 7] = 0.0f;
} }
LinearStorageCreateInfo create_info; TensorLinearDescriptor desc;
create_info.storage_type = LinearStorageType::TEXTURE_2D; desc.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition_.GetDataType(); desc.element_type = definition_.GetDataType();
LinearStorage lt; LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, at_aligned, context, &lt)); RETURN_IF_ERROR(CreateLinearStorage(desc, at_aligned, context, &lt));
args_.AddObject("at", AccessType::READ, args_.AddObject("at", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt))); absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -550,14 +552,15 @@ absl::Status CreateWinograd36To4x4(
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases, const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
Winograd36To4x4* result) { Winograd36To4x4* result) {
*result = Winograd36To4x4(definition); *result = Winograd36To4x4(definition);
LinearStorageCreateInfo create_info; TensorLinearDescriptor desc;
create_info.storage_type = LinearStorageType::TEXTURE_2D; desc.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition.GetDataType(); desc.element_type = definition.GetDataType();
LinearStorage lt; LinearStorage lt;
RETURN_IF_ERROR( RETURN_IF_ERROR(
CreateLinearStorage(create_info, biases, creation_context.context, &lt)); CreateLinearStorage(desc, biases, creation_context.context, &lt));
result->args_.AddObject("biases", AccessType::READ, result->args_.AddObject("biases", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt))); absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return result->UploadAt(creation_context.context); return result->UploadAt(creation_context.context);
} }

View File

@ -76,10 +76,7 @@ absl::Status TensorLinearDescriptor::PerformReadSelector(
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type, LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
DataType data_type) DataType data_type)
: depth_(depth), storage_type_(storage_type), data_type_(data_type) { : depth_(depth), storage_type_(storage_type), data_type_(data_type) {}
desc_.storage_type = storage_type;
desc_.element_type = data_type;
}
LinearStorage::LinearStorage(LinearStorage&& storage) LinearStorage::LinearStorage(LinearStorage&& storage)
: GPUObject(std::move(storage)), : GPUObject(std::move(storage)),
@ -89,8 +86,7 @@ LinearStorage::LinearStorage(LinearStorage&& storage)
depth_(storage.depth_), depth_(storage.depth_),
name_(std::move(storage.name_)), name_(std::move(storage.name_)),
storage_type_(storage.storage_type_), storage_type_(storage.storage_type_),
data_type_(storage.data_type_), data_type_(storage.data_type_) {
desc_(storage.desc_) {
storage.memory_ = nullptr; storage.memory_ = nullptr;
} }
@ -103,7 +99,6 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) {
name_ = std::move(storage.name_); name_ = std::move(storage.name_);
std::swap(storage_type_, storage.storage_type_); std::swap(storage_type_, storage.storage_type_);
std::swap(data_type_, storage.data_type_); std::swap(data_type_, storage.data_type_);
desc_ = storage.desc_;
GPUObject::operator=(std::move(storage)); GPUObject::operator=(std::move(storage));
} }
return *this; return *this;

View File

@ -92,9 +92,6 @@ class LinearStorage : public GPUObject {
std::string ReadLinearFLT4(const std::string& z_coord) const; std::string ReadLinearFLT4(const std::string& z_coord) const;
std::string GetDeclaration() const; std::string GetDeclaration() const;
const GPUObjectDescriptor* GetGPUDescriptor() const override {
return &desc_;
}
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override; GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
private: private:
@ -115,7 +112,6 @@ class LinearStorage : public GPUObject {
std::string name_; std::string name_;
LinearStorageType storage_type_; LinearStorageType storage_type_;
DataType data_type_; DataType data_type_;
TensorLinearDescriptor desc_;
}; };
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data, absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
@ -152,6 +148,31 @@ absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info,
return absl::OkStatus(); return absl::OkStatus();
} }
template <DataType T>
absl::Status CreateLinearStorage(const TensorLinearDescriptor& descriptor,
const tflite::gpu::Tensor<Linear, T>& tensor,
CLContext* context, LinearStorage* result) {
LinearStorageCreateInfo creation_info;
creation_info.storage_type = descriptor.storage_type;
creation_info.data_type = descriptor.element_type;
int size = creation_info.aligned_size != 0 ? creation_info.aligned_size
: tensor.shape.v;
const int depth = DivideRoundUp(size, 4);
if (creation_info.data_type == DataType::FLOAT32) {
std::vector<float4> gpu_data(depth);
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
context, result));
} else {
std::vector<half4> gpu_data(depth);
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
context, result));
}
result->SetName(creation_info.name);
return absl::OkStatus();
}
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -58,9 +58,6 @@ class Tensor : public GPUObject {
virtual ~Tensor() { Release(); } virtual ~Tensor() { Release(); }
const GPUObjectDescriptor* GetGPUDescriptor() const override {
return &descriptor_;
}
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override; GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
int Width() const { return shape_.w; } int Width() const { return shape_.w; }