Removed virtual method GPUObjectDescriptor* GetGPUDescriptor() for GPUObject.

PiperOrigin-RevId: 317672008
Change-Id: Iae9189569290c5d50dd1d2645dceac5385b2641c
This commit is contained in:
Raman Sarokin 2020-06-22 09:41:48 -07:00 committed by TensorFlower Gardener
parent 8498c64d77
commit e25fcc8393
8 changed files with 62 additions and 44 deletions

View File

@ -221,8 +221,9 @@ void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
}
void Arguments::AddObject(const std::string& name, AccessType access_type,
GPUObjectPtr&& object) {
objects_[name] = {access_type, std::move(object)};
GPUObjectPtr&& object,
GPUObjectDescriptorPtr&& descriptor_ptr) {
objects_[name] = {access_type, std::move(object), std::move(descriptor_ptr)};
}
void Arguments::AddGPUResources(const std::string& name,
@ -411,7 +412,8 @@ absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
return absl::InvalidArgumentError(
absl::StrCat("Object name collision. Name - ", name));
}
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr)};
objects_[name] = {v.second.access_type, std::move(v.second.obj_ptr),
std::move(v.second.descriptor)};
}
for (const auto& v : args.int_values_) {
AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
@ -677,7 +679,7 @@ absl::Status Arguments::ResolveSelector(
desc_ptr = it->second.descriptor.get();
access_type = it->second.access_type;
} else if (auto it = objects_.find(object_name); it != objects_.end()) {
desc_ptr = it->second.obj_ptr->GetGPUDescriptor();
desc_ptr = it->second.descriptor.get();
access_type = it->second.access_type;
} else {
return absl::NotFoundError(
@ -760,8 +762,7 @@ absl::Status Arguments::ResolveSelectorsPass(
absl::Status Arguments::AddObjectArgs() {
for (auto& t : objects_) {
AddGPUResources(t.first,
t.second.obj_ptr->GetGPUDescriptor()->GetGPUResources(
t.second.access_type));
t.second.descriptor->GetGPUResources(t.second.access_type));
RETURN_IF_ERROR(SetGPUResources(
t.first, t.second.obj_ptr->GetGPUResources(t.second.access_type)));
}

View File

@ -50,7 +50,8 @@ class Arguments {
void AddObjectRef(const std::string& name, AccessType access_type,
GPUObjectDescriptorPtr&& descriptor_ptr);
void AddObject(const std::string& name, AccessType access_type,
GPUObjectPtr&& object);
GPUObjectPtr&& object,
GPUObjectDescriptorPtr&& descriptor_ptr);
absl::Status SetInt(const std::string& name, int value);
absl::Status SetFloat(const std::string& name, float value);
@ -162,6 +163,7 @@ class Arguments {
struct ObjectArg {
AccessType access_type;
GPUObjectPtr obj_ptr;
GPUObjectDescriptorPtr descriptor;
};
std::map<std::string, ObjectArg> objects_;
};

View File

@ -149,7 +149,6 @@ class GPUObject {
GPUObject(const GPUObject&) = delete;
GPUObject& operator=(const GPUObject&) = delete;
virtual ~GPUObject() = default;
virtual const GPUObjectDescriptor* GetGPUDescriptor() const = 0;
virtual GPUResourcesWithValue GetGPUResources(
AccessType access_type) const = 0;
};

View File

@ -69,17 +69,17 @@ absl::Status CreatePReLU(const CreationContext& creation_context,
template <DataType T>
absl::Status PReLU::UploadParameters(
const tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
LinearStorageCreateInfo create_info;
create_info.storage_type =
TensorLinearDescriptor desc;
desc.storage_type =
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
create_info.data_type = definition_.GetPrimaryDataType();
RETURN_IF_ERROR(
CreateLinearStorage(create_info, parameters, context, &alpha_));
desc.element_type = definition_.GetPrimaryDataType();
RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, &alpha_));
LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, parameters, context, &lt));
RETURN_IF_ERROR(CreateLinearStorage(desc, parameters, context, &lt));
args_.AddObject("alpha", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt)));
absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return absl::OkStatus();
}

View File

@ -378,14 +378,15 @@ absl::Status Winograd4x4To36::UploadBt(CLContext* context) {
bt_aligned.data[y * 8 + 7] = 0.0f;
}
LinearStorageCreateInfo create_info;
create_info.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition_.GetDataType();
TensorLinearDescriptor desc;
desc.storage_type = LinearStorageType::TEXTURE_2D;
desc.element_type = definition_.GetDataType();
LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, bt_aligned, context, &lt));
RETURN_IF_ERROR(CreateLinearStorage(desc, bt_aligned, context, &lt));
args_.AddObject("bt", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt)));
absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return absl::OkStatus();
}
@ -492,13 +493,14 @@ absl::Status Winograd36To4x4::UploadAt(CLContext* context) {
at_aligned.data[y * 8 + 7] = 0.0f;
}
LinearStorageCreateInfo create_info;
create_info.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition_.GetDataType();
TensorLinearDescriptor desc;
desc.storage_type = LinearStorageType::TEXTURE_2D;
desc.element_type = definition_.GetDataType();
LinearStorage lt;
RETURN_IF_ERROR(CreateLinearStorage(create_info, at_aligned, context, &lt));
RETURN_IF_ERROR(CreateLinearStorage(desc, at_aligned, context, &lt));
args_.AddObject("at", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt)));
absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return absl::OkStatus();
}
@ -550,14 +552,15 @@ absl::Status CreateWinograd36To4x4(
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
Winograd36To4x4* result) {
*result = Winograd36To4x4(definition);
LinearStorageCreateInfo create_info;
create_info.storage_type = LinearStorageType::TEXTURE_2D;
create_info.data_type = definition.GetDataType();
TensorLinearDescriptor desc;
desc.storage_type = LinearStorageType::TEXTURE_2D;
desc.element_type = definition.GetDataType();
LinearStorage lt;
RETURN_IF_ERROR(
CreateLinearStorage(create_info, biases, creation_context.context, &lt));
CreateLinearStorage(desc, biases, creation_context.context, &lt));
result->args_.AddObject("biases", AccessType::READ,
absl::make_unique<LinearStorage>(std::move(lt)));
absl::make_unique<LinearStorage>(std::move(lt)),
absl::make_unique<TensorLinearDescriptor>(desc));
return result->UploadAt(creation_context.context);
}

View File

@ -76,10 +76,7 @@ absl::Status TensorLinearDescriptor::PerformReadSelector(
LinearStorage::LinearStorage(int depth, LinearStorageType storage_type,
DataType data_type)
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {
desc_.storage_type = storage_type;
desc_.element_type = data_type;
}
: depth_(depth), storage_type_(storage_type), data_type_(data_type) {}
LinearStorage::LinearStorage(LinearStorage&& storage)
: GPUObject(std::move(storage)),
@ -89,8 +86,7 @@ LinearStorage::LinearStorage(LinearStorage&& storage)
depth_(storage.depth_),
name_(std::move(storage.name_)),
storage_type_(storage.storage_type_),
data_type_(storage.data_type_),
desc_(storage.desc_) {
data_type_(storage.data_type_) {
storage.memory_ = nullptr;
}
@ -103,7 +99,6 @@ LinearStorage& LinearStorage::operator=(LinearStorage&& storage) {
name_ = std::move(storage.name_);
std::swap(storage_type_, storage.storage_type_);
std::swap(data_type_, storage.data_type_);
desc_ = storage.desc_;
GPUObject::operator=(std::move(storage));
}
return *this;

View File

@ -92,9 +92,6 @@ class LinearStorage : public GPUObject {
std::string ReadLinearFLT4(const std::string& z_coord) const;
std::string GetDeclaration() const;
const GPUObjectDescriptor* GetGPUDescriptor() const override {
return &desc_;
}
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
private:
@ -115,7 +112,6 @@ class LinearStorage : public GPUObject {
std::string name_;
LinearStorageType storage_type_;
DataType data_type_;
TensorLinearDescriptor desc_;
};
absl::Status CreateBufferLinearStorage(int size, DataType data_type, void* data,
@ -152,6 +148,31 @@ absl::Status CreateLinearStorage(const LinearStorageCreateInfo& creation_info,
return absl::OkStatus();
}
template <DataType T>
absl::Status CreateLinearStorage(const TensorLinearDescriptor& descriptor,
const tflite::gpu::Tensor<Linear, T>& tensor,
CLContext* context, LinearStorage* result) {
LinearStorageCreateInfo creation_info;
creation_info.storage_type = descriptor.storage_type;
creation_info.data_type = descriptor.element_type;
int size = creation_info.aligned_size != 0 ? creation_info.aligned_size
: tensor.shape.v;
const int depth = DivideRoundUp(size, 4);
if (creation_info.data_type == DataType::FLOAT32) {
std::vector<float4> gpu_data(depth);
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
context, result));
} else {
std::vector<half4> gpu_data(depth);
CopyLinearFLT4(tensor, absl::MakeSpan(gpu_data));
RETURN_IF_ERROR(CreateLinearStorage(creation_info, depth, gpu_data.data(),
context, result));
}
result->SetName(creation_info.name);
return absl::OkStatus();
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -58,9 +58,6 @@ class Tensor : public GPUObject {
virtual ~Tensor() { Release(); }
const GPUObjectDescriptor* GetGPUDescriptor() const override {
return &descriptor_;
}
GPUResourcesWithValue GetGPUResources(AccessType access_type) const override;
int Width() const { return shape_.w; }