GPUObjectDescriptor moved to common gpu folder.
PiperOrigin-RevId: 340373222 Change-Id: I243e98d7ece3acbc5d744a5d620e3f6b6ace12ed
This commit is contained in:
parent
cb4df1e1d5
commit
cfd834264d
tensorflow/lite/delegates/gpu
@ -52,13 +52,13 @@ cc_library(
|
||||
srcs = ["arguments.cc"],
|
||||
hdrs = ["arguments.h"],
|
||||
deps = [
|
||||
":gpu_object_desc",
|
||||
":serialization_cc_fbs",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -373,22 +373,11 @@ cc_library(
|
||||
name = "gpu_object",
|
||||
hdrs = ["gpu_object.h"],
|
||||
deps = [
|
||||
":gpu_object_desc",
|
||||
":opencl_wrapper",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_object_desc",
|
||||
hdrs = ["gpu_object_desc.h"],
|
||||
deps = [
|
||||
":serialization_cc_fbs",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
],
|
||||
)
|
||||
|
||||
@ -513,6 +502,10 @@ flatbuffer_cc_library(
|
||||
srcs = ["serialization.fbs"],
|
||||
flatc_args = [
|
||||
"--scoped-enums",
|
||||
"-I ./",
|
||||
],
|
||||
includes = [
|
||||
"//tensorflow/lite/delegates/gpu/common/task:serialization_base_cc_fbs_includes",
|
||||
],
|
||||
)
|
||||
|
||||
@ -577,11 +570,11 @@ cc_library(
|
||||
srcs = ["tensor_type.cc"],
|
||||
hdrs = ["tensor_type.h"],
|
||||
deps = [
|
||||
":gpu_object_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -635,12 +628,12 @@ cc_library(
|
||||
srcs = ["util.cc"],
|
||||
hdrs = ["util.h"],
|
||||
deps = [
|
||||
":gpu_object_desc",
|
||||
":opencl_wrapper",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:gpu_object_desc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -20,10 +20,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
|
||||
|
@ -21,11 +21,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
@ -33,16 +33,16 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
data::AccessType ToFB(AccessType type) {
|
||||
tflite::gpu::data::AccessType ToFB(AccessType type) {
|
||||
switch (type) {
|
||||
case AccessType::READ:
|
||||
return data::AccessType::READ;
|
||||
return tflite::gpu::data::AccessType::READ;
|
||||
case AccessType::WRITE:
|
||||
return data::AccessType::WRITE;
|
||||
return tflite::gpu::data::AccessType::WRITE;
|
||||
case AccessType::READ_WRITE:
|
||||
return data::AccessType::READ_WRITE;
|
||||
return tflite::gpu::data::AccessType::READ_WRITE;
|
||||
default:
|
||||
return data::AccessType::READ_WRITE;
|
||||
return tflite::gpu::data::AccessType::READ_WRITE;
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,13 +165,13 @@ DataType ToEnum(data::DataType type) {
|
||||
}
|
||||
}
|
||||
|
||||
AccessType ToEnum(data::AccessType type) {
|
||||
AccessType ToEnum(tflite::gpu::data::AccessType type) {
|
||||
switch (type) {
|
||||
case data::AccessType::READ:
|
||||
case tflite::gpu::data::AccessType::READ:
|
||||
return AccessType::READ;
|
||||
case data::AccessType::WRITE:
|
||||
case tflite::gpu::data::AccessType::WRITE:
|
||||
return AccessType::WRITE;
|
||||
case data::AccessType::READ_WRITE:
|
||||
case tflite::gpu::data::AccessType::READ_WRITE:
|
||||
return AccessType::READ_WRITE;
|
||||
}
|
||||
}
|
||||
@ -292,25 +292,27 @@ flatbuffers::Offset<data::Int3> Encode(
|
||||
return int3_builder.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<data::GPUObjectDescriptor> Encode(
|
||||
flatbuffers::Offset<tflite::gpu::data::GPUObjectDescriptor> Encode(
|
||||
const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) {
|
||||
std::vector<flatbuffers::Offset<data::StateVariable>> state_vars_fb;
|
||||
std::vector<flatbuffers::Offset<tflite::gpu::data::StateVariable>>
|
||||
state_vars_fb;
|
||||
for (auto& v0 : desc.state_vars_) {
|
||||
auto key_fb = builder->CreateString(v0.first);
|
||||
auto value_fb = builder->CreateString(v0.second);
|
||||
data::StateVariableBuilder state_builder(*builder);
|
||||
tflite::gpu::data::StateVariableBuilder state_builder(*builder);
|
||||
state_builder.add_key(key_fb);
|
||||
state_builder.add_value(value_fb);
|
||||
state_vars_fb.push_back(state_builder.Finish());
|
||||
}
|
||||
auto state_vars_fb_vec = builder->CreateVector(state_vars_fb);
|
||||
data::GPUObjectDescriptorBuilder obj_builder(*builder);
|
||||
tflite::gpu::data::GPUObjectDescriptorBuilder obj_builder(*builder);
|
||||
obj_builder.add_state_vars(state_vars_fb_vec);
|
||||
obj_builder.add_access_type(ToFB(desc.access_type_));
|
||||
return obj_builder.Finish();
|
||||
}
|
||||
|
||||
void Decode(const data::GPUObjectDescriptor* fb_obj, GPUObjectDescriptor* obj) {
|
||||
void Decode(const tflite::gpu::data::GPUObjectDescriptor* fb_obj,
|
||||
GPUObjectDescriptor* obj) {
|
||||
obj->access_type_ = ToEnum(fb_obj->access_type());
|
||||
for (auto state_fb : *fb_obj->state_vars()) {
|
||||
std::string key(state_fb->key()->c_str(), state_fb->key()->size());
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
include "tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs";
|
||||
|
||||
namespace tflite.gpu.cl.data;
|
||||
|
||||
table Int4 {
|
||||
@ -50,12 +52,6 @@ table HalfValue {
|
||||
active:bool;
|
||||
}
|
||||
|
||||
enum AccessType : byte {
|
||||
READ = 0,
|
||||
WRITE = 1,
|
||||
READ_WRITE = 2,
|
||||
}
|
||||
|
||||
enum DataType : byte {
|
||||
UNKNOWN = 0,
|
||||
FLOAT32 = 1,
|
||||
@ -68,18 +64,8 @@ enum MemoryType : byte {
|
||||
LOCAL = 2,
|
||||
}
|
||||
|
||||
table StateVariable {
|
||||
key:string;
|
||||
value:string;
|
||||
}
|
||||
|
||||
table GPUObjectDescriptor {
|
||||
state_vars:[StateVariable];
|
||||
access_type:AccessType;
|
||||
}
|
||||
|
||||
table BufferDescriptor {
|
||||
base_obj:GPUObjectDescriptor;
|
||||
base_obj:tflite.gpu.data.GPUObjectDescriptor;
|
||||
element_type:DataType;
|
||||
element_size:int32;
|
||||
memory_type:MemoryType;
|
||||
@ -89,7 +75,7 @@ table BufferDescriptor {
|
||||
}
|
||||
|
||||
table Texture2DDescriptor {
|
||||
base_obj:GPUObjectDescriptor;
|
||||
base_obj:tflite.gpu.data.GPUObjectDescriptor;
|
||||
element_type:DataType;
|
||||
normalized:bool;
|
||||
normalized_type:DataType;
|
||||
@ -103,7 +89,7 @@ enum LinearStorageType : byte {
|
||||
}
|
||||
|
||||
table TensorLinearDescriptor {
|
||||
base_obj:GPUObjectDescriptor;
|
||||
base_obj:tflite.gpu.data.GPUObjectDescriptor;
|
||||
storage_type:LinearStorageType;
|
||||
element_type:DataType;
|
||||
memory_type:MemoryType;
|
||||
@ -138,7 +124,7 @@ table BHWDC {
|
||||
}
|
||||
|
||||
table TensorDescriptor {
|
||||
base_obj:GPUObjectDescriptor;
|
||||
base_obj:tflite.gpu.data.GPUObjectDescriptor;
|
||||
data_type:DataType;
|
||||
storage_type:TensorStorageType;
|
||||
layout:Layout;
|
||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -39,6 +39,7 @@ cc_binary(
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,10 +19,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
|
||||
|
25
tensorflow/lite/delegates/gpu/common/task/BUILD
Normal file
25
tensorflow/lite/delegates/gpu/common/task/BUILD
Normal file
@ -0,0 +1,25 @@
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_object_desc",
|
||||
hdrs = ["gpu_object_desc.h"],
|
||||
deps = [
|
||||
":serialization_base_cc_fbs",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "serialization_base_cc_fbs",
|
||||
srcs = ["serialization_base.fbs"],
|
||||
flatc_args = [
|
||||
"--scoped-enums",
|
||||
],
|
||||
)
|
@ -13,18 +13,18 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_DESC_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_DESC_H_
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OBJECT_DESC_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OBJECT_DESC_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/access_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -132,9 +132,9 @@ class GPUObjectDescriptor {
|
||||
AccessType GetAccess() const { return access_type_; }
|
||||
|
||||
protected:
|
||||
friend flatbuffers::Offset<data::GPUObjectDescriptor> Encode(
|
||||
friend flatbuffers::Offset<tflite::gpu::data::GPUObjectDescriptor> Encode(
|
||||
const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder);
|
||||
friend void Decode(const data::GPUObjectDescriptor* fb_obj,
|
||||
friend void Decode(const tflite::gpu::data::GPUObjectDescriptor* fb_obj,
|
||||
GPUObjectDescriptor* obj);
|
||||
mutable std::map<std::string, std::string> state_vars_;
|
||||
AccessType access_type_;
|
||||
@ -146,4 +146,4 @@ using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>;
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_DESC_H_
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_GPU_OBJECT_DESC_H_
|
@ -0,0 +1,31 @@
|
||||
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace tflite.gpu.data;
|
||||
|
||||
enum AccessType : byte {
|
||||
READ = 0,
|
||||
WRITE = 1,
|
||||
READ_WRITE = 2,
|
||||
}
|
||||
|
||||
table StateVariable {
|
||||
key:string;
|
||||
value:string;
|
||||
}
|
||||
|
||||
table GPUObjectDescriptor {
|
||||
state_vars:[StateVariable];
|
||||
access_type:AccessType;
|
||||
}
|
Loading…
Reference in New Issue
Block a user