TFLite GPU OpenGL: Add vector<float4> to Variable.

PiperOrigin-RevId: 260837358
This commit is contained in:
Juhyun Lee 2019-07-30 18:13:47 -07:00 committed by TensorFlower Gardener
parent 6b1dc74337
commit 9f3f6745ee
5 changed files with 92 additions and 2 deletions

View File

@ -65,6 +65,7 @@ struct VariableTypeGetter {
std::string operator()(float) const { return "float"; }
std::string operator()(const float2&) const { return "vec2"; }
std::string operator()(const float4&) const { return "vec4"; }
std::string operator()(const std::vector<float4>&) const { return "vec4"; }
};
// Returns GLSL uniform type of the given variable.

View File

@ -57,14 +57,17 @@ struct ParameterSetter {
return TFLITE_GPU_CALL_GL(glProgramUniform1i, program_id, uniform_id,
value);
}
Status operator()(const int2& value) {
return TFLITE_GPU_CALL_GL(glProgramUniform2i, program_id, uniform_id,
value.x, value.y);
}
Status operator()(const int4& value) {
return TFLITE_GPU_CALL_GL(glProgramUniform4i, program_id, uniform_id,
value.x, value.y, value.z, value.w);
}
Status operator()(const std::vector<int2>& value) {
std::vector<GLint> ints(value.size() * 2, 0);
for (int i = 0; i < value.size(); ++i) {
@ -74,27 +77,44 @@ struct ParameterSetter {
return TFLITE_GPU_CALL_GL(glProgramUniform2iv, program_id, uniform_id,
ints.size(), ints.data());
}
Status operator()(unsigned int value) {
return TFLITE_GPU_CALL_GL(glProgramUniform1ui, program_id, uniform_id,
value);
}
Status operator()(const uint4& value) {
return TFLITE_GPU_CALL_GL(glProgramUniform4ui, program_id, uniform_id,
value.x, value.y, value.z, value.w);
}
Status operator()(float value) {
return TFLITE_GPU_CALL_GL(glProgramUniform1f, program_id, uniform_id,
value);
}
Status operator()(const float2& value) {
return TFLITE_GPU_CALL_GL(glProgramUniform2f, program_id, uniform_id,
value.x, value.y);
}
Status operator()(const float4& value) {
return TFLITE_GPU_CALL_GL(glProgramUniform4f, program_id, uniform_id,
value.x, value.y, value.z, value.w);
}
Status operator()(const std::vector<float4>& value) {
std::vector<GLfloat> floats(value.size() * 4, 0);
for (int i = 0; i < value.size(); ++i) {
floats[i * 4] = value[i].x;
floats[i * 4 + 1] = value[i].y;
floats[i * 4 + 2] = value[i].z;
floats[i * 4 + 3] = value[i].w;
}
return TFLITE_GPU_CALL_GL(glProgramUniform4fv, program_id, uniform_id,
floats.size(), floats.data());
}
const GLuint program_id;
const GLint uniform_id;
};

View File

@ -37,12 +37,14 @@ struct ParameterValueGetter {
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(const int2& value) {
auto offset = builder->CreateVector(std::vector<int32_t>{value.x, value.y});
data::DataInt32Builder data(*builder);
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(const int4& value) {
auto offset = builder->CreateVector(
std::vector<int32_t>{value.x, value.y, value.z, value.w});
@ -50,6 +52,7 @@ struct ParameterValueGetter {
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(const std::vector<int2>& value) {
std::vector<int32_t> d(value.size() * 2);
for (size_t i = 0; i < value.size(); ++i) {
@ -61,12 +64,14 @@ struct ParameterValueGetter {
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(uint32_t value) {
auto offset = builder->CreateVector(std::vector<uint32_t>{value});
data::DataUint32Builder data(*builder);
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(const uint4& value) {
auto offset = builder->CreateVector(
std::vector<uint32_t>{value.x, value.y, value.z, value.w});
@ -74,18 +79,21 @@ struct ParameterValueGetter {
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(float value) {
auto offset = builder->CreateVector(std::vector<float>{value});
data::DataFloatBuilder data(*builder);
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(const float2& value) {
auto offset = builder->CreateVector(std::vector<float>{value.x, value.y});
data::DataFloatBuilder data(*builder);
data.add_data(offset);
return data.Finish().Union();
}
Offset<void> operator()(const float4& value) {
auto offset = builder->CreateVector(
std::vector<float>{value.x, value.y, value.z, value.w});
@ -94,6 +102,20 @@ struct ParameterValueGetter {
return data.Finish().Union();
}
Offset<void> operator()(const std::vector<float4>& value) {
std::vector<float> d(value.size() * 4);
for (size_t i = 0; i < value.size(); ++i) {
d[i * 4] = value[i].x;
d[i * 4 + 1] = value[i].y;
d[i * 4 + 2] = value[i].z;
d[i * 4 + 3] = value[i].w;
}
auto offset = builder->CreateVector(d);
data::DataFloatBuilder data(*builder);
data.add_data(offset);
return data.Finish().Union();
}
::flatbuffers::FlatBufferBuilder* builder;
};
@ -101,60 +123,84 @@ struct DataVariantTypeGetter {
data::DataVariant operator()(int32_t) const {
return data::DataVariant::DataInt32;
}
data::DataVariant operator()(const int2&) const {
return data::DataVariant::DataInt32;
}
data::DataVariant operator()(const int4&) const {
return data::DataVariant::DataInt32;
}
data::DataVariant operator()(const std::vector<int2>&) const {
return data::DataVariant::DataInt32;
}
data::DataVariant operator()(uint32_t) const {
return data::DataVariant::DataUint32;
}
data::DataVariant operator()(const uint4&) const {
return data::DataVariant::DataUint32;
}
data::DataVariant operator()(float) const {
return data::DataVariant::DataFloat;
}
data::DataVariant operator()(const float2&) const {
return data::DataVariant::DataFloat;
}
data::DataVariant operator()(const float4&) const {
return data::DataVariant::DataFloat;
}
data::DataVariant operator()(const std::vector<float4>&) const {
return data::DataVariant::DataFloat;
}
};
struct ParameterTypeGetter {
data::ParameterType operator()(int32_t) const {
return data::ParameterType::INT32;
}
data::ParameterType operator()(const int2&) const {
return data::ParameterType::INT32;
}
data::ParameterType operator()(const int4&) const {
return data::ParameterType::INT32;
}
data::ParameterType operator()(const std::vector<int2>&) const {
return data::ParameterType::INT32_2;
}
data::ParameterType operator()(uint32_t) const {
return data::ParameterType::UINT32;
}
data::ParameterType operator()(const uint4&) const {
return data::ParameterType::UINT32;
}
data::ParameterType operator()(float) const {
return data::ParameterType::FLOAT32;
}
data::ParameterType operator()(const float2&) const {
return data::ParameterType::FLOAT32;
}
data::ParameterType operator()(const float4&) const {
return data::ParameterType::FLOAT32;
}
data::ParameterType operator()(const std::vector<float4>&) const {
return data::ParameterType::FLOAT32;
}
};
data::DataType ToFB(DataType type) {

View File

@ -70,14 +70,17 @@ struct ParameterComparator {
bool operator()(int32_t value) const {
return value == absl::get<int32_t>(a.value);
}
bool operator()(const int2& value) const {
auto v = absl::get<int2>(a.value);
return value.x == v.x && value.y == v.y;
}
bool operator()(const int4& value) const {
auto v = absl::get<int4>(a.value);
return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w;
}
bool operator()(const std::vector<int2>& value) const {
auto v = absl::get<std::vector<int2>>(a.value);
if (v.size() != value.size()) {
@ -90,24 +93,43 @@ struct ParameterComparator {
}
return true;
}
bool operator()(uint32_t value) const {
return value == absl::get<uint32_t>(a.value);
}
bool operator()(const uint4& value) const {
auto v = absl::get<uint4>(a.value);
return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w;
}
bool operator()(float value) const {
return value == absl::get<float>(a.value);
}
bool operator()(float2 value) const {
auto v = absl::get<float2>(a.value);
return value.x == v.x && value.y == v.y;
}
bool operator()(const float4& value) const {
auto v = absl::get<float4>(a.value);
return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w;
}
bool operator()(const std::vector<float4>& value) const {
auto v = absl::get<std::vector<float4>>(a.value);
if (v.size() != value.size()) {
return false;
}
for (int i = 0; i < v.size(); ++i) {
if (v[i].x != value[i].x || v[i].y != value[i].y) {
return false;
}
}
return true;
}
Variable a;
};

View File

@ -28,8 +28,9 @@ namespace gpu {
namespace gl {
struct Variable {
using ValueType = absl::variant<int32_t, int2, int4, uint32_t, uint4, float,
float2, float4, std::vector<int2>>;
using ValueType =
absl::variant<int32_t, int2, int4, uint32_t, uint4, float, float2, float4,
std::vector<int2>, std::vector<float4>>;
std::string name;
ValueType value;