Published the GPU delegates.
PiperOrigin-RevId: 240848313
This commit is contained in:
parent
fd2db21368
commit
fb772b781b
115
tensorflow/lite/delegates/gpu/BUILD
Normal file
115
tensorflow/lite/delegates/gpu/BUILD
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# Primary purpose of this config is to replace ::util::Status with our custom
|
||||||
|
# light implementation ::tflite::gpu::StatusLite to reduce binary size. Besides
|
||||||
|
# that, certain features that were hard to communicate without full open source
|
||||||
|
# were hidden away too such as compiled models, serialization, and metadata.
|
||||||
|
# While the latter will be fully available with the open source release, the
|
||||||
|
# former will have to stay until absl::Status is released.
|
||||||
|
config_setting(
|
||||||
|
name = "tflite_gpu_binary_release",
|
||||||
|
values = {"copt": "-DTFLITE_GPU_BINARY_RELEASE"},
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_delegate",
|
||||||
|
srcs = ["gl_delegate.cc"],
|
||||||
|
hdrs = ["gl_delegate.h"],
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"-lEGL",
|
||||||
|
"-lGLESv3",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:api",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:command_queue",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:compiler",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:egl_environment",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_call",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/converters:bhwc_to_phwc4",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/converters:phwc4_to_bhwc",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/kernels:registry",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/workgroups:best_effort_calculator",
|
||||||
|
] + select({
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:common_cc_fbs",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:metadata_cc_fbs",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:workgroups_cc_fbs",
|
||||||
|
"@flatbuffers",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
":tflite_gpu_binary_release": [],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "metal_delegate",
|
||||||
|
srcs = ["metal_delegate.mm"],
|
||||||
|
hdrs = ["metal_delegate.h"],
|
||||||
|
copts = ["-std=c++11"],
|
||||||
|
sdk_frameworks = ["Metal"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:api",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:buffer_convert",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtflite_gpu_gl.so
|
||||||
|
cc_binary(
|
||||||
|
name = "libtflite_gpu_gl.so",
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"-lEGL",
|
||||||
|
"-lGLESv3",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
linkshared = 1,
|
||||||
|
linkstatic = 1,
|
||||||
|
tags = [
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [":gl_delegate"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# build -c opt --config ios_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtflite_gpu_metal.so
|
||||||
|
cc_binary(
|
||||||
|
name = "libtflite_gpu_metal.so",
|
||||||
|
linkshared = 1,
|
||||||
|
linkstatic = 1,
|
||||||
|
tags = [
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [":metal_delegate"],
|
||||||
|
)
|
205
tensorflow/lite/delegates/gpu/README.md
Normal file
205
tensorflow/lite/delegates/gpu/README.md
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
# TFLite on GPU
|
||||||
|
|
||||||
|
TensorFlow Lite (TFLite) supports several hardware accelerators. This document
|
||||||
|
describes how to use the GPU backend using the TFLite delegate APIs on Android
|
||||||
|
and iOS.
|
||||||
|
|
||||||
|
GPUs are designed to have high throughput for massively parallelizable
|
||||||
|
workloads. Thus, they are well-suited for deep neural nets which consists of a
|
||||||
|
huge number of operators, each working on some input tensor(s) that can be
|
||||||
|
easily divided into smaller workloads and carried out in parallel, typically
|
||||||
|
resulting in lower latency. In the best scenario, inference on the GPU may now
|
||||||
|
run fast enough and now become suitable for real-time applications if it was not
|
||||||
|
before.
|
||||||
|
|
||||||
|
GPUs do their computation with 16-bit or 32-bit floating point numbers and do
|
||||||
|
not require quantization for optimal performance unlike the CPUs. If
|
||||||
|
quantization of your neural network was not an option due to lower accuracy
|
||||||
|
caused by lost precision, such concern can be discarded when running deep neural
|
||||||
|
net models on the GPU.
|
||||||
|
|
||||||
|
Another benefit that comes with GPU inference is its power efficiency. GPUs
|
||||||
|
carry out the computations in a very efficient and optimized way, so that they
|
||||||
|
consume less power and generate less heat than when the same task is run on the
|
||||||
|
CPUs.
|
||||||
|
|
||||||
|
TFLite on GPU supports the following ops in 16-bit and 32-bit float precision:
|
||||||
|
|
||||||
|
* `ADD v1`
|
||||||
|
* `AVERAGE_POOL_2D v1`
|
||||||
|
* `CONCATENATION v1`
|
||||||
|
* `CONV_2D v1`
|
||||||
|
* `DEPTHWISE_CONV_2D v1-2`
|
||||||
|
* `FULLY_CONNECTED v1`
|
||||||
|
* `LOGISTIC v1`
|
||||||
|
* `LSTM v2 (Basic LSTM only)`
|
||||||
|
* `MAX_POOL_2D v1`
|
||||||
|
* `MUL v1`
|
||||||
|
* `PAD v1`
|
||||||
|
* `PRELU v1`
|
||||||
|
* `RELU v1`
|
||||||
|
* `RELU6 v1`
|
||||||
|
* `RESHAPE v1`
|
||||||
|
* `RESIZE_BILINEAR v1`
|
||||||
|
* `SOFTMAX v1`
|
||||||
|
* `STRIDED_SLICE v1`
|
||||||
|
* `SUB v1`
|
||||||
|
* `TRANSPOSE_CONV v1`
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
Using TFLite on GPU is as simple as getting the GPU delegate via
|
||||||
|
`TfLiteGpuDelegateCreate()` and then passing it to
|
||||||
|
`Interpreter::ModifyGraphWithDelegate()` instead of calling
|
||||||
|
`Interpreter::AllocateTensors()`:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
////////
|
||||||
|
// Set up interpreter.
|
||||||
|
auto model = FlatBufferModel::BuildFromFile(model_path);
|
||||||
|
ops::builtin::BuiltinOpResolver op_resolver;
|
||||||
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
|
InterpreterBuilder(*model, op_resolver)(&interpreter);
|
||||||
|
|
||||||
|
////////
|
||||||
|
// NEW: Prepare GPU delegate.
|
||||||
|
auto* delegate = TfLiteGpuDelegateCreate(/*options=*/nullptr);
|
||||||
|
if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return;
|
||||||
|
|
||||||
|
////////
|
||||||
|
// Run inference.
|
||||||
|
WriteToInputTensor(interpreter->typed_input_tensor<float>(0));
|
||||||
|
if (interpreter->Invoke() != kTfLiteOk) return;
|
||||||
|
ReadFromOutputTensor(interpreter->typed_output_tensor<float>(0));
|
||||||
|
|
||||||
|
////////
|
||||||
|
// Clean up.
|
||||||
|
TfLiteGpuDelegateDelete(delegate);
|
||||||
|
```
|
||||||
|
|
||||||
|
*IMPORTANT:* When calling `Interpreter::ModifyGraphWithDelegate()` or
|
||||||
|
`Interpreter::Invoke()`, the caller must have a `EGLContext` in the current
|
||||||
|
thread and `Interpreter::Invoke()` must be called from the same `EGLContext`.
|
||||||
|
If such `EGLContext` does not exist, the delegate will internally create one,
|
||||||
|
but then the developer must ensure that `Interpreter::Invoke()` is always called
|
||||||
|
from the same thread `Interpreter::ModifyGraphWithDelegate()` was called.
|
||||||
|
|
||||||
|
## Building and Runtime
|
||||||
|
|
||||||
|
TFLite GPU backend uses OpenGL compute shaders and thus requires OpenGL ES 3.1
|
||||||
|
or higher.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
bazel build --config android_arm64 //path/to/your:project
|
||||||
|
```
|
||||||
|
|
||||||
|
Metal shaders are used for iOS, which were introduced with iOS 8. Thus,
|
||||||
|
compilation flags should look like:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
bazel build --config ios_arm64 //path/to/your:project
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage: Delegate Options
|
||||||
|
|
||||||
|
There are GPU options that can be set and passed on to
|
||||||
|
`TfLiteGpuDelegateCreate()`. When option is set to `nullptr` as shown in the
|
||||||
|
Basic Usage, it translates to:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
const TfLiteGpuDelegateOptions kDefaultOptions = {
|
||||||
|
.metadata = nullptr,
|
||||||
|
.compile_options = {
|
||||||
|
.precision_loss_allowed = 0, // false
|
||||||
|
.preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,
|
||||||
|
.dynamic_batch_enabled = 0, // false
|
||||||
|
},
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
Similar for `NewTfLiteMetalDelgate()`:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
const TfLiteMetalDelegateOptions kDefaultOptions = {
|
||||||
|
.precision_loss_allowed = 0, // false
|
||||||
|
.wait_type = TFLITE_METAL_WAIT_TYPE_SLEEP,
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
While it is convenient to just supply `nullptr`, it is recommended to explicitly
|
||||||
|
set the options to avoid any unexpected artifacts in case default values are
|
||||||
|
changed.
|
||||||
|
|
||||||
|
## Advanced Usage: Input/Output Buffers (C++)
|
||||||
|
|
||||||
|
To do computation on the GPU, data must be made available to the GPU which often
|
||||||
|
translates to performing a memory copy. It is desirable not to cross the
|
||||||
|
CPU/GPU memory boundary if possible, as this can take up a significant amount of
|
||||||
|
time. Usually, such crossing is inevitable, but in some special cases, one or
|
||||||
|
the other can be omitted.
|
||||||
|
|
||||||
|
If the network's input is an image already loaded in the GPU memory, e.g. a GPU
|
||||||
|
texture containing the camera feed, it can stay in the GPU memory without ever
|
||||||
|
entering the CPU memory. Similarly, if the network's output is in the form of a
|
||||||
|
renderable image, e.g.
|
||||||
|
[image style transfer](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf),
|
||||||
|
it can be directly displayed on the screen.
|
||||||
|
|
||||||
|
To let users achieve best performance, TFLite makes it possible for them to
|
||||||
|
directly read from/write to the delegate's hardware buffer and bypass avoidable
|
||||||
|
memory copies.
|
||||||
|
|
||||||
|
Assuming the camera input is in the GPU memory as `GL_TEXTURE_2D`, it must be
|
||||||
|
first converted to a shader storage buffer object (SSBO) for OpenGL or to a
|
||||||
|
`MTLBuffer` object for Metal. One can associate a TfLiteTensor with a
|
||||||
|
user-prepared SSBO or `MTLBuffer` with `TfLiteGpuDelegateBindBufferToTensor()`
|
||||||
|
or `TfLiteMetalDelegateBindBufferToTensor()`, respectively.
|
||||||
|
|
||||||
|
*IMPORTANT:* These must be called before
|
||||||
|
`Interpreter::ModifyGraphWithDelegate()`.
|
||||||
|
|
||||||
|
*IMPORTANT:* By default, the inference output is copied from GPU memory to CPU
|
||||||
|
memory implicitly by the framework. This behavior can be turned off by calling
|
||||||
|
`Interpreter::SetAllowBufferHandleOutput(true)` during initialization. To copy
|
||||||
|
the inference output from GPU memory to CPU memory, explicit
|
||||||
|
`Interpreter::EnsureTensorDataIsReadable()` calls are required for each output
|
||||||
|
tensor.
|
||||||
|
|
||||||
|
```c++
|
||||||
|
////////
|
||||||
|
// Prepare GPU delegate.
|
||||||
|
auto* delegate = TfLiteGpuDelegateCreate(nullptr);
|
||||||
|
interpreter->SetAllowBufferHandleOutput(true); // disable default gpu->cpu copy
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
if (TfLiteGpuDelegateBindBufferToTensor(delegate, user_provided_input_buffer, interpreter->inputs()[0]) != kTfLiteOk) return;
|
||||||
|
if (TfLiteGpuDelegateBindBufferToTensor(delegate, user_provided_output_buffer, interpreter->outputs()[0]) != kTfLiteOk) return;
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
if (TfLiteMetalDelegateBindBufferToTensor(delegate, user_provided_input_buffer, interpreter->inputs()[0]) != kTfLiteOk) return;
|
||||||
|
if (TfLiteMetalDelegateBindBufferToTensor(delegate, user_provided_output_buffer, interpreter->outputs()[0]) != kTfLiteOk) return;
|
||||||
|
#endif
|
||||||
|
if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return;
|
||||||
|
|
||||||
|
////////
|
||||||
|
// Run inference.
|
||||||
|
if (interpreter->Invoke() != kTfLiteOk) return;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips and Tricks
|
||||||
|
|
||||||
|
* Some operations that are trivial on CPU side may be high cost in GPU land.
|
||||||
|
One class of such operation is various forms of reshape operations (including
|
||||||
|
`BATCH_TO_SPACE`, `SPACE_TO_BATCH`, `SPACE_TO_DEPTH`, etc.). If those ops
|
||||||
|
are inserted into the network just for the network architect's logical
|
||||||
|
thinking, it is worth removing them for performance.
|
||||||
|
|
||||||
|
* On GPU, tensor data is sliced into 4-channels. Thus, a computation on a
|
||||||
|
tensor of shape `[B, H, W, 5]` will perform about the same on a tensor of
|
||||||
|
shape `[B, H, W, 8]`, but significantly worse than `[B, H, W, 4]`.
|
||||||
|
|
||||||
|
* In that sense, if the camera hardware supports image frames in RGBA, feeding
|
||||||
|
that 4-channel input is significantly faster as a memory copy (from 3-channel
|
||||||
|
RGB to 4-channel RGBX) can be avoided.
|
||||||
|
|
||||||
|
* For performance [best practices](https://www.tensorflow.org/lite/performance/best_practices), do not hesitate to re-train your classifier with
|
||||||
|
mobile-optimized network architecture. That is a significant part of
|
||||||
|
optimization for on-device inference.
|
154
tensorflow/lite/delegates/gpu/common/BUILD
Normal file
154
tensorflow/lite/delegates/gpu/common/BUILD
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "convert",
|
||||||
|
srcs = ["convert.cc"],
|
||||||
|
hdrs = ["convert.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"@FP16",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "data_type",
|
||||||
|
srcs = ["data_type.cc"],
|
||||||
|
hdrs = ["data_type.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model",
|
||||||
|
hdrs = ["model.h"],
|
||||||
|
deps = [
|
||||||
|
":data_type",
|
||||||
|
":shape",
|
||||||
|
":status",
|
||||||
|
":tensor",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "model_test",
|
||||||
|
srcs = ["model_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":model",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model_builder",
|
||||||
|
srcs = ["model_builder.cc"],
|
||||||
|
hdrs = ["model_builder.h"],
|
||||||
|
deps = [
|
||||||
|
":data_type",
|
||||||
|
":model",
|
||||||
|
":operations",
|
||||||
|
":shape",
|
||||||
|
":status",
|
||||||
|
":tensor",
|
||||||
|
"//tensorflow/lite:context",
|
||||||
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite/kernels:kernel_util",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(impjdi): Add unit test for model_builder.
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model_transformer",
|
||||||
|
srcs = ["model_transformer.cc"],
|
||||||
|
hdrs = ["model_transformer.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(impjdi): Add unit test for model_transformer.
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "operations",
|
||||||
|
srcs = ["operations.cc"],
|
||||||
|
hdrs = ["operations.h"],
|
||||||
|
deps = [
|
||||||
|
":data_type",
|
||||||
|
":model",
|
||||||
|
":shape",
|
||||||
|
":status",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(impjdi): Add unit test for operations.
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "shape",
|
||||||
|
srcs = ["shape.cc"],
|
||||||
|
hdrs = ["shape.h"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "shape_test",
|
||||||
|
srcs = ["shape_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":shape",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "status",
|
||||||
|
hdrs = ["status.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor",
|
||||||
|
hdrs = ["tensor.h"],
|
||||||
|
deps = [
|
||||||
|
":data_type",
|
||||||
|
":shape",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "types",
|
||||||
|
hdrs = ["types.h"],
|
||||||
|
deps = [
|
||||||
|
"@FP16",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "util",
|
||||||
|
hdrs = ["util.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "util_test",
|
||||||
|
srcs = ["util_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":util",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
506
tensorflow/lite/delegates/gpu/common/convert.cc
Normal file
506
tensorflow/lite/delegates/gpu/common/convert.cc
Normal file
@ -0,0 +1,506 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||||
|
|
||||||
|
#include <fp16.h>
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kPhwc4ChannelsInPlane = 4;
|
||||||
|
constexpr int kPhwo4i4ChannelsInPlane = 4;
|
||||||
|
constexpr int kPiohw4ChannelsInPlane = 4;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape) {
|
||||||
|
return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
|
||||||
|
AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape) {
|
||||||
|
return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) *
|
||||||
|
AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layout is Po,H,W,OI4x4.
|
||||||
|
Status ConvertToPHWO4I4(absl::Span<const float> in, const OHWI& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
if (in.size() != shape.DimensionsProduct()) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPHWO4I4: Input data size does not match expected size: ",
|
||||||
|
in.size(), " != ", shape.DimensionsProduct()));
|
||||||
|
}
|
||||||
|
if (out.size() != GetElementsSizeForPHWO4I4(shape)) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPHWO4I4: Output data size does not match expected size: ",
|
||||||
|
out.size(), " != ", GetElementsSizeForPHWO4I4(shape)));
|
||||||
|
}
|
||||||
|
|
||||||
|
float* output = out.data();
|
||||||
|
for (int p = 0; p < IntegralDivideRoundUp(shape.o, kPhwo4i4ChannelsInPlane);
|
||||||
|
++p) {
|
||||||
|
for (int h = 0; h < shape.h; ++h) {
|
||||||
|
for (int w = 0; w < shape.w; ++w) {
|
||||||
|
for (int c = 0;
|
||||||
|
c < IntegralDivideRoundUp(shape.i, kPhwo4i4ChannelsInPlane); ++c) {
|
||||||
|
for (int co = 0; co < kPhwo4i4ChannelsInPlane; ++co) {
|
||||||
|
for (int ci = 0; ci < kPhwo4i4ChannelsInPlane; ++ci) {
|
||||||
|
float value = 0;
|
||||||
|
if (c * kPhwo4i4ChannelsInPlane + ci < shape.i &&
|
||||||
|
p * kPhwo4i4ChannelsInPlane + co < shape.o) {
|
||||||
|
// tensor is in OHWI
|
||||||
|
int tensor_o = p * kPhwo4i4ChannelsInPlane + co;
|
||||||
|
int tensor_i = c * kPhwo4i4ChannelsInPlane + ci;
|
||||||
|
value = in[shape.LinearIndex({tensor_o, h, w, tensor_i})];
|
||||||
|
}
|
||||||
|
(*output++) = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> ConvertToPHWO4I4(
|
||||||
|
const Tensor<OHWI, DataType::FLOAT32>& tensor) {
|
||||||
|
std::vector<float> transposed(GetElementsSizeForPHWO4I4(tensor.shape));
|
||||||
|
ConvertToPHWO4I4(tensor.data, tensor.shape,
|
||||||
|
absl::MakeSpan(transposed.data(), transposed.size()))
|
||||||
|
.IgnoreError();
|
||||||
|
return transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint3 Get3DSizeForPHWO4I4(const OHWI& shape) {
|
||||||
|
return uint3(AlignByN(shape.i, 4), shape.h * shape.w,
|
||||||
|
IntegralDivideRoundUp(shape.o, 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layout is Po,H,W,OI4x4.
|
||||||
|
Status ConvertToPHWO4I4(absl::Span<const float> in, const IHWO& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
if (in.size() != shape.DimensionsProduct()) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPHWO4I4: Input data size does not match expected size: ",
|
||||||
|
in.size(), " != ", shape.DimensionsProduct()));
|
||||||
|
}
|
||||||
|
if (out.size() != GetElementsSizeForPHWO4I4(shape)) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPHWO4I4: Output data size does not match expected size: ",
|
||||||
|
out.size(), " != ", GetElementsSizeForPHWO4I4(shape)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dst_depth = IntegralDivideRoundUp(shape.o, 4);
|
||||||
|
const int src_depth = IntegralDivideRoundUp(shape.i, 4);
|
||||||
|
|
||||||
|
float* output = out.data();
|
||||||
|
for (int f = 0; f < dst_depth; ++f) {
|
||||||
|
for (int y = 0; y < shape.h; ++y) {
|
||||||
|
for (int x = 0; x < shape.w; ++x) {
|
||||||
|
for (int ch = 0; ch < src_depth; ++ch) {
|
||||||
|
for (int co = 0; co < 4; ++co) {
|
||||||
|
for (int ci = 0; ci < 4; ++ci) {
|
||||||
|
const int src_channel = ch * 4 + ci;
|
||||||
|
const int dst_channel = f * 4 + co;
|
||||||
|
float value = 0;
|
||||||
|
if (src_channel < shape.i && dst_channel < shape.o) {
|
||||||
|
// tensor is in IHWO
|
||||||
|
value = in[shape.LinearIndex({src_channel, y, x, dst_channel})];
|
||||||
|
}
|
||||||
|
(*output++) = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> ConvertToPHWO4I4(
|
||||||
|
const Tensor<IHWO, DataType::FLOAT32>& tensor) {
|
||||||
|
std::vector<float> transposed(GetElementsSizeForPHWO4I4(tensor.shape));
|
||||||
|
ConvertToPHWO4I4(tensor.data, tensor.shape,
|
||||||
|
absl::MakeSpan(transposed.data(), transposed.size()))
|
||||||
|
.IgnoreError();
|
||||||
|
return transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t GetElementsSizeForPIOHW4(const OHWI& shape) {
|
||||||
|
return AlignByN(shape.o * shape.i, kPiohw4ChannelsInPlane) * shape.h *
|
||||||
|
shape.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertToPIOHW4(absl::Span<const float> in, const OHWI& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
if (in.size() != shape.DimensionsProduct()) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPIOHW4: Input data size does not match expected size: ",
|
||||||
|
in.size(), " != ", shape.DimensionsProduct()));
|
||||||
|
}
|
||||||
|
if (out.size() != GetElementsSizeForPIOHW4(shape)) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPIOHW4: Output data size does not match expected size: ",
|
||||||
|
out.size(), " != ", GetElementsSizeForPIOHW4(shape)));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t output_channels = shape.o * shape.i;
|
||||||
|
int32_t num_planes =
|
||||||
|
IntegralDivideRoundUp(output_channels, kPiohw4ChannelsInPlane);
|
||||||
|
float* output = out.data();
|
||||||
|
for (int p = 0; p < num_planes; ++p) {
|
||||||
|
for (int h = 0; h < shape.h; ++h) {
|
||||||
|
for (int w = 0; w < shape.w; ++w) {
|
||||||
|
for (int c = 0; c < kPiohw4ChannelsInPlane; ++c) {
|
||||||
|
int output_c = p * kPiohw4ChannelsInPlane + c;
|
||||||
|
(*output++) = output_c >= output_channels
|
||||||
|
? 0
|
||||||
|
: in[shape.LinearIndex({output_c % shape.o, h, w,
|
||||||
|
output_c / shape.o})];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> ConvertToPIOHW4(
|
||||||
|
const Tensor<OHWI, DataType::FLOAT32>& tensor) {
|
||||||
|
std::vector<float> transposed(GetElementsSizeForPIOHW4(tensor.shape));
|
||||||
|
ConvertToPIOHW4(tensor.data, tensor.shape,
|
||||||
|
absl::MakeSpan(transposed.data(), transposed.size()))
|
||||||
|
.IgnoreError();
|
||||||
|
return transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ValidateConvertToPHWC4(absl::Span<const float> in, const BHWC& shape,
|
||||||
|
absl::Span<T> out) {
|
||||||
|
if (in.size() != shape.DimensionsProduct()) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPHWC4: Input data size does not match expected size: ",
|
||||||
|
in.size(), " != ", shape.DimensionsProduct()));
|
||||||
|
}
|
||||||
|
if (out.size() != GetElementsSizeForPHWC4(shape)) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertToPHWC4: Output data size does not match expected size: ",
|
||||||
|
out.size(), " != ", GetElementsSizeForPHWC4(shape)));
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layout is Pc,H,W,C4 where P - is a plane based on channels.
|
||||||
|
Status ConvertToPHWC4(absl::Span<const float> in, const BHWC& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out));
|
||||||
|
if (shape.c == 4) {
|
||||||
|
std::memcpy(out.data(), in.data(),
|
||||||
|
shape.DimensionsProduct() * sizeof(float));
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
// Layout is Pc,H,W,C4 where P - is a plane based on channels.
|
||||||
|
int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane);
|
||||||
|
const int num_pixels = shape.h * shape.w;
|
||||||
|
// A layer is a set of kPhwc4ChannelsInPlane channels images.
|
||||||
|
const int num_full_planes = shape.c / kPhwc4ChannelsInPlane;
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
float* dest =
|
||||||
|
out.data() + b * num_pixels * num_planes * kPhwc4ChannelsInPlane;
|
||||||
|
for (int p = 0; p < num_full_planes; p++) {
|
||||||
|
const float* src =
|
||||||
|
in.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane});
|
||||||
|
for (int i = 0; i < num_pixels; i++) {
|
||||||
|
std::memcpy(dest, src, kPhwc4ChannelsInPlane * sizeof(float));
|
||||||
|
src += shape.c;
|
||||||
|
dest += kPhwc4ChannelsInPlane;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Padding last kPhwc4ChannelsInPlane-channel layer to multiple of
|
||||||
|
// kPhwc4ChannelsInPlane.
|
||||||
|
const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane;
|
||||||
|
const int remaining_channels =
|
||||||
|
shape.c - num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
if (remaining_channels == 0) {
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
const float* src =
|
||||||
|
in.data() +
|
||||||
|
shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane});
|
||||||
|
float* dest = out.data() + b * padded_size +
|
||||||
|
num_pixels * num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
std::memcpy(dest, src, remaining_channels * sizeof(float));
|
||||||
|
std::memset(dest + remaining_channels, 0,
|
||||||
|
(4 - remaining_channels) * sizeof(float));
|
||||||
|
src += shape.c;
|
||||||
|
dest += kPhwc4ChannelsInPlane;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layout is Pc,H,W,C4 where P - is a plane based on channels.
|
||||||
|
Status ConvertToPHWC4Half(absl::Span<const float> in, const BHWC& shape,
|
||||||
|
absl::Span<HalfBits> out) {
|
||||||
|
RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out));
|
||||||
|
|
||||||
|
// Layout is Pc,H,W,C4 where P - is a plane based on channels.
|
||||||
|
int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane);
|
||||||
|
const int num_pixels = shape.h * shape.w;
|
||||||
|
// A layer is a set of kPhwc4ChannelsInPlane channels images.
|
||||||
|
const int num_full_planes = shape.c / kPhwc4ChannelsInPlane;
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
HalfBits* dest =
|
||||||
|
out.data() + b * num_pixels * num_planes * kPhwc4ChannelsInPlane;
|
||||||
|
for (int p = 0; p < num_full_planes; p++) {
|
||||||
|
const float* src =
|
||||||
|
in.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane});
|
||||||
|
for (int i = 0; i < num_pixels; i++) {
|
||||||
|
dest[0] = fp16_ieee_from_fp32_value(src[0]);
|
||||||
|
dest[1] = fp16_ieee_from_fp32_value(src[1]);
|
||||||
|
dest[2] = fp16_ieee_from_fp32_value(src[2]);
|
||||||
|
dest[3] = fp16_ieee_from_fp32_value(src[3]);
|
||||||
|
src += shape.c;
|
||||||
|
dest += kPhwc4ChannelsInPlane;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Padding last kPhwc4ChannelsInPlane-channel layer to multiple of
|
||||||
|
// kPhwc4ChannelsInPlane.
|
||||||
|
const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane;
|
||||||
|
const int remaining_channels =
|
||||||
|
shape.c - num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
if (remaining_channels == 0) {
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
const float* src =
|
||||||
|
in.data() +
|
||||||
|
shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane});
|
||||||
|
HalfBits* dest = out.data() + b * padded_size +
|
||||||
|
num_pixels * num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
switch (remaining_channels) {
|
||||||
|
case 1:
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
dest[0] = fp16_ieee_from_fp32_value(src[0]);
|
||||||
|
dest[1] = 0;
|
||||||
|
dest[2] = 0;
|
||||||
|
dest[3] = 0;
|
||||||
|
src += shape.c;
|
||||||
|
dest += kPhwc4ChannelsInPlane;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
dest[0] = fp16_ieee_from_fp32_value(src[0]);
|
||||||
|
dest[1] = fp16_ieee_from_fp32_value(src[1]);
|
||||||
|
dest[2] = 0;
|
||||||
|
dest[3] = 0;
|
||||||
|
src += shape.c;
|
||||||
|
dest += kPhwc4ChannelsInPlane;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
dest[0] = fp16_ieee_from_fp32_value(src[0]);
|
||||||
|
dest[1] = fp16_ieee_from_fp32_value(src[1]);
|
||||||
|
dest[2] = fp16_ieee_from_fp32_value(src[2]);
|
||||||
|
dest[3] = 0;
|
||||||
|
src += shape.c;
|
||||||
|
dest += kPhwc4ChannelsInPlane;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return UnimplementedError(
|
||||||
|
"ConvertToPHWC4Half: Unsupported channels per planes count.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> ConvertToPHWC4(
|
||||||
|
const Tensor<BHWC, DataType::FLOAT32>& tensor) {
|
||||||
|
std::vector<float> transposed(GetElementsSizeForPHWC4(tensor.shape));
|
||||||
|
ConvertToPHWC4(tensor.data, tensor.shape,
|
||||||
|
absl::MakeSpan(transposed.data(), transposed.size()))
|
||||||
|
.IgnoreError();
|
||||||
|
// TODO(akulik): Maybe safer to return Status.
|
||||||
|
return transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> ConvertToPHWC4(
|
||||||
|
const Tensor<HWC, DataType::FLOAT32>& tensor) {
|
||||||
|
const BHWC batched_shape =
|
||||||
|
BHWC(1, tensor.shape.h, tensor.shape.w, tensor.shape.c);
|
||||||
|
std::vector<float> transposed(GetElementsSizeForPHWC4(batched_shape));
|
||||||
|
ConvertToPHWC4(tensor.data, batched_shape,
|
||||||
|
absl::MakeSpan(transposed.data(), transposed.size()))
|
||||||
|
.IgnoreError();
|
||||||
|
// TODO(akulik): Maybe safer to return Status.
|
||||||
|
return transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t GetElementsSizeForPHWC4(const BHWC& shape) {
|
||||||
|
return shape.b * shape.h * shape.w * AlignByN(shape.c, kPhwc4ChannelsInPlane);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ValidateConvertFromPHWC4(absl::Span<const T> in, const BHWC& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
if (in.size() != GetElementsSizeForPHWC4(shape)) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertFromPHWC4: Input data size does not match expected size: ",
|
||||||
|
in.size(), " != ", GetElementsSizeForPHWC4(shape)));
|
||||||
|
}
|
||||||
|
if (out.size() != shape.DimensionsProduct()) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"ConvertFromPHWC4: Output data size does not match expected size: ",
|
||||||
|
out.size(), " != ", shape.DimensionsProduct()));
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertFromPHWC4(absl::Span<const float> in, const BHWC& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out));
|
||||||
|
if (shape.c == 4) {
|
||||||
|
std::memcpy(out.data(), in.data(),
|
||||||
|
shape.DimensionsProduct() * sizeof(float));
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane);
|
||||||
|
const int num_pixels = shape.h * shape.w;
|
||||||
|
const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane;
|
||||||
|
// A layer is a set of kPhwc4ChannelsInPlane channels images.
|
||||||
|
const int num_full_planes = shape.c / kPhwc4ChannelsInPlane;
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
const float* src = in.data() + b * padded_size;
|
||||||
|
for (int p = 0; p < num_full_planes; p++) {
|
||||||
|
float* dest =
|
||||||
|
out.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane});
|
||||||
|
for (int i = 0; i < num_pixels; i++) {
|
||||||
|
std::memcpy(dest, src, kPhwc4ChannelsInPlane * sizeof(float));
|
||||||
|
src += kPhwc4ChannelsInPlane;
|
||||||
|
dest += shape.c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unpadding last kPhwc4ChannelsInPlane-channel plane
|
||||||
|
const int remaining_channels =
|
||||||
|
shape.c - num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
if (remaining_channels == 0) {
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
const float* src = in.data() + b * padded_size +
|
||||||
|
num_pixels * num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
float* dest =
|
||||||
|
out.data() +
|
||||||
|
shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane});
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
std::memcpy(dest, src, remaining_channels * sizeof(float));
|
||||||
|
src += kPhwc4ChannelsInPlane;
|
||||||
|
dest += shape.c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertFromPHWC4Half(absl::Span<const HalfBits> in, const BHWC& shape,
|
||||||
|
absl::Span<float> out) {
|
||||||
|
RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out));
|
||||||
|
int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane);
|
||||||
|
const int num_pixels = shape.h * shape.w;
|
||||||
|
const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane;
|
||||||
|
// A layer is a set of kPhwc4ChannelsInPlane channels images.
|
||||||
|
const int num_full_planes = shape.c / kPhwc4ChannelsInPlane;
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
const HalfBits* src = in.data() + b * padded_size;
|
||||||
|
for (int p = 0; p < num_full_planes; p++) {
|
||||||
|
float* dest =
|
||||||
|
out.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane});
|
||||||
|
for (int i = 0; i < num_pixels; i++) {
|
||||||
|
dest[0] = fp16_ieee_to_fp32_value(src[0]);
|
||||||
|
dest[1] = fp16_ieee_to_fp32_value(src[1]);
|
||||||
|
dest[2] = fp16_ieee_to_fp32_value(src[2]);
|
||||||
|
dest[3] = fp16_ieee_to_fp32_value(src[3]);
|
||||||
|
src += kPhwc4ChannelsInPlane;
|
||||||
|
dest += shape.c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unpadding last kPhwc4ChannelsInPlane-channel plane
|
||||||
|
const int remaining_channels =
|
||||||
|
shape.c - num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
if (remaining_channels == 0) {
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
for (int b = 0; b < shape.b; b++) {
|
||||||
|
const HalfBits* src = in.data() + b * padded_size +
|
||||||
|
num_pixels * num_full_planes * kPhwc4ChannelsInPlane;
|
||||||
|
float* dest =
|
||||||
|
out.data() +
|
||||||
|
shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane});
|
||||||
|
switch (remaining_channels) {
|
||||||
|
case 1:
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
dest[0] = fp16_ieee_to_fp32_value(src[0]);
|
||||||
|
src += kPhwc4ChannelsInPlane;
|
||||||
|
dest += shape.c;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
dest[0] = fp16_ieee_to_fp32_value(src[0]);
|
||||||
|
dest[1] = fp16_ieee_to_fp32_value(src[1]);
|
||||||
|
src += kPhwc4ChannelsInPlane;
|
||||||
|
dest += shape.c;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
for (int p = 0; p < num_pixels; p++) {
|
||||||
|
dest[0] = fp16_ieee_to_fp32_value(src[0]);
|
||||||
|
dest[1] = fp16_ieee_to_fp32_value(src[1]);
|
||||||
|
dest[2] = fp16_ieee_to_fp32_value(src[2]);
|
||||||
|
src += kPhwc4ChannelsInPlane;
|
||||||
|
dest += shape.c;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return UnimplementedError(
|
||||||
|
"ConvertToPHWC4Half: Unsupported channels per planes count.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
97
tensorflow/lite/delegates/gpu/common/convert.h
Normal file
97
tensorflow/lite/delegates/gpu/common/convert.h
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// PHWC4 layout is where channels are grouped by 4 in a row and P stands for
|
||||||
|
// a plane that was derived by dividing channels by 4.
|
||||||
|
::tflite::gpu::Status ConvertToPHWC4(absl::Span<const float> in,
|
||||||
|
const BHWC& shape, absl::Span<float> out);
|
||||||
|
::tflite::gpu::Status ConvertToPHWC4Half(
|
||||||
|
absl::Span<const float> in, const BHWC& shape,
|
||||||
|
absl::Span<::tflite::gpu::HalfBits> out);
|
||||||
|
|
||||||
|
// @return number of elements when shape is converted into PHWC4.
|
||||||
|
uint32_t GetElementsSizeForPHWC4(const BHWC& shape);
|
||||||
|
|
||||||
|
// Operation is opposite to ConvertToPHWC4.
|
||||||
|
::tflite::gpu::Status ConvertFromPHWC4(absl::Span<const float> in,
|
||||||
|
const BHWC& shape,
|
||||||
|
absl::Span<float> out);
|
||||||
|
::tflite::gpu::Status ConvertFromPHWC4Half(
|
||||||
|
absl::Span<const ::tflite::gpu::HalfBits> in, const BHWC& shape,
|
||||||
|
absl::Span<float> out);
|
||||||
|
|
||||||
|
// Convenience wrapper around a method above.
|
||||||
|
std::vector<float> ConvertToPHWC4(
|
||||||
|
const Tensor<BHWC, DataType::FLOAT32>& tensor);
|
||||||
|
std::vector<float> ConvertToPHWC4(const Tensor<HWC, DataType::FLOAT32>& tensor);
|
||||||
|
|
||||||
|
// @return number of elements when shape is converted into PIOHW4.
|
||||||
|
uint32_t GetElementsSizeForPIOHW4(const OHWI& shape);
|
||||||
|
|
||||||
|
// PIOHW4 layout re-arranges weights in groups by 4, where outer dimension is
|
||||||
|
// P which is OxI/4.
|
||||||
|
::tflite::gpu::Status ConvertToPIOHW4(absl::Span<const float> in,
|
||||||
|
const OHWI& shape, absl::Span<float> out);
|
||||||
|
|
||||||
|
// Convenience wrapper around a method above.
|
||||||
|
std::vector<float> ConvertToPIOHW4(
|
||||||
|
const Tensor<OHWI, DataType::FLOAT32>& tensor);
|
||||||
|
|
||||||
|
// @return number of elements when shape is converted into PHWO4I4.
|
||||||
|
uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape);
|
||||||
|
|
||||||
|
// Layout is Po,H,W,OI4x4.
|
||||||
|
::tflite::gpu::Status ConvertToPHWO4I4(absl::Span<const float> in,
|
||||||
|
const OHWI& shape,
|
||||||
|
absl::Span<float> out);
|
||||||
|
|
||||||
|
// Convenience wrapper around a method above.
|
||||||
|
std::vector<float> ConvertToPHWO4I4(
|
||||||
|
const Tensor<OHWI, DataType::FLOAT32>& tensor);
|
||||||
|
|
||||||
|
// @return (x,y,z) size for PHWO4I4 to access elements where each element
|
||||||
|
// consists of 4 values.
|
||||||
|
::tflite::gpu::uint3 Get3DSizeForPHWO4I4(const OHWI& shape);
|
||||||
|
|
||||||
|
// @return number of elements when shape is converted into PHWO4I4.
|
||||||
|
uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape);
|
||||||
|
|
||||||
|
// Layout is Po,H,W,OI4x4.
|
||||||
|
::tflite::gpu::Status ConvertToPHWO4I4(absl::Span<const float> in,
|
||||||
|
const IHWO& shape,
|
||||||
|
absl::Span<float> out);
|
||||||
|
|
||||||
|
// Convenience wrapper around a method above.
|
||||||
|
std::vector<float> ConvertToPHWO4I4(
|
||||||
|
const Tensor<IHWO, DataType::FLOAT32>& tensor);
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_
|
78
tensorflow/lite/delegates/gpu/common/data_type.cc
Normal file
78
tensorflow/lite/delegates/gpu/common/data_type.cc
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
size_t SizeOf(DataType data_type) {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::UINT8:
|
||||||
|
case DataType::INT8:
|
||||||
|
return 1;
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
case DataType::INT16:
|
||||||
|
case DataType::UINT16:
|
||||||
|
return 2;
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
case DataType::INT32:
|
||||||
|
case DataType::UINT32:
|
||||||
|
return 4;
|
||||||
|
case DataType::FLOAT64:
|
||||||
|
case DataType::INT64:
|
||||||
|
case DataType::UINT64:
|
||||||
|
return 8;
|
||||||
|
case DataType::UNKNOWN:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString(DataType data_type) {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
return "float16";
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "float32";
|
||||||
|
case DataType::FLOAT64:
|
||||||
|
return "float64";
|
||||||
|
case DataType::INT16:
|
||||||
|
return "int16";
|
||||||
|
case DataType::INT32:
|
||||||
|
return "int32";
|
||||||
|
case DataType::INT64:
|
||||||
|
return "int64";
|
||||||
|
case DataType::INT8:
|
||||||
|
return "int8";
|
||||||
|
case DataType::UINT16:
|
||||||
|
return "uint16";
|
||||||
|
case DataType::UINT32:
|
||||||
|
return "uint32";
|
||||||
|
case DataType::UINT64:
|
||||||
|
return "uint64";
|
||||||
|
case DataType::UINT8:
|
||||||
|
return "uint8";
|
||||||
|
case DataType::UNKNOWN:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
return "undefined";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
47
tensorflow/lite/delegates/gpu/common/data_type.h
Normal file
47
tensorflow/lite/delegates/gpu/common/data_type.h
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_DATA_TYPE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_DATA_TYPE_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
enum class DataType {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
FLOAT16 = 1,
|
||||||
|
FLOAT32 = 2,
|
||||||
|
FLOAT64 = 3,
|
||||||
|
UINT8 = 4,
|
||||||
|
INT8 = 5,
|
||||||
|
UINT16 = 6,
|
||||||
|
INT16 = 7,
|
||||||
|
UINT32 = 8,
|
||||||
|
INT32 = 9,
|
||||||
|
UINT64 = 10,
|
||||||
|
INT64 = 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t SizeOf(DataType type);
|
||||||
|
|
||||||
|
std::string ToString(DataType t);
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_DATA_TYPE_H_
|
547
tensorflow/lite/delegates/gpu/common/model.h
Normal file
547
tensorflow/lite/delegates/gpu/common/model.h
Normal file
@ -0,0 +1,547 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// There is yet another representation of CNN graph. The primary purpose of this
|
||||||
|
// representation is to simplify graph manipulation.
|
||||||
|
|
||||||
|
using ValueId = uint32_t;
|
||||||
|
|
||||||
|
using NodeId = uint32_t;
|
||||||
|
|
||||||
|
// Connects tensor's producer and operation that depends on this tensor.
|
||||||
|
template <typename TensorT>
|
||||||
|
struct Value {
|
||||||
|
using TensorType = TensorT;
|
||||||
|
|
||||||
|
const ValueId id;
|
||||||
|
|
||||||
|
TensorType tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Operation {
|
||||||
|
std::string type;
|
||||||
|
|
||||||
|
absl::any attributes;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Node {
|
||||||
|
const NodeId id;
|
||||||
|
|
||||||
|
Operation operation;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Graph is DAG that consists of nodes and values. Each value may have a single
|
||||||
|
// producer node and multiple consumer nodes. Therefore, each node may have
|
||||||
|
// multiple input and output values.
|
||||||
|
//
|
||||||
|
// Value that does not have a producer is a graph's input. Value that does not
|
||||||
|
// have a consumer is a graph's output.
|
||||||
|
//
|
||||||
|
// Interface provides methods for graph introspection and manipulation. Abstract
|
||||||
|
// interface makes allows subgraphs representation to ensure safe manipulations.
|
||||||
|
template <typename TensorT>
|
||||||
|
class Graph {
|
||||||
|
public:
|
||||||
|
virtual ~Graph() = default;
|
||||||
|
|
||||||
|
// @return a collection of nodes in this graph.
|
||||||
|
virtual std::vector<Node*> nodes() const = 0;
|
||||||
|
|
||||||
|
// @return a collection of values in this graph.
|
||||||
|
virtual std::vector<Value<TensorT>*> values() const = 0;
|
||||||
|
|
||||||
|
// @return graph inputs, that are values without producers.
|
||||||
|
virtual std::vector<Value<TensorT>*> inputs() const = 0;
|
||||||
|
|
||||||
|
// @return graph outputs, that are values without consumers.
|
||||||
|
virtual std::vector<Value<TensorT>*> outputs() const = 0;
|
||||||
|
|
||||||
|
// @return inputs into the given node. Returns empty vector for deleted node.
|
||||||
|
virtual std::vector<Value<TensorT>*> FindInputs(NodeId id) const = 0;
|
||||||
|
|
||||||
|
// @return outputs from the given node. Returns empty vector for deleted node.
|
||||||
|
virtual std::vector<Value<TensorT>*> FindOutputs(NodeId id) const = 0;
|
||||||
|
|
||||||
|
virtual bool IsGraphInput(ValueId id) const = 0;
|
||||||
|
|
||||||
|
virtual bool IsGraphOutput(ValueId id) const = 0;
|
||||||
|
|
||||||
|
// @return producer of the given value. Returns nullptr for deleted value.
|
||||||
|
virtual Node* FindProducer(ValueId id) const = 0;
|
||||||
|
|
||||||
|
// @return consumers of the given value. Returns empty vector for deleted
|
||||||
|
// value.
|
||||||
|
virtual std::vector<Node*> FindConsumers(ValueId id) const = 0;
|
||||||
|
|
||||||
|
// @return a node or nullptr if node with the given id is not present.
|
||||||
|
virtual Node* GetNode(NodeId id) const = 0;
|
||||||
|
|
||||||
|
// @return a value or nullptr if value with the given id is not present.
|
||||||
|
virtual Value<TensorT>* GetValue(ValueId id) const = 0;
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Graph manipulation functions are below
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// @return new node created in this graph
|
||||||
|
// NOTE: nodes should be created in the topological order, e.g. node A that
|
||||||
|
// depends on a value from node B should be created after node B.
|
||||||
|
virtual Node* NewNode() = 0;
|
||||||
|
|
||||||
|
// @return new value created in this graph
|
||||||
|
virtual Value<TensorT>* NewValue() = 0;
|
||||||
|
|
||||||
|
// Sets a producer for the given value. There could be a single producer
|
||||||
|
// for a value. If a value had another producer, it will reassign producer
|
||||||
|
// appropriately. If a value didn't have a producer, it will be removed
|
||||||
|
// from a graph's input.
|
||||||
|
virtual Status SetProducer(NodeId producer, ValueId value) = 0;
|
||||||
|
|
||||||
|
// Removes a producer for the given value. Value becomes producer-less and
|
||||||
|
// therefore becomes graph's input.
|
||||||
|
virtual Status RemoveProducer(ValueId value) = 0;
|
||||||
|
|
||||||
|
// Sets a consumer for the given value. There could be multiple consumers
|
||||||
|
// for a value.
|
||||||
|
virtual Status AddConsumer(NodeId consumer, ValueId value) = 0;
|
||||||
|
|
||||||
|
// Removes a consumer for the given value. If value does not have any
|
||||||
|
// consumers it becomes graph's output.
|
||||||
|
virtual Status RemoveConsumer(NodeId consumer, ValueId value) = 0;
|
||||||
|
|
||||||
|
// Removes node from this graph. For all input values this node will be
|
||||||
|
// removed from consumers and for all output values a producer will be
|
||||||
|
// removed.
|
||||||
|
virtual Status DeleteNode(NodeId id) = 0;
|
||||||
|
|
||||||
|
// Removes value from this graph. It will be removed from inputs for all
|
||||||
|
// dependent nodes. A node that was a producer of this value will loose its
|
||||||
|
// output.
|
||||||
|
virtual Status DeleteValue(ValueId id) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implementation of a Graph interface. It keeps values and nodes referenced by
|
||||||
|
// their index in a vector. Therefore, nodes and values are never deleted, but
|
||||||
|
// rather erased, where corresponding index remains.
|
||||||
|
//
|
||||||
|
// It is possible to re-use removed indices, but it is not implemented yet.
|
||||||
|
template <typename TensorT>
|
||||||
|
class Model : public Graph<TensorT> {
|
||||||
|
public:
|
||||||
|
const std::string& name() const { return name_; }
|
||||||
|
|
||||||
|
void set_name(std::string name) { name_ = std::move(name); }
|
||||||
|
|
||||||
|
std::vector<Value<TensorT>*> values() const final {
|
||||||
|
return FilterValues([](const ValueDef&) { return true; });
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Node*> nodes() const final {
|
||||||
|
return FilterNodes([](const NodeDef&) { return true; });
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value<TensorT>*> inputs() const final {
|
||||||
|
return FilterValues(
|
||||||
|
[](const ValueDef& v) { return v.producer == nullptr; });
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value<TensorT>*> outputs() const final {
|
||||||
|
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsGraphInput(ValueId id) const final {
|
||||||
|
if (id >= values_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return values_[id].producer == nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsGraphOutput(ValueId id) const final {
|
||||||
|
if (id >= values_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return values_[id].consumers.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* GetNode(NodeId id) const final {
|
||||||
|
if (id >= nodes_.size()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return nodes_[id].node.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value<TensorT>* GetValue(ValueId id) const final {
|
||||||
|
if (id >= values_.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return values_[id].value.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* NewNode() final {
|
||||||
|
NodeDef def;
|
||||||
|
def.node =
|
||||||
|
absl::make_unique<Node>(Node{static_cast<NodeId>(nodes_.size()), {}});
|
||||||
|
Node* node = def.node.get();
|
||||||
|
nodes_.push_back(std::move(def));
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value<TensorT>* NewValue() final {
|
||||||
|
ValueDef def;
|
||||||
|
def.value = absl::make_unique<Value<TensorT>>(
|
||||||
|
Value<TensorT>{static_cast<ValueId>(values_.size()), {}});
|
||||||
|
Value<TensorT>* value = def.value.get();
|
||||||
|
values_.push_back(std::move(def));
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value<TensorT>*> FindInputs(NodeId id) const final {
|
||||||
|
if (id >= nodes_.size()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return nodes_[id].inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value<TensorT>*> FindOutputs(NodeId id) const final {
|
||||||
|
if (id >= nodes_.size()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return nodes_[id].outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* FindProducer(ValueId id) const final {
|
||||||
|
if (id >= values_.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return values_[id].producer;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Node*> FindConsumers(ValueId id) const final {
|
||||||
|
if (id >= values_.size()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return values_[id].consumers;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SetProducer(NodeId producer, ValueId value) final {
|
||||||
|
ValueDef* v;
|
||||||
|
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||||
|
Value<TensorT>* value_ptr = v->value.get();
|
||||||
|
NodeDef* n;
|
||||||
|
RETURN_IF_ERROR(LookupNode(producer, &n));
|
||||||
|
Node* node_ptr = n->node.get();
|
||||||
|
|
||||||
|
// check if this value has the same producer already
|
||||||
|
if (node_ptr == v->producer) {
|
||||||
|
return InvalidArgumentError("Node is already a producer of the value");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the node is a consumer of this value.
|
||||||
|
if (std::find(n->inputs.begin(), n->inputs.end(), value_ptr) !=
|
||||||
|
n->inputs.end()) {
|
||||||
|
return InvalidArgumentError("Node is a consumer of the value");
|
||||||
|
}
|
||||||
|
// TODO(akulik): detect circular dependency?
|
||||||
|
|
||||||
|
if (v->producer != nullptr) {
|
||||||
|
// value is no longer produced by it's previous producer.
|
||||||
|
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||||
|
}
|
||||||
|
v->producer = node_ptr;
|
||||||
|
n->outputs.push_back(value_ptr);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoveProducer(ValueId value) final {
|
||||||
|
ValueDef* v;
|
||||||
|
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||||
|
Value<TensorT>* value_ptr = v->value.get();
|
||||||
|
if (v->producer == nullptr) {
|
||||||
|
return InvalidArgumentError("Value does not have a producer");
|
||||||
|
}
|
||||||
|
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||||
|
v->producer = nullptr;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AddConsumer(NodeId consumer, ValueId value) final {
|
||||||
|
ValueDef* v;
|
||||||
|
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||||
|
Value<TensorT>* value_ptr = v->value.get();
|
||||||
|
NodeDef* n;
|
||||||
|
RETURN_IF_ERROR(LookupNode(consumer, &n));
|
||||||
|
Node* node_ptr = n->node.get();
|
||||||
|
|
||||||
|
// check if this value has the same producer already
|
||||||
|
if (node_ptr == v->producer) {
|
||||||
|
return InvalidArgumentError("Node is a producer of the value");
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if this value has the same consumer already
|
||||||
|
if (std::find(n->inputs.begin(), n->inputs.end(), value_ptr) !=
|
||||||
|
n->inputs.end()) {
|
||||||
|
return InvalidArgumentError("Node is already a consumer of the value");
|
||||||
|
}
|
||||||
|
|
||||||
|
n->inputs.push_back(value_ptr);
|
||||||
|
v->consumers.push_back(node_ptr);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoveConsumer(NodeId consumer, ValueId value) final {
|
||||||
|
ValueDef* v;
|
||||||
|
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||||
|
Value<TensorT>* value_ptr = v->value.get();
|
||||||
|
NodeDef* n;
|
||||||
|
RETURN_IF_ERROR(LookupNode(consumer, &n));
|
||||||
|
Node* node_ptr = n->node.get();
|
||||||
|
if (std::find(n->inputs.begin(), n->inputs.end(), value_ptr) ==
|
||||||
|
n->inputs.end()) {
|
||||||
|
return InvalidArgumentError("Node is not a consumer of the value");
|
||||||
|
}
|
||||||
|
Erase(&n->inputs, value_ptr);
|
||||||
|
Erase(&v->consumers, node_ptr);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DeleteNode(NodeId id) final {
|
||||||
|
NodeDef* n;
|
||||||
|
RETURN_IF_ERROR(LookupNode(id, &n));
|
||||||
|
Node* node_ptr = n->node.get();
|
||||||
|
for (auto value : n->inputs) {
|
||||||
|
Erase(&values_[value->id].consumers, node_ptr);
|
||||||
|
}
|
||||||
|
for (auto value : n->outputs) {
|
||||||
|
values_[value->id].producer = nullptr;
|
||||||
|
}
|
||||||
|
n->inputs.clear();
|
||||||
|
n->outputs.clear();
|
||||||
|
n->node.reset();
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DeleteValue(ValueId id) final {
|
||||||
|
ValueDef* v;
|
||||||
|
RETURN_IF_ERROR(LookupValue(id, &v));
|
||||||
|
Value<TensorT>* value_ptr = v->value.get();
|
||||||
|
if (v->producer != nullptr) {
|
||||||
|
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||||
|
}
|
||||||
|
if (!v->consumers.empty()) {
|
||||||
|
for (auto node : v->consumers) {
|
||||||
|
Erase(&nodes_[node->id].inputs, value_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v->producer = nullptr;
|
||||||
|
v->consumers.clear();
|
||||||
|
v->value.reset();
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MakeExactCopy(Model<TensorT>* model) const {
|
||||||
|
model->nodes_.clear();
|
||||||
|
model->values_.clear();
|
||||||
|
model->name_ = name_;
|
||||||
|
for (auto& value_def : values_) {
|
||||||
|
model->values_.push_back({});
|
||||||
|
if (value_def.value) {
|
||||||
|
model->values_.back().value =
|
||||||
|
absl::make_unique<Value<TensorT>>(*value_def.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& node_def : nodes_) {
|
||||||
|
model->nodes_.push_back({});
|
||||||
|
if (node_def.node) {
|
||||||
|
model->nodes_.back().node = absl::make_unique<Node>(*node_def.node);
|
||||||
|
for (auto output : node_def.outputs) {
|
||||||
|
RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
|
||||||
|
}
|
||||||
|
for (auto input : node_def.inputs) {
|
||||||
|
RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct NodeDef {
|
||||||
|
std::vector<Value<TensorT>*> inputs;
|
||||||
|
std::vector<Value<TensorT>*> outputs;
|
||||||
|
std::unique_ptr<Node> node;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ValueDef {
|
||||||
|
Node* producer = nullptr;
|
||||||
|
std::vector<Node*> consumers;
|
||||||
|
std::unique_ptr<Value<TensorT>> value;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void Erase(std::vector<T>* values, T value) {
|
||||||
|
values->erase(std::find(values->begin(), values->end(), value));
|
||||||
|
}
|
||||||
|
|
||||||
|
// @return non-nullptr NodeDef that has valid Node or an error
|
||||||
|
Status LookupNode(NodeId id, NodeDef** node_def) {
|
||||||
|
if (id >= nodes_.size()) {
|
||||||
|
return OutOfRangeError("NodeId is out of range");
|
||||||
|
}
|
||||||
|
auto& n = nodes_[id];
|
||||||
|
if (!n.node) {
|
||||||
|
return OutOfRangeError("Node is already deleted");
|
||||||
|
}
|
||||||
|
*node_def = &n;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// @return non-nullptr ValueDef that has valid Value or an error
|
||||||
|
Status LookupValue(ValueId id, ValueDef** value_def) {
|
||||||
|
if (id >= values_.size()) {
|
||||||
|
return OutOfRangeError("ValueId is out of range");
|
||||||
|
}
|
||||||
|
auto& v = values_[id];
|
||||||
|
if (!v.value) {
|
||||||
|
return OutOfRangeError("Value is already deleted");
|
||||||
|
}
|
||||||
|
*value_def = &v;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Pred>
|
||||||
|
std::vector<Value<TensorT>*> FilterValues(const Pred& predicate) const {
|
||||||
|
std::vector<Value<TensorT>*> values;
|
||||||
|
values.reserve(values_.size());
|
||||||
|
for (auto& v : values_) {
|
||||||
|
if (v.value != nullptr && predicate(v)) {
|
||||||
|
values.push_back(v.value.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Pred>
|
||||||
|
std::vector<Node*> FilterNodes(const Pred& predicate) const {
|
||||||
|
std::vector<Node*> nodes;
|
||||||
|
nodes.reserve(nodes_.size());
|
||||||
|
for (auto& n : nodes_) {
|
||||||
|
if (n.node != nullptr && predicate(n)) {
|
||||||
|
nodes.push_back(n.node.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string name_;
|
||||||
|
|
||||||
|
// There are two approaches possible: wrap entire NodeDef and ValueDef into
|
||||||
|
// unique_ptr and store it in values_ and nodes_ or store it by value.
|
||||||
|
// We store it by value here to make introspection calls cheaper.
|
||||||
|
std::vector<ValueDef> values_;
|
||||||
|
std::vector<NodeDef> nodes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Removes to_remove node that precedes to_keep node only if to_remove has
|
||||||
|
// outputs that are consumed only by to_keep. In such case to_keep inherits all
|
||||||
|
// to_remove inputs.
|
||||||
|
template <typename TensorT>
|
||||||
|
Status RemovePrecedingNode(Graph<TensorT>* graph, const Node* to_remove,
|
||||||
|
const Node* to_keep) {
|
||||||
|
// Make sure all outputs from to_remove are consumed by to_keep.
|
||||||
|
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||||
|
auto consumers = graph->FindConsumers(output->id);
|
||||||
|
if (consumers.size() > 1 ||
|
||||||
|
(consumers.size() == 1 && consumers[0] != to_keep)) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Output from to_remove node has other consumers");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update all references
|
||||||
|
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||||
|
RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
|
||||||
|
}
|
||||||
|
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||||
|
RETURN_IF_ERROR(graph->DeleteValue(output->id));
|
||||||
|
}
|
||||||
|
return graph->DeleteNode(to_remove->id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes to_remove node that follows to_keep node only if to_remove has inputs
|
||||||
|
// that are produced by to_keep. to_keep inherits all to_remove inputs.
|
||||||
|
template <typename TensorT>
|
||||||
|
Status RemoveFollowingNode(Graph<TensorT>* graph, const Node* to_remove,
|
||||||
|
const Node* to_keep) {
|
||||||
|
// Make sure all inputs to to_remove are produced by to_keep.
|
||||||
|
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||||
|
Node* producer = graph->FindProducer(input->id);
|
||||||
|
if (producer->id != to_keep->id) {
|
||||||
|
return InvalidArgumentError("To_remove node has other inputs");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||||
|
RETURN_IF_ERROR(graph->DeleteValue(input->id));
|
||||||
|
}
|
||||||
|
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||||
|
RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
|
||||||
|
}
|
||||||
|
return graph->DeleteNode(to_remove->id);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TensorT>
|
||||||
|
Status AddOutput(Graph<TensorT>* graph, const Node* from_node,
|
||||||
|
Value<TensorT>** output) {
|
||||||
|
auto link = graph->NewValue();
|
||||||
|
RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
|
||||||
|
*output = link;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TensorT>
|
||||||
|
Status ConnectTwoNodes(Graph<TensorT>* graph, const Node* from_node,
|
||||||
|
const Node* to_node, Value<TensorT>** output) {
|
||||||
|
Value<TensorT>* link;
|
||||||
|
RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
|
||||||
|
RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
|
||||||
|
*output = link;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
using GraphFloat32 = Model<TensorRef<BHWC>>;
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_
|
1935
tensorflow/lite/delegates/gpu/common/model_builder.cc
Normal file
1935
tensorflow/lite/delegates/gpu/common/model_builder.cc
Normal file
File diff suppressed because it is too large
Load Diff
45
tensorflow/lite/delegates/gpu/common/model_builder.h
Normal file
45
tensorflow/lite/delegates/gpu/common/model_builder.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/context.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Validates which operations are supported and returns array of operations to
|
||||||
|
// replace with GPU kernels. The caller must free the pointer on TfLiteIntArray.
|
||||||
|
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context);
|
||||||
|
|
||||||
|
// Extracts TFLite delegate execution plan from the input TFLite context and
|
||||||
|
// converts it into generic graph format.
|
||||||
|
Status BuildModel(TfLiteContext* context,
|
||||||
|
const TfLiteDelegateParams* delegate_params,
|
||||||
|
GraphFloat32* graph);
|
||||||
|
|
||||||
|
Status ConvertTfliteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||||
|
TensorRefFloat32* flow_tensor);
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
|
275
tensorflow/lite/delegates/gpu/common/model_test.cc
Normal file
275
tensorflow/lite/delegates/gpu/common/model_test.cc
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::UnorderedElementsAre;
|
||||||
|
|
||||||
|
TEST(Model, SingleNode) {
|
||||||
|
// graph_input -> node -> graph_output
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||||
|
|
||||||
|
EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node));
|
||||||
|
EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output));
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
|
||||||
|
EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input));
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node->id), UnorderedElementsAre(graph_output));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_input->id), UnorderedElementsAre(node));
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_input->id), ::testing::Eq(nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, SingleNodeMultipleOutputs) {
|
||||||
|
// graph_input -> node -> (graph_output1, graph_output2)
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output1 = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output2 = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node->id),
|
||||||
|
UnorderedElementsAre(graph_output1, graph_output2));
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_output1->id), ::testing::Eq(node));
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_output2->id), ::testing::Eq(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, SetSameConsumer) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||||
|
EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, RemoveConsumer) {
|
||||||
|
// (graph_input1, graph_input2) -> node
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input1 = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_input2 = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_input1->id),
|
||||||
|
UnorderedElementsAre(node));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_input2->id),
|
||||||
|
UnorderedElementsAre(node));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node->id),
|
||||||
|
UnorderedElementsAre(graph_input1, graph_input2));
|
||||||
|
EXPECT_THAT(graph.outputs(), UnorderedElementsAre());
|
||||||
|
|
||||||
|
// Now remove graph_input1
|
||||||
|
ASSERT_TRUE(graph.RemoveConsumer(node->id, graph_input1->id).ok());
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_input1->id), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input2));
|
||||||
|
EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_input1));
|
||||||
|
|
||||||
|
// Can not remove it twice
|
||||||
|
ASSERT_FALSE(graph.RemoveConsumer(node->id, graph_input1->id).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, SetSameProducer) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||||
|
EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, RemoveProducer) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node));
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.RemoveProducer(graph_output->id).ok());
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output));
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(nullptr));
|
||||||
|
|
||||||
|
// Can not remove producer twice
|
||||||
|
ASSERT_FALSE(graph.RemoveProducer(graph_output->id).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, CircularDependency) {
|
||||||
|
{
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
|
||||||
|
EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
|
||||||
|
EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, ReassignValue) {
|
||||||
|
// Before:
|
||||||
|
// graph_input -> node1 -> graph_output
|
||||||
|
// \ -> node2
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node1 = graph.NewNode();
|
||||||
|
Node* node2 = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
|
||||||
|
|
||||||
|
// After:
|
||||||
|
// graph_input -> node1
|
||||||
|
// \ -> node2 -> graph_output
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
|
||||||
|
|
||||||
|
EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre(graph_input));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(graph_input));
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre(graph_output));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_input->id),
|
||||||
|
UnorderedElementsAre(node1, node2));
|
||||||
|
EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node2));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, DeleteValue) {
|
||||||
|
// graph_input -> node1 -> value -> node2 -> graph_output
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node1 = graph.NewNode();
|
||||||
|
Node* node2 = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
|
||||||
|
|
||||||
|
EXPECT_THAT(graph.values(),
|
||||||
|
UnorderedElementsAre(graph_input, graph_output, value));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
|
||||||
|
EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre(value));
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.DeleteValue(value->id).ok());
|
||||||
|
value = nullptr;
|
||||||
|
EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre());
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.DeleteValue(graph_input->id).ok());
|
||||||
|
graph_input = nullptr;
|
||||||
|
EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_output));
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre());
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.DeleteValue(graph_output->id).ok());
|
||||||
|
graph_output = nullptr;
|
||||||
|
EXPECT_THAT(graph.values(), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.outputs(), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Model, DeleteNode) {
|
||||||
|
// graph_input -> node1 -> value -> node2 -> graph_output
|
||||||
|
// \-> node3 -> graph_output2
|
||||||
|
GraphFloat32 graph;
|
||||||
|
Node* node1 = graph.NewNode();
|
||||||
|
Node* node2 = graph.NewNode();
|
||||||
|
Node* node3 = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* graph_output2 = graph.NewValue();
|
||||||
|
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node3->id, value->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(node3->id, graph_output2->id).ok());
|
||||||
|
|
||||||
|
EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2, node3));
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
|
||||||
|
EXPECT_THAT(graph.outputs(),
|
||||||
|
UnorderedElementsAre(graph_output, graph_output2));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(value->id),
|
||||||
|
UnorderedElementsAre(node2, node3));
|
||||||
|
EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
|
||||||
|
EXPECT_THAT(graph.FindInputs(node3->id), UnorderedElementsAre(value));
|
||||||
|
|
||||||
|
// graph_input -> node1 -> value -> node2 -> graph_output
|
||||||
|
// graph_output2
|
||||||
|
ASSERT_TRUE(graph.DeleteNode(node3->id).ok());
|
||||||
|
node3 = nullptr;
|
||||||
|
EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2));
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input, graph_output2));
|
||||||
|
EXPECT_THAT(graph.outputs(),
|
||||||
|
UnorderedElementsAre(graph_output, graph_output2));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
|
||||||
|
|
||||||
|
// value -> node2 -> graph_output
|
||||||
|
// graph_input
|
||||||
|
// graph_output2
|
||||||
|
ASSERT_TRUE(graph.DeleteNode(node1->id).ok());
|
||||||
|
node1 = nullptr;
|
||||||
|
EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node2));
|
||||||
|
EXPECT_THAT(graph.inputs(),
|
||||||
|
UnorderedElementsAre(value, graph_output2, graph_input));
|
||||||
|
EXPECT_THAT(graph.outputs(),
|
||||||
|
UnorderedElementsAre(graph_input, graph_output, graph_output2));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
|
||||||
|
EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr));
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.DeleteNode(node2->id).ok());
|
||||||
|
node2 = nullptr;
|
||||||
|
EXPECT_THAT(graph.nodes(), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output, graph_output2,
|
||||||
|
graph_input, value));
|
||||||
|
EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2,
|
||||||
|
graph_input, value));
|
||||||
|
EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre());
|
||||||
|
EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
197
tensorflow/lite/delegates/gpu/common/model_transformer.cc
Normal file
197
tensorflow/lite/delegates/gpu/common/model_transformer.cc
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
bool ModelTransformer::Apply(const std::string& name,
|
||||||
|
SequenceTransformation* transformation) {
|
||||||
|
// Seed transformations with starting node. Each node may start a chain of
|
||||||
|
// transformations.
|
||||||
|
for (auto input : graph_->inputs()) {
|
||||||
|
for (auto node : graph_->FindConsumers(input->id)) {
|
||||||
|
AddNodeToProcess(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!to_process_.empty()) {
|
||||||
|
auto node = graph_->GetNode(to_process_.front());
|
||||||
|
if (node) {
|
||||||
|
if (!ApplyStartingWithNode(name, transformation, node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
to_process_.pop_front();
|
||||||
|
}
|
||||||
|
processed_.clear();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ModelTransformer::Apply(const std::string& name,
|
||||||
|
NodeTransformation* transformation) {
|
||||||
|
// Apply a transformation only to nodes that are present in the graph before
|
||||||
|
// transformation.
|
||||||
|
std::vector<NodeId> nodes;
|
||||||
|
for (auto node : graph_->nodes()) {
|
||||||
|
nodes.push_back(node->id);
|
||||||
|
}
|
||||||
|
for (auto node_id : nodes) {
|
||||||
|
auto node = graph_->GetNode(node_id);
|
||||||
|
if (!node) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto result = transformation->ApplyToNode(node, graph_);
|
||||||
|
if (result.status == TransformStatus::INVALID) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (reporter_) {
|
||||||
|
if (result.status == TransformStatus::APPLIED) {
|
||||||
|
reporter_->AppliedTransformation(name, std::to_string(node_id),
|
||||||
|
result.message);
|
||||||
|
}
|
||||||
|
if (result.status == TransformStatus::DECLINED) {
|
||||||
|
reporter_->DeclinedTransformation(name, std::to_string(node_id),
|
||||||
|
result.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ModelTransformer::ApplyStartingWithNode(
|
||||||
|
const std::string& name, SequenceTransformation* transformation,
|
||||||
|
Node* begin) {
|
||||||
|
int expected_sequence_length = transformation->ExpectedSequenceLength();
|
||||||
|
|
||||||
|
std::deque<NodeId> sequence;
|
||||||
|
std::vector<Node*> nodes;
|
||||||
|
nodes.reserve(transformation->ExpectedSequenceLength());
|
||||||
|
sequence.push_back(begin->id);
|
||||||
|
|
||||||
|
// Go over nodes with sequence sliding window of size
|
||||||
|
// expected_sequence_length until a node with multiple dependents is found.
|
||||||
|
while (true) {
|
||||||
|
// Apply transformation if possible.
|
||||||
|
if (sequence.size() == expected_sequence_length) {
|
||||||
|
nodes.clear();
|
||||||
|
for (NodeId id : sequence) {
|
||||||
|
// Nodes present in sequence should be present in a graph. If they are
|
||||||
|
// not, then this transformation changes a graph but didn't say it.
|
||||||
|
Node* node = graph_->GetNode(id);
|
||||||
|
if (node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
nodes.push_back(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeId first_in_sequence = sequence.front();
|
||||||
|
auto preceding_node =
|
||||||
|
graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id);
|
||||||
|
auto result = transformation->ApplyToNodesSequence(nodes, graph_);
|
||||||
|
if (result.status == TransformStatus::INVALID) {
|
||||||
|
// graph is broken now.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (result.status == TransformStatus::DECLINED) {
|
||||||
|
if (reporter_) {
|
||||||
|
reporter_->DeclinedTransformation(name, absl::StrJoin(sequence, "+"),
|
||||||
|
result.message);
|
||||||
|
}
|
||||||
|
} else if (result.status == TransformStatus::APPLIED) {
|
||||||
|
if (reporter_) {
|
||||||
|
reporter_->AppliedTransformation(name, absl::StrJoin(sequence, "+"),
|
||||||
|
result.message);
|
||||||
|
}
|
||||||
|
// Also remove first node of a sequence from a set of processed node.
|
||||||
|
// Out of all nodes in a sequence only first one may have been added
|
||||||
|
// to "processed" set because other nodes do not have more than one
|
||||||
|
// dependent. However, if a sequence is changed, then processing needs
|
||||||
|
// to be restarted again.
|
||||||
|
processed_.erase(first_in_sequence);
|
||||||
|
// Transformation was successful. Restart sequence from the node that
|
||||||
|
// precedes current sequence.
|
||||||
|
if (preceding_node) {
|
||||||
|
processed_.erase(preceding_node->id);
|
||||||
|
AddNodeToProcess(preceding_node);
|
||||||
|
} else {
|
||||||
|
// This is the first node in the graph. Re-seed transformation.
|
||||||
|
for (auto input : graph_->inputs()) {
|
||||||
|
for (auto node : graph_->FindConsumers(input->id)) {
|
||||||
|
AddNodeToProcess(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to extend current sequence.
|
||||||
|
Node* next_node_in_sequence = nullptr;
|
||||||
|
bool has_multiple_children = false;
|
||||||
|
|
||||||
|
// Check that all outputs from last node are consumed by a single node.
|
||||||
|
for (auto output_value : graph_->FindOutputs(sequence.back())) {
|
||||||
|
for (auto dependent : graph_->FindConsumers(output_value->id)) {
|
||||||
|
if (has_multiple_children) {
|
||||||
|
AddNodeToProcess(dependent);
|
||||||
|
} else if (next_node_in_sequence == nullptr) {
|
||||||
|
next_node_in_sequence = dependent;
|
||||||
|
} else if (next_node_in_sequence != dependent) {
|
||||||
|
// There are more than two nodes depend on the output from end node,
|
||||||
|
// therefore here a sequence stops and new will start. Push all such
|
||||||
|
// nodes.
|
||||||
|
has_multiple_children = true;
|
||||||
|
AddNodeToProcess(dependent);
|
||||||
|
AddNodeToProcess(next_node_in_sequence);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now check that next node has inputs only produced by the last node.
|
||||||
|
if (!has_multiple_children && next_node_in_sequence) {
|
||||||
|
for (auto input : graph_->FindInputs(next_node_in_sequence->id)) {
|
||||||
|
auto producer = graph_->FindProducer(input->id);
|
||||||
|
if (producer == nullptr || producer->id != sequence.back()) {
|
||||||
|
has_multiple_children = true;
|
||||||
|
AddNodeToProcess(next_node_in_sequence);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_multiple_children || next_node_in_sequence == nullptr) {
|
||||||
|
// reached end of this transformation sequence.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
sequence.push_back(next_node_in_sequence->id);
|
||||||
|
// Decrease sequence until it matches expected length.
|
||||||
|
if (sequence.size() > expected_sequence_length) {
|
||||||
|
sequence.pop_front();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
146
tensorflow/lite/delegates/gpu/common/model_transformer.h
Normal file
146
tensorflow/lite/delegates/gpu/common/model_transformer.h
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
class TransformationReporter;
|
||||||
|
|
||||||
|
struct TransformationContext {
|
||||||
|
GraphFloat32* graph;
|
||||||
|
TransformationReporter* reporter;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class TransformStatus {
|
||||||
|
// Transformation was not applied due to trivial conditions mismatch.
|
||||||
|
//
|
||||||
|
// This is different from DECLINED code below that provides in-depth
|
||||||
|
// explanation why a transformation that could have been applied but was not
|
||||||
|
// due to some issues.
|
||||||
|
SKIPPED,
|
||||||
|
|
||||||
|
// Transformation was declined, therefore, a model was not modified.
|
||||||
|
DECLINED,
|
||||||
|
|
||||||
|
// Transformation was applied successfully
|
||||||
|
APPLIED,
|
||||||
|
|
||||||
|
// Transformation may partially be applied, but left a model in an invalid
|
||||||
|
// state. This error should be considered unrecoverable.
|
||||||
|
INVALID,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TransformResult {
|
||||||
|
TransformStatus status;
|
||||||
|
std::string message;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class responsible for applying a transformation to a single node.
|
||||||
|
class NodeTransformation {
|
||||||
|
public:
|
||||||
|
virtual ~NodeTransformation() = default;
|
||||||
|
|
||||||
|
virtual TransformResult ApplyToNode(Node* node, GraphFloat32* graph) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class responsible for applying a transformation to a sequence of nodes.
|
||||||
|
// Nodes are guaranteed to depend on each other without extra dependents being
|
||||||
|
// spilled.
|
||||||
|
class SequenceTransformation {
|
||||||
|
public:
|
||||||
|
virtual ~SequenceTransformation() = default;
|
||||||
|
|
||||||
|
// @return number of nodes in a sequence to apply this transformation.
|
||||||
|
virtual int ExpectedSequenceLength() const = 0;
|
||||||
|
|
||||||
|
// Applies transformations to a sequence of nodes. Transformation
|
||||||
|
// implementation is free manipulate with sequence nodes including adding
|
||||||
|
// and/or deleting nodes. if there were updates to nodes in the end and/or
|
||||||
|
// beginning of the sequence, then referential consistency should be
|
||||||
|
// maintained by updating relevant references in nodes that precede this
|
||||||
|
// sequence or depend on a last node of the sequence.
|
||||||
|
virtual TransformResult ApplyToNodesSequence(
|
||||||
|
const std::vector<Node*>& sequence, GraphFloat32* graph) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A class accumulated decisions or updates done by transformations.
|
||||||
|
class TransformationReporter {
|
||||||
|
public:
|
||||||
|
virtual ~TransformationReporter() = default;
|
||||||
|
|
||||||
|
virtual void DeclinedTransformation(const std::string& transformation,
|
||||||
|
const std::string& node_ids,
|
||||||
|
const std::string& message) = 0;
|
||||||
|
|
||||||
|
virtual void AppliedTransformation(const std::string& transformation,
|
||||||
|
const std::string& node_ids,
|
||||||
|
const std::string& message) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A class is designed to perform model transformations.
|
||||||
|
class ModelTransformer {
|
||||||
|
public:
|
||||||
|
ModelTransformer(GraphFloat32* graph, TransformationReporter* reporter)
|
||||||
|
: graph_(graph), reporter_(reporter) {}
|
||||||
|
|
||||||
|
// @return false if a graph is in the broken states can not be used any more
|
||||||
|
bool Apply(const std::string& name, SequenceTransformation* transformation);
|
||||||
|
|
||||||
|
// @return false if a graph is in the broken states can not be used any more
|
||||||
|
bool Apply(const std::string& name, NodeTransformation* transformation);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool ApplyStartingWithNode(const std::string& name,
|
||||||
|
SequenceTransformation* transformation,
|
||||||
|
Node* begin);
|
||||||
|
|
||||||
|
void AddNodeToProcess(Node* node) {
|
||||||
|
if (node && processed_.insert(node->id).second) {
|
||||||
|
to_process_.push_back(node->id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphFloat32* graph_;
|
||||||
|
TransformationReporter* reporter_;
|
||||||
|
|
||||||
|
std::deque<NodeId> to_process_;
|
||||||
|
std::unordered_set<NodeId> processed_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NullTransformationReporter : public TransformationReporter {
|
||||||
|
public:
|
||||||
|
void DeclinedTransformation(const std::string& transformation,
|
||||||
|
const std::string& nodes_id,
|
||||||
|
const std::string& message) override {}
|
||||||
|
|
||||||
|
void AppliedTransformation(const std::string& transformation,
|
||||||
|
const std::string& nodes_id,
|
||||||
|
const std::string& message) override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
|
396
tensorflow/lite/delegates/gpu/common/operations.cc
Normal file
396
tensorflow/lite/delegates/gpu/common/operations.cc
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
Padding2D& Padding2D::operator=(const Padding2D& value) {
|
||||||
|
prepended = value.prepended;
|
||||||
|
appended = value.appended;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Padding2D::operator==(const Padding2D& value) {
|
||||||
|
return this->prepended == value.prepended && this->appended == value.appended;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Padding2D::operator!=(const Padding2D& value) { return !(*this == value); }
|
||||||
|
|
||||||
|
std::string ToString(enum OperationType op) {
|
||||||
|
switch (op) {
|
||||||
|
case OperationType::UNKNOWN:
|
||||||
|
break;
|
||||||
|
case OperationType::ABS:
|
||||||
|
return "abs";
|
||||||
|
case OperationType::ADD:
|
||||||
|
return "add";
|
||||||
|
case OperationType::APPLY_MASK:
|
||||||
|
return "apply_mask";
|
||||||
|
case OperationType::SUB:
|
||||||
|
return "subtract";
|
||||||
|
case OperationType::POOLING_2D:
|
||||||
|
return "pooling_2d";
|
||||||
|
case OperationType::MAX_UNPOOLING_2D:
|
||||||
|
return "max_unpooling";
|
||||||
|
case OperationType::BATCH_NORMALIZATION:
|
||||||
|
return "batch_normalization";
|
||||||
|
case OperationType::CONCAT:
|
||||||
|
return "concat";
|
||||||
|
case OperationType::CONST:
|
||||||
|
return "const";
|
||||||
|
case OperationType::CONVOLUTION_2D:
|
||||||
|
return "convolution_2d";
|
||||||
|
case OperationType::COS:
|
||||||
|
return "cos";
|
||||||
|
case OperationType::DEPTHWISE_CONVOLUTION:
|
||||||
|
return "depthwise_convolution";
|
||||||
|
case OperationType::LOG:
|
||||||
|
return "log";
|
||||||
|
case OperationType::MUL:
|
||||||
|
return "mul";
|
||||||
|
case OperationType::PAD:
|
||||||
|
return "pad";
|
||||||
|
case OperationType::PRELU:
|
||||||
|
return "prelu";
|
||||||
|
case OperationType::RELU:
|
||||||
|
return "relu";
|
||||||
|
case OperationType::RESIZE:
|
||||||
|
return "resize";
|
||||||
|
case OperationType::RESHAPE:
|
||||||
|
return "reshape";
|
||||||
|
case OperationType::RSQRT:
|
||||||
|
return "rsqrt";
|
||||||
|
case OperationType::SIGMOID:
|
||||||
|
return "sigmoid";
|
||||||
|
case OperationType::SIN:
|
||||||
|
return "sin";
|
||||||
|
case OperationType::SLICE:
|
||||||
|
return "slice";
|
||||||
|
case OperationType::SOFT_MAX:
|
||||||
|
return "soft_max";
|
||||||
|
case OperationType::SQRT:
|
||||||
|
return "sqrt";
|
||||||
|
case OperationType::SQUARE:
|
||||||
|
return "square";
|
||||||
|
case OperationType::UPSAMPLE_2D:
|
||||||
|
return "upsample_2d";
|
||||||
|
case OperationType::CONVOLUTION_TRANSPOSED:
|
||||||
|
return "convolution_transposed";
|
||||||
|
case OperationType::MULTIPLY_SCALAR:
|
||||||
|
return "multiply_scalar";
|
||||||
|
case OperationType::FULLY_CONNECTED:
|
||||||
|
return "fully_connected";
|
||||||
|
case OperationType::TANH:
|
||||||
|
return "tanh";
|
||||||
|
case OperationType::LSTM:
|
||||||
|
return "lstm";
|
||||||
|
}
|
||||||
|
return "unknown_operation";
|
||||||
|
}
|
||||||
|
|
||||||
|
OperationType OperationTypeFromString(const std::string& name) {
|
||||||
|
static const auto operations =
|
||||||
|
new std::unordered_map<std::string, OperationType>({
|
||||||
|
{"abs", OperationType::ABS},
|
||||||
|
{"add", OperationType::ADD},
|
||||||
|
{"apply_mask", OperationType::APPLY_MASK},
|
||||||
|
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
||||||
|
{"concat", OperationType::CONCAT},
|
||||||
|
{"const", OperationType::CONST},
|
||||||
|
{"convolution_2d", OperationType::CONVOLUTION_2D},
|
||||||
|
{"convolution_transposed", OperationType::CONVOLUTION_TRANSPOSED},
|
||||||
|
{"cos", OperationType::COS},
|
||||||
|
{"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
|
||||||
|
{"fully_connected", OperationType::FULLY_CONNECTED},
|
||||||
|
{"log", OperationType::LOG},
|
||||||
|
{"lstm", OperationType::LSTM},
|
||||||
|
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
||||||
|
{"mul", OperationType::MUL},
|
||||||
|
{"multiply_scalar", OperationType::MULTIPLY_SCALAR},
|
||||||
|
{"pad", OperationType::PAD},
|
||||||
|
{"pooling_2d", OperationType::POOLING_2D},
|
||||||
|
{"prelu", OperationType::PRELU},
|
||||||
|
{"relu", OperationType::RELU},
|
||||||
|
{"resize", OperationType::RESIZE},
|
||||||
|
{"reshape", OperationType::RESHAPE},
|
||||||
|
{"rsqrt", OperationType::RSQRT},
|
||||||
|
{"sigmoid", OperationType::SIGMOID},
|
||||||
|
{"sin", OperationType::SIN},
|
||||||
|
{"slice", OperationType::SLICE},
|
||||||
|
{"soft_max", OperationType::SOFT_MAX},
|
||||||
|
{"sqrt", OperationType::SQRT},
|
||||||
|
{"square", OperationType::SQUARE},
|
||||||
|
{"subtract", OperationType::SUB},
|
||||||
|
{"tanh", OperationType::TANH},
|
||||||
|
{"upsample_2d", OperationType::UPSAMPLE_2D},
|
||||||
|
});
|
||||||
|
auto op = operations->find(name);
|
||||||
|
return op == operations->end() ? OperationType::UNKNOWN : op->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T IntegralDivideRoundUp(T n, T divisor) {
|
||||||
|
return (n - 1) / divisor + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel,
|
||||||
|
int32_t padding, int32_t dilation) {
|
||||||
|
const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
|
||||||
|
return input + padding - dilated_kernel + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis T>
|
||||||
|
int32_t CalculateOutputWithoutStrides(const BHWC& input,
|
||||||
|
const Convolution2DAttributes& attr) {
|
||||||
|
return CalculateOutputSizeBeforeStrides(
|
||||||
|
input.get<T>(), attr.weights.shape.get<T>(),
|
||||||
|
attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
|
||||||
|
attr.dilations.get<T>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis T>
|
||||||
|
int32_t CalculateOutputWithoutStrides(const BHWC& input,
|
||||||
|
const Pooling2DAttributes& attr) {
|
||||||
|
return CalculateOutputSizeBeforeStrides(
|
||||||
|
input.get<T>(), attr.kernel.get<T>(),
|
||||||
|
attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
|
||||||
|
/*dilation=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis T>
|
||||||
|
int32_t CalculateOutput(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
|
return (input.get<T>() - 1) * attr.stride.get<T>() -
|
||||||
|
(attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
|
||||||
|
attr.weights.shape.get<T>() + attr.adjacent.get<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int32_t StridedSize(int32_t size, int32_t stride) {
|
||||||
|
return stride == 0 ? -1 : IntegralDivideRoundUp(size, stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis AxisT, typename AttrT>
|
||||||
|
int32_t CalculateOutput(const BHWC& input, const AttrT& attr) {
|
||||||
|
return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
|
||||||
|
attr.strides.template get<AxisT>());
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation,
|
||||||
|
int32_t stride) {
|
||||||
|
const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
|
||||||
|
return std::max(0, dilated_kernel - (input - 1) % stride - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a padding that should be present to make sure image size stays
|
||||||
|
// the same.
|
||||||
|
template <Axis AxisT>
|
||||||
|
int32_t CalculateSamePadding(const BHWC& input,
|
||||||
|
const Convolution2DAttributes& attr) {
|
||||||
|
return CalculateSamePadding(
|
||||||
|
input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
|
||||||
|
attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis AxisT>
|
||||||
|
int32_t CalculateSamePadding(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
|
return CalculateSamePadding(input.get<AxisT>(),
|
||||||
|
attr.weights.shape.get<AxisT>(),
|
||||||
|
/*dilation=*/1, attr.stride.get<AxisT>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis AxisT>
|
||||||
|
int32_t CalculateSamePadding(const BHWC& input,
|
||||||
|
const Pooling2DAttributes& attr) {
|
||||||
|
return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
|
||||||
|
/*dilation=*/1, attr.strides.get<AxisT>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis AxisT>
|
||||||
|
int32_t CalculateSamePadding(const BHWC& input,
|
||||||
|
const MaxUnpooling2DAttributes& attr) {
|
||||||
|
return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
|
||||||
|
/*dilation=*/1, attr.strides.get<AxisT>());
|
||||||
|
}
|
||||||
|
|
||||||
|
Padding2D MakeSamePadding(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
|
int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
|
||||||
|
int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
|
||||||
|
Padding2D padding;
|
||||||
|
padding.prepended = HW(padding_height / 2, padding_width / 2);
|
||||||
|
padding.appended = HW(padding_height - padding_height / 2,
|
||||||
|
padding_width - padding_width / 2);
|
||||||
|
return padding;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If padding depends on input, convert it into fixed padding.
|
||||||
|
template <class AttrT>
|
||||||
|
Padding2D MakeSamePadding(const BHWC& input, const AttrT& attr) {
|
||||||
|
int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
|
||||||
|
int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
|
||||||
|
Padding2D padding;
|
||||||
|
padding.prepended = HW(padding_height / 2, padding_width / 2);
|
||||||
|
padding.appended = HW(padding_height - padding_height / 2,
|
||||||
|
padding_width - padding_width / 2);
|
||||||
|
return padding;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const MaxUnpooling2DAttributes& attr) {
|
||||||
|
return BHWC(input.b,
|
||||||
|
input.h * attr.strides.h - attr.padding.prepended.h -
|
||||||
|
attr.padding.appended.h,
|
||||||
|
input.w * attr.strides.w - attr.padding.prepended.w -
|
||||||
|
attr.padding.appended.w,
|
||||||
|
input.c);
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr) {
|
||||||
|
return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
|
||||||
|
CalculateOutput<Axis::WIDTH>(input, attr), input.c);
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const Convolution2DAttributes& attr) {
|
||||||
|
return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
|
||||||
|
CalculateOutput<Axis::WIDTH>(input, attr),
|
||||||
|
attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
|
return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
|
||||||
|
CalculateOutput<Axis::WIDTH>(input, attr),
|
||||||
|
attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const DepthwiseConvolution2DAttributes& attr) {
|
||||||
|
return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
|
||||||
|
CalculateOutput<Axis::WIDTH>(input, attr),
|
||||||
|
attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
|
||||||
|
attr.weights.shape.get<Axis::INPUT_CHANNELS>());
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
|
||||||
|
return BHWC(input.b, StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
|
||||||
|
StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
|
||||||
|
StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) {
|
||||||
|
return BHWC(input.b, attr.appended.h + attr.prepended.h + input.h,
|
||||||
|
attr.appended.w + attr.prepended.w + input.w,
|
||||||
|
attr.appended.c + attr.prepended.c + input.c);
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const FullyConnectedAttributes& attr) {
|
||||||
|
return BHWC(input.b, 1, 1, attr.weights.shape.o);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CalculateOutputShape(const std::vector<BHWC>& input,
|
||||||
|
const ConcatAttributes& attr, BHWC* output_shape) {
|
||||||
|
BHWC new_shape = input[0];
|
||||||
|
switch (attr.axis) {
|
||||||
|
case Axis::CHANNELS:
|
||||||
|
for (int i = 1; i < input.size(); i++) {
|
||||||
|
if (input[i].h != new_shape.h || input[i].w != new_shape.w) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Height and Width must be the same when concatenating "
|
||||||
|
"by channels axis");
|
||||||
|
}
|
||||||
|
new_shape.c += input[i].c;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Axis::HEIGHT:
|
||||||
|
for (int i = 1; i < input.size(); i++) {
|
||||||
|
if (input[i].w != new_shape.w || input[i].c != new_shape.c) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Channels and Width must be the same when concatenating "
|
||||||
|
"by height axis");
|
||||||
|
}
|
||||||
|
new_shape.h += input[i].h;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case Axis::WIDTH:
|
||||||
|
for (int i = 1; i < input.size(); i++) {
|
||||||
|
if (input[i].h != new_shape.h || input[i].c != new_shape.c) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Height and Channels must be the same when concatenating "
|
||||||
|
"by width axis");
|
||||||
|
}
|
||||||
|
new_shape.w += input[i].w;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return InvalidArgumentError("Invalid axis");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
*output_shape = new_shape;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const Convolution2DAttributes& attr) {
|
||||||
|
return MakeSamePadding(input, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
|
return MakeSamePadding(input, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const DepthwiseConvolution2DAttributes& attr) {
|
||||||
|
return MakeSamePadding(input, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const Pooling2DAttributes& attr) {
|
||||||
|
return MakeSamePadding(input, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const MaxUnpooling2DAttributes& attr) {
|
||||||
|
return MakeSamePadding(input, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
float CalculateResizeScale(int32_t input_size, int32_t output_size,
|
||||||
|
const Upsample2DAttributes& attr) {
|
||||||
|
return attr.align_corners && input_size > 1 && output_size > 1
|
||||||
|
? static_cast<float>(input_size - 1) / (output_size - 1)
|
||||||
|
: static_cast<float>(input_size) / output_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const Upsample2DAttributes& attr) {
|
||||||
|
return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
319
tensorflow/lite/delegates/gpu/common/operations.h
Normal file
319
tensorflow/lite/delegates/gpu/common/operations.h
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Non exhaustive list of operations.
|
||||||
|
enum class OperationType {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
ABS,
|
||||||
|
ADD,
|
||||||
|
// TODO(eignasheva): remove APPLY_MASK operation, is should be just MUL
|
||||||
|
APPLY_MASK,
|
||||||
|
BATCH_NORMALIZATION,
|
||||||
|
CONCAT,
|
||||||
|
CONST,
|
||||||
|
CONVOLUTION_2D,
|
||||||
|
CONVOLUTION_TRANSPOSED,
|
||||||
|
COS,
|
||||||
|
DEPTHWISE_CONVOLUTION,
|
||||||
|
FULLY_CONNECTED,
|
||||||
|
LOG,
|
||||||
|
LSTM,
|
||||||
|
MAX_UNPOOLING_2D,
|
||||||
|
MUL,
|
||||||
|
MULTIPLY_SCALAR,
|
||||||
|
POOLING_2D,
|
||||||
|
PAD,
|
||||||
|
PRELU,
|
||||||
|
RELU,
|
||||||
|
RESHAPE,
|
||||||
|
RESIZE,
|
||||||
|
RSQRT,
|
||||||
|
SIGMOID,
|
||||||
|
SIN,
|
||||||
|
SLICE,
|
||||||
|
SOFT_MAX,
|
||||||
|
SQRT,
|
||||||
|
SQUARE,
|
||||||
|
SUB,
|
||||||
|
TANH,
|
||||||
|
UPSAMPLE_2D,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string ToString(enum OperationType op);
|
||||||
|
|
||||||
|
OperationType OperationTypeFromString(const std::string& name);
|
||||||
|
|
||||||
|
struct Padding2D {
|
||||||
|
Padding2D() = default;
|
||||||
|
Padding2D& operator=(const Padding2D& value);
|
||||||
|
bool operator==(const Padding2D& value);
|
||||||
|
bool operator!=(const Padding2D& value);
|
||||||
|
|
||||||
|
// Padding values for every axis (if needed), where 'prepended' defines
|
||||||
|
// padding for the beginning of each axis and 'appended' represents end part
|
||||||
|
// of the corresponding axis.
|
||||||
|
HW prepended = HW(-1, -1);
|
||||||
|
HW appended = HW(-1, -1);
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class PoolingType {
|
||||||
|
UNDEFINED = 0,
|
||||||
|
|
||||||
|
// average pooling
|
||||||
|
AVERAGE = 1,
|
||||||
|
|
||||||
|
// max pooling
|
||||||
|
MAX = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Pooling2DAttributes {
|
||||||
|
PoolingType type = PoolingType::UNDEFINED;
|
||||||
|
// Strides for every axis.
|
||||||
|
HW strides = HW(-1, -1);
|
||||||
|
HW kernel = HW(-1, -1);
|
||||||
|
Padding2D padding;
|
||||||
|
// NOTE(akulik): technically the number of outputs from Pooling node indicates
|
||||||
|
// whether indices are needed or not, but I decided to keep it inside
|
||||||
|
// attributes to simplify processing.
|
||||||
|
bool output_indices = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MaxUnpooling2DAttributes {
|
||||||
|
// Strides for every axis.
|
||||||
|
HW strides = HW(-1, -1);
|
||||||
|
HW kernel = HW(-1, -1);
|
||||||
|
Padding2D padding;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ConcatAttributes {
|
||||||
|
// Defines axis by which to concat on.
|
||||||
|
Axis axis = Axis::UNKNOWN;
|
||||||
|
};
|
||||||
|
|
||||||
|
// @return shape of a tensor after MaxUnpooling2D operation is applied to
|
||||||
|
// the given input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const MaxUnpooling2DAttributes& attr);
|
||||||
|
|
||||||
|
// @return shape of a tensor after Pooling2D operation is applied to the given
|
||||||
|
// input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr);
|
||||||
|
|
||||||
|
// @return shape of a tensor after Concat operation is applied to the given
|
||||||
|
// input.
|
||||||
|
Status CalculateOutputShape(const std::vector<BHWC>& input,
|
||||||
|
const ConcatAttributes& attr, BHWC* output_shape);
|
||||||
|
|
||||||
|
// @return padding for pooling operation to make sure output keep the same shape
|
||||||
|
// as the given input.
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const Pooling2DAttributes& attr);
|
||||||
|
|
||||||
|
// @return padding for max unpooling operation to make sure output keep the same
|
||||||
|
// shape as the given input.
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const MaxUnpooling2DAttributes& attr);
|
||||||
|
|
||||||
|
struct Convolution2DAttributes {
|
||||||
|
HW strides = HW(1, 1); // Along each axis.
|
||||||
|
HW dilations = HW(1, 1); // Along each axis.
|
||||||
|
Padding2D padding;
|
||||||
|
|
||||||
|
Tensor<OHWI, DataType::FLOAT32> weights;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> bias; // optional
|
||||||
|
};
|
||||||
|
|
||||||
|
// @return shape of a tensor after Convolution2D operation is applied to
|
||||||
|
// the given input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const Convolution2DAttributes& attr);
|
||||||
|
|
||||||
|
// @return padding for convolution operation to make sure output keep the same
|
||||||
|
// shape as the given input.
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const Convolution2DAttributes& attr);
|
||||||
|
|
||||||
|
struct ConvolutionTransposedAttributes {
|
||||||
|
HW stride = HW(1, 1); // Along each axis.
|
||||||
|
HW adjacent; // TODO(sorokin): No op on Flow.
|
||||||
|
Padding2D padding;
|
||||||
|
|
||||||
|
Tensor<OHWI, DataType::FLOAT32> weights;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> bias; // optional
|
||||||
|
};
|
||||||
|
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
|
// @return shape of a tensor after ConvolutionTransposed operation is applied to
|
||||||
|
// the given input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
|
struct DepthwiseConvolution2DAttributes : public Convolution2DAttributes {};
|
||||||
|
|
||||||
|
// @return shape of a tensor after DepthwiseConvolution2D operation is applied
|
||||||
|
// to the given input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const DepthwiseConvolution2DAttributes& attr);
|
||||||
|
|
||||||
|
// @return padding for depthwise convolution operation to make sure output keep
|
||||||
|
// the same shape as the given input.
|
||||||
|
Padding2D CalculateSamePadding(const BHWC& input,
|
||||||
|
const DepthwiseConvolution2DAttributes& attr);
|
||||||
|
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const DepthwiseConvolution2DAttributes& attr);
|
||||||
|
|
||||||
|
// f(x):= {
|
||||||
|
// if x < 0 : x -> alpha * x
|
||||||
|
// if x >= 0 : x -> min(clip, x)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - ReLU: clip = 0, alpha = 0
|
||||||
|
// - ReLU6: clip = 6, alpha = 0
|
||||||
|
// - Leaky ReLU: clip = 0, alpha = a
|
||||||
|
struct ReLUAttributes {
|
||||||
|
// clip <= 0 mean it is not set.
|
||||||
|
float clip = 0;
|
||||||
|
|
||||||
|
float alpha = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PReLUAttributes {
|
||||||
|
// clip <= 0 mean it is not set.
|
||||||
|
float clip = 0;
|
||||||
|
|
||||||
|
// If alpha is linear, then it is sharded across CHANNELS axis, otherwise
|
||||||
|
// full shape alpha is required.
|
||||||
|
absl::variant<Tensor<Linear, DataType::FLOAT32>,
|
||||||
|
Tensor<HWC, DataType::FLOAT32>>
|
||||||
|
alpha;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SoftMaxAttributes {
|
||||||
|
Axis axis = Axis::UNKNOWN;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum LstmKernelType {
|
||||||
|
FULL = 0,
|
||||||
|
BASIC = 1, // Currently, only basic is supported.
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LstmAttributes {
|
||||||
|
LstmKernelType kernel_type = LstmKernelType::BASIC;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MultiplyScalarAttributes {
|
||||||
|
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
|
||||||
|
param;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class UpsamplingType {
|
||||||
|
NEAREST = 0,
|
||||||
|
BILINEAR = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Upsample2DAttributes {
|
||||||
|
HW new_shape;
|
||||||
|
|
||||||
|
UpsamplingType type = UpsamplingType::NEAREST;
|
||||||
|
|
||||||
|
// If true, the centers of the 4 corner pixels of the input and output tensors
|
||||||
|
// are aligned, preserving the values at the corner pixels. Defaults to false.
|
||||||
|
bool align_corners = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
float CalculateResizeScale(int32_t input_size, int32_t output_size,
|
||||||
|
const Upsample2DAttributes& attr);
|
||||||
|
|
||||||
|
// @return shape of a tensor after upscale operation is applied to the given
|
||||||
|
// input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const Upsample2DAttributes& attr);
|
||||||
|
|
||||||
|
enum class PaddingContentType {
|
||||||
|
ZEROS = 0,
|
||||||
|
REFLECT = 1,
|
||||||
|
EDGE = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PadAttributes {
|
||||||
|
PaddingContentType type = PaddingContentType::ZEROS;
|
||||||
|
|
||||||
|
HWC prepended;
|
||||||
|
HWC appended;
|
||||||
|
};
|
||||||
|
|
||||||
|
// @return shape of a tensor after Pad operation is applied to the given input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr);
|
||||||
|
|
||||||
|
struct ConstTensorAttributes {
|
||||||
|
Tensor<BHWC, DataType::FLOAT32> tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Simple slicing without advanced support for shrinking, reverse slicing etc.
|
||||||
|
struct SliceAttributes {
|
||||||
|
// Specifies start and end dimensions for slicing.
|
||||||
|
HWC starts;
|
||||||
|
HWC ends;
|
||||||
|
|
||||||
|
// Stride should be >= 1.
|
||||||
|
HWC strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
// @return shape of a tensor after Slice2D operation is applied to the given
|
||||||
|
// input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr);
|
||||||
|
|
||||||
|
struct AddAttributes {
|
||||||
|
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
|
||||||
|
param;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FullyConnectedAttributes {
|
||||||
|
Tensor<OHWI, DataType::FLOAT32> weights;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> bias;
|
||||||
|
};
|
||||||
|
|
||||||
|
// @return shape of a tensor after FullyConnected operation is applied to
|
||||||
|
// the given input.
|
||||||
|
BHWC CalculateOutputShape(const BHWC& input,
|
||||||
|
const FullyConnectedAttributes& attr);
|
||||||
|
|
||||||
|
struct ReshapeAttributes {
|
||||||
|
BHWC new_shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_
|
115
tensorflow/lite/delegates/gpu/common/shape.cc
Normal file
115
tensorflow/lite/delegates/gpu/common/shape.cc
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct GetAxisByIndexFunc {
|
||||||
|
template <Layout T>
|
||||||
|
Axis operator()() const {
|
||||||
|
return GetAxis<T>(index);
|
||||||
|
}
|
||||||
|
int32_t index;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GetIndexByAxisFunc {
|
||||||
|
template <Layout T>
|
||||||
|
int operator()() const {
|
||||||
|
return GetAxisIndex<T>(axis);
|
||||||
|
}
|
||||||
|
Axis axis;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NumAxisFunc {
|
||||||
|
template <Layout T>
|
||||||
|
int operator()() const {
|
||||||
|
return Size<T>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::string ToString(Axis axis) {
|
||||||
|
switch (axis) {
|
||||||
|
case Axis::BATCH:
|
||||||
|
return "batch";
|
||||||
|
case Axis::CHANNELS:
|
||||||
|
return "channels";
|
||||||
|
case Axis::INPUT_CHANNELS:
|
||||||
|
return "input_channels";
|
||||||
|
case Axis::OUTPUT_CHANNELS:
|
||||||
|
return "output_channels";
|
||||||
|
case Axis::HEIGHT:
|
||||||
|
return "height";
|
||||||
|
case Axis::WIDTH:
|
||||||
|
return "width";
|
||||||
|
case Axis::VALUE:
|
||||||
|
return "value";
|
||||||
|
case Axis::UNKNOWN:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
return "undefined";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString(Layout layout) {
|
||||||
|
switch (layout) {
|
||||||
|
case Layout::SCALAR:
|
||||||
|
return "scalar";
|
||||||
|
case Layout::LINEAR:
|
||||||
|
return "linear";
|
||||||
|
case Layout::HW:
|
||||||
|
return "hw";
|
||||||
|
case Layout::CHW:
|
||||||
|
return "chw";
|
||||||
|
case Layout::HWC:
|
||||||
|
return "hwc";
|
||||||
|
case Layout::OHWI:
|
||||||
|
return "ohwi";
|
||||||
|
case Layout::IHWO:
|
||||||
|
return "ihwo";
|
||||||
|
case Layout::OIHW:
|
||||||
|
return "oihw";
|
||||||
|
case Layout::IOHW:
|
||||||
|
return "iohw";
|
||||||
|
case Layout::BHWC:
|
||||||
|
return "bhwc";
|
||||||
|
case Layout::UNKNOWN:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
return "undefined";
|
||||||
|
}
|
||||||
|
|
||||||
|
Axis GetAxis(Layout layout, int32_t index) {
|
||||||
|
return DispatchByLayout(layout, GetAxisByIndexFunc{index});
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetAxisIndex(Layout layout, Axis axis) {
|
||||||
|
return DispatchByLayout(layout, GetIndexByAxisFunc{axis});
|
||||||
|
}
|
||||||
|
|
||||||
|
int Size(Layout layout) { return DispatchByLayout(layout, NumAxisFunc()); }
|
||||||
|
|
||||||
|
std::string ToString(const Shape& s) {
|
||||||
|
return absl::StrCat("{", ToString(s.layout), ", {",
|
||||||
|
absl::StrJoin(s.dimensions, ", "), "}}");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
612
tensorflow/lite/delegates/gpu/common/shape.h
Normal file
612
tensorflow/lite/delegates/gpu/common/shape.h
Normal file
@ -0,0 +1,612 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
|
||||||
|
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
#include <functional>
|
||||||
|
#include <numeric>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
enum class Axis {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
CHANNELS = 1,
|
||||||
|
INPUT_CHANNELS = 2,
|
||||||
|
OUTPUT_CHANNELS = 3,
|
||||||
|
HEIGHT = 4,
|
||||||
|
WIDTH = 5,
|
||||||
|
BATCH = 6,
|
||||||
|
VALUE = 7,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string ToString(Axis t);
|
||||||
|
|
||||||
|
// Layout represents axis order.
|
||||||
|
enum class Layout {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
SCALAR = 1,
|
||||||
|
LINEAR = 2,
|
||||||
|
HW = 3,
|
||||||
|
CHW = 4,
|
||||||
|
HWC = 5,
|
||||||
|
OIHW = 6,
|
||||||
|
OHWI = 7,
|
||||||
|
IHWO = 8,
|
||||||
|
IOHW = 9,
|
||||||
|
BHWC = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string ToString(Layout l);
|
||||||
|
|
||||||
|
// Returns number of axis for the fixed layout.
|
||||||
|
template <Layout T>
|
||||||
|
constexpr int Size();
|
||||||
|
|
||||||
|
// Returns number of axis for the given layout.
|
||||||
|
int Size(Layout layout);
|
||||||
|
|
||||||
|
// Returns Axis for the given index and fixed layout.
|
||||||
|
template <Layout T>
|
||||||
|
constexpr Axis GetAxis(int index);
|
||||||
|
|
||||||
|
// Returns axis for the given layout and index.
|
||||||
|
Axis GetAxis(Layout layout, int32_t index);
|
||||||
|
|
||||||
|
// Returns axis index for the given axis and fixed layout.
|
||||||
|
template <Layout T>
|
||||||
|
constexpr int GetAxisIndex(Axis axis);
|
||||||
|
|
||||||
|
// Returns axis index for the given layout and axis.
|
||||||
|
int GetAxisIndex(Layout layout, Axis axis);
|
||||||
|
|
||||||
|
// Stores Layout(axis set and order) and value for dimensions.
|
||||||
|
struct Shape {
|
||||||
|
Shape() : layout(Layout::UNKNOWN), dimensions() {}
|
||||||
|
|
||||||
|
explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {}
|
||||||
|
|
||||||
|
Shape(Layout t, std::vector<int32_t> d)
|
||||||
|
: layout(t), dimensions(std::move(d)) {}
|
||||||
|
|
||||||
|
bool operator==(const Shape& other) const {
|
||||||
|
return (layout == other.layout) && (dimensions == other.dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const Shape& other) const { return !operator==(other); }
|
||||||
|
|
||||||
|
// All methods below are matching same methods defined in StrongShape to
|
||||||
|
// make sure generic algorithms work both ways.
|
||||||
|
|
||||||
|
// Returns back a dimension or -1 if it is not found.
|
||||||
|
template <Axis D>
|
||||||
|
int32_t get() const;
|
||||||
|
int32_t get(Axis d) const;
|
||||||
|
|
||||||
|
template <Axis D>
|
||||||
|
bool set(int32_t t);
|
||||||
|
bool set(Axis d, int32_t t);
|
||||||
|
|
||||||
|
Axis axis(int index) const { return GetAxis(layout, index); }
|
||||||
|
|
||||||
|
int index(Axis d) const { return GetAxisIndex(layout, d); }
|
||||||
|
|
||||||
|
int64_t DimensionsProduct() const {
|
||||||
|
return std::accumulate(dimensions.begin(), dimensions.end(), 1ll,
|
||||||
|
std::multiplies<int64_t>());
|
||||||
|
}
|
||||||
|
|
||||||
|
Layout layout = Layout::UNKNOWN;
|
||||||
|
|
||||||
|
std::vector<int32_t> dimensions;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string ToString(const Shape& s);
|
||||||
|
|
||||||
|
// StrongShape provides convenient explicit access to dimensions stored in
|
||||||
|
// shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors.
|
||||||
|
//
|
||||||
|
// There is a conversion possible both ways between Shape and StrongShape.
|
||||||
|
//
|
||||||
|
// OIHW oihw; // specific shape
|
||||||
|
// Shape l = oihw.ToShape();
|
||||||
|
//
|
||||||
|
// OHWI other; // notice not the same but compatible shape.
|
||||||
|
// if (!other.Adopt(l)) {
|
||||||
|
// // error handling
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// StrongShape supports the following set of operations:
|
||||||
|
//
|
||||||
|
// // Returns number of axis in the shape class.
|
||||||
|
// static constexpr int size();
|
||||||
|
//
|
||||||
|
// // Returns Axis for the given index or Axis::UNKNOWN if index
|
||||||
|
// // falls outside of the defined range in this shape.
|
||||||
|
// static constexpr Axis axis(int index);
|
||||||
|
//
|
||||||
|
// // Returns index for the given axis or -1 if axis is not defined in this
|
||||||
|
// // shape.
|
||||||
|
// static constexpr int index(Axis d);
|
||||||
|
//
|
||||||
|
// // Getters
|
||||||
|
// int32_t get(int index) const;
|
||||||
|
// int32_t get(Axis d) const;
|
||||||
|
// int32_t get<Axis>() const;
|
||||||
|
//
|
||||||
|
// // Setters that return false if set was not successful.
|
||||||
|
// bool set(int index, int32_t v);
|
||||||
|
// bool set(Axis d, int32_t v);
|
||||||
|
// bool set<Axis>(int32_t v);
|
||||||
|
//
|
||||||
|
// // Returns shape's layout.
|
||||||
|
// static const Layout layout;
|
||||||
|
//
|
||||||
|
// // Turns specific shape into generic shape.
|
||||||
|
// Shape ToShape() const;
|
||||||
|
//
|
||||||
|
// // Copies all dimensions from the given shape.
|
||||||
|
// bool Adopt(const Shape&);
|
||||||
|
//
|
||||||
|
template <Layout L>
|
||||||
|
struct StrongShape;
|
||||||
|
|
||||||
|
using Scalar = StrongShape<Layout::SCALAR>;
|
||||||
|
using Linear = StrongShape<Layout::LINEAR>;
|
||||||
|
using HW = StrongShape<Layout::HW>;
|
||||||
|
|
||||||
|
// Common tensor shape for CNN models working with images.
|
||||||
|
using CHW = StrongShape<Layout::CHW>;
|
||||||
|
using HWC = StrongShape<Layout::HWC>;
|
||||||
|
using BHWC = StrongShape<Layout::BHWC>;
|
||||||
|
|
||||||
|
// Tensor shape used in convolution_2d weights.
|
||||||
|
using OIHW = StrongShape<Layout::OIHW>;
|
||||||
|
using OHWI = StrongShape<Layout::OHWI>;
|
||||||
|
using IHWO = StrongShape<Layout::IHWO>;
|
||||||
|
using IOHW = StrongShape<Layout::IOHW>;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Everything below are internal implementation details.
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace internal_shape {
|
||||||
|
|
||||||
|
template <Axis T>
|
||||||
|
struct AxisTraits;
|
||||||
|
|
||||||
|
#define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName) \
|
||||||
|
template <> \
|
||||||
|
struct AxisTraits<Axis::AxisName> { \
|
||||||
|
struct Holder { \
|
||||||
|
int32_t HolderName; \
|
||||||
|
\
|
||||||
|
protected: \
|
||||||
|
int32_t operator()() const { return HolderName; } \
|
||||||
|
void operator()(int32_t v) { HolderName = v; } \
|
||||||
|
}; \
|
||||||
|
\
|
||||||
|
using dimension_holder_type = Holder; \
|
||||||
|
}
|
||||||
|
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(CHANNELS, c);
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(HEIGHT, h);
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(WIDTH, w);
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i);
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o);
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(BATCH, b);
|
||||||
|
TFLITE_GPU_AXIS_TRAITS(VALUE, v);
|
||||||
|
|
||||||
|
#undef TFLITE_GPU_AXIS_TRAITS
|
||||||
|
|
||||||
|
template <int N, Axis... As>
|
||||||
|
struct StrongShapeImpl;
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct StrongShapeImpl<N> {
|
||||||
|
static constexpr int size() { return N; }
|
||||||
|
|
||||||
|
static constexpr Axis axis(int) { return Axis::UNKNOWN; }
|
||||||
|
|
||||||
|
static constexpr int index(Axis) { return -1; }
|
||||||
|
|
||||||
|
int32_t get(Axis) const { return -1; }
|
||||||
|
|
||||||
|
int32_t get(int) const { return -1; }
|
||||||
|
|
||||||
|
template <Axis B>
|
||||||
|
int32_t get() const {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool set(Axis, int32_t) { return false; }
|
||||||
|
|
||||||
|
bool set(int, int32_t) { return false; }
|
||||||
|
|
||||||
|
template <Axis B>
|
||||||
|
bool set(int32_t) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Used to deduce number of axis, and to be a child of a proper holder to
|
||||||
|
// provide access to the dimension by name
|
||||||
|
template <int N, Axis A, Axis... As>
|
||||||
|
struct StrongShapeImpl<N, A, As...>
|
||||||
|
: public AxisTraits<A>::dimension_holder_type,
|
||||||
|
public StrongShapeImpl<N + 1, As...> {
|
||||||
|
using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type;
|
||||||
|
|
||||||
|
using rest_type = StrongShapeImpl<N + 1, As...>;
|
||||||
|
|
||||||
|
StrongShapeImpl() : dimension_holder_type{0}, rest_type() {}
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
explicit StrongShapeImpl(int32_t t, Ts... ts)
|
||||||
|
: dimension_holder_type{t}, rest_type(ts...) {}
|
||||||
|
|
||||||
|
static constexpr Axis axis(int index) {
|
||||||
|
return index == N ? A : rest_type::axis(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int index(Axis d) {
|
||||||
|
return d == A ? N : rest_type::index(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t get(Axis d) const {
|
||||||
|
return d == A ? dimension_holder_type::operator()() : rest_type::get(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis B>
|
||||||
|
int32_t get() const {
|
||||||
|
return B == A ? dimension_holder_type::operator()()
|
||||||
|
: rest_type::template get<B>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t get(int index) const {
|
||||||
|
return index == N ? dimension_holder_type::operator()()
|
||||||
|
: rest_type::get(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool set(Axis d, int32_t t) {
|
||||||
|
if (d == A) {
|
||||||
|
dimension_holder_type::operator()(t);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return rest_type::set(d, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool set(int index, int32_t t) {
|
||||||
|
if (index == N) {
|
||||||
|
dimension_holder_type::operator()(t);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return rest_type::set(index, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis B>
|
||||||
|
bool set(int32_t t) {
|
||||||
|
if (A == B) {
|
||||||
|
dimension_holder_type::operator()(t);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return rest_type::template set<B>(t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <Layout T>
|
||||||
|
struct LayoutTraits;
|
||||||
|
|
||||||
|
#define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...) \
|
||||||
|
template <> \
|
||||||
|
struct LayoutTraits<Layout::LayoutName> { \
|
||||||
|
using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \
|
||||||
|
}
|
||||||
|
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
|
||||||
|
Axis::INPUT_CHANNELS);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS,
|
||||||
|
Axis::HEIGHT, Axis::WIDTH);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS,
|
||||||
|
Axis::HEIGHT, Axis::WIDTH);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
|
||||||
|
Axis::OUTPUT_CHANNELS);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE);
|
||||||
|
TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
|
||||||
|
Axis::CHANNELS);
|
||||||
|
|
||||||
|
#undef TFLITE_GPU_LAYOUT_TRAITS
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct LayoutTraits<Layout::UNKNOWN> {
|
||||||
|
using strong_shape_type = StrongShapeImpl<0>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <Axis A>
|
||||||
|
struct DimensionGetterFixedAxisFunc {
|
||||||
|
template <Layout T>
|
||||||
|
int32_t operator()() const {
|
||||||
|
constexpr int i = GetAxisIndex<T>(A);
|
||||||
|
return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
|
||||||
|
}
|
||||||
|
const Shape* l;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DimensionGetterFunc {
|
||||||
|
template <Layout T>
|
||||||
|
int32_t operator()() const {
|
||||||
|
int i = GetAxisIndex<T>(d);
|
||||||
|
return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
|
||||||
|
}
|
||||||
|
Axis d;
|
||||||
|
const Shape* l;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <Axis A>
|
||||||
|
struct DimensionSetterFixedAxisFunc {
|
||||||
|
template <Layout T>
|
||||||
|
bool operator()() const {
|
||||||
|
constexpr int i = GetAxisIndex<T>(A);
|
||||||
|
if (i >= 0 && i < l->dimensions.size()) {
|
||||||
|
l->dimensions[i] = v;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Shape* l;
|
||||||
|
int32_t v;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DimensionSetterFunc {
|
||||||
|
template <Layout T>
|
||||||
|
bool operator()() const {
|
||||||
|
int i = GetAxisIndex<T>(d);
|
||||||
|
if (i >= 0 && i < l->dimensions.size()) {
|
||||||
|
l->dimensions[i] = v;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Axis d;
|
||||||
|
Shape* l;
|
||||||
|
int32_t v;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <Layout L>
|
||||||
|
struct ToShapeFunc {
|
||||||
|
template <Layout T>
|
||||||
|
bool operator()() const {
|
||||||
|
for (int i = 0; i < StrongShape<L>::size(); ++i) {
|
||||||
|
int index = GetAxisIndex<T>(StrongShape<L>::axis(i));
|
||||||
|
if (index < 0) return false;
|
||||||
|
shape->set(i, l.dimensions[index]);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
StrongShape<L>* shape;
|
||||||
|
const Shape& l;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal_shape
|
||||||
|
|
||||||
|
// template <Axis... As>
|
||||||
|
template <Layout L>
|
||||||
|
struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type {
|
||||||
|
using strong_shape_type =
|
||||||
|
typename internal_shape::LayoutTraits<L>::strong_shape_type;
|
||||||
|
StrongShape() = default;
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
explicit StrongShape(Ts... t) : strong_shape_type(t...) {}
|
||||||
|
|
||||||
|
constexpr static Layout layout = L;
|
||||||
|
|
||||||
|
bool operator==(const StrongShape<L>& shape) const {
|
||||||
|
// TODO(akulik): implement better alternative.
|
||||||
|
return this->ToShape() == shape.ToShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const StrongShape<L>& shape) const {
|
||||||
|
// TODO(akulik): implement better alternative.
|
||||||
|
return this->ToShape() != shape.ToShape();
|
||||||
|
}
|
||||||
|
bool empty() const { return DimensionsProduct() == 0; }
|
||||||
|
|
||||||
|
// Turns StrongShape into generic shape.
|
||||||
|
Shape ToShape() const {
|
||||||
|
std::vector<int32_t> dimensions(StrongShape::size());
|
||||||
|
for (int i = 0; i < StrongShape::size(); ++i) {
|
||||||
|
dimensions[i] = StrongShape::get(i);
|
||||||
|
}
|
||||||
|
return Shape(L, std::move(dimensions));
|
||||||
|
}
|
||||||
|
|
||||||
|
// @return all dimensions multiplied
|
||||||
|
int64_t DimensionsProduct() const {
|
||||||
|
int64_t product = 1;
|
||||||
|
for (int i = 0; i < StrongShape::size(); ++i) {
|
||||||
|
product *= StrongShape::get(i);
|
||||||
|
}
|
||||||
|
return product;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Translates given coordinates of the layout into a linear index assuming
|
||||||
|
// dimensions are sorted in tensor access order e.g. if you access
|
||||||
|
// foobar[i][j][k] order of coordinates should be i,j,k.
|
||||||
|
int64_t LinearIndex(
|
||||||
|
const std::array<int32_t, StrongShape::size()>& coordinates) const {
|
||||||
|
int64_t index = coordinates[0];
|
||||||
|
for (int i = 1; i < StrongShape::size(); ++i) {
|
||||||
|
index = index * StrongShape::get(i) + coordinates[i];
|
||||||
|
}
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies all dimensions from the given generic shape into specific shape.
|
||||||
|
// It requires shape to have all axis defined in the given
|
||||||
|
// StrongShape. For example:
|
||||||
|
// - If this shape is OHWI but given shape is OIHW, Adopt will copy all
|
||||||
|
// dimensions and return true.
|
||||||
|
// - If this shape is OIHW but input shape is HW, Adopt will copy H and W
|
||||||
|
// dimensions and return true, but if this shape is HW and given shape
|
||||||
|
// OIHW, then Adopt will return false because not all axis are present in
|
||||||
|
// the input shape.
|
||||||
|
//
|
||||||
|
// @return false if generic shape is not compatible.
|
||||||
|
bool Adopt(const Shape& shape) {
|
||||||
|
return DispatchByLayout(shape.layout,
|
||||||
|
internal_shape::ToShapeFunc<L>{this, shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
// For all axis defined in a given shape copies values to this shape.
|
||||||
|
// Therefore, it is possible to copy dimensions from CHW to BCHW, but not
|
||||||
|
// the other way around.
|
||||||
|
//
|
||||||
|
// BCHW bchw;
|
||||||
|
// CHW chw;
|
||||||
|
// bchw.CopyAllGivenAxis(chw); --> true
|
||||||
|
// chw.CopyAllGivenAxis(bchw); --> false
|
||||||
|
//
|
||||||
|
// @return false if axis in source shape is not defined here, thus value
|
||||||
|
// was not copied.
|
||||||
|
template <Layout B>
|
||||||
|
bool CopyAllGivenAxis(const StrongShape<B>& source) {
|
||||||
|
for (int i = 0; i < source.size(); ++i) {
|
||||||
|
if (!StrongShape::set(source.axis(i), source.get(i))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For all axis defined in this shape copies values from the given shape.
|
||||||
|
//
|
||||||
|
// BCHW bchw;
|
||||||
|
// CHW chw;
|
||||||
|
// bchw.CopyAllDefinedAxis(chw); --> false
|
||||||
|
// chw.CopyAllDefinedAxis(bchw); --> true
|
||||||
|
//
|
||||||
|
// @return false if given shape does not have axis defined here,
|
||||||
|
// therefore a value was not copied.
|
||||||
|
template <Layout B>
|
||||||
|
bool CopyAllDefinedAxis(const StrongShape<B>& source) {
|
||||||
|
for (int i = 0; i < StrongShape::size(); ++i) {
|
||||||
|
int source_index = source.index(StrongShape::axis(i));
|
||||||
|
if (source_index < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
StrongShape::set(i, source.get(source_index)); // always true
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies values only for matching axis.
|
||||||
|
template <Layout B>
|
||||||
|
void CopyMatchingAxis(const StrongShape<B>& source) {
|
||||||
|
for (int i = 0; i < StrongShape::size(); ++i) {
|
||||||
|
StrongShape::set(source.axis(i), source.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <Layout T>
|
||||||
|
inline std::string ToString(const StrongShape<T>& s) {
|
||||||
|
return ToString(s.ToShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Layout L>
|
||||||
|
constexpr Layout StrongShape<L>::layout;
|
||||||
|
|
||||||
|
template <class F>
|
||||||
|
auto DispatchByLayout(Layout type, F f)
|
||||||
|
-> decltype(f.template operator()<Layout::UNKNOWN>()) {
|
||||||
|
switch (type) {
|
||||||
|
case Layout::HW:
|
||||||
|
return f.template operator()<Layout::HW>();
|
||||||
|
case Layout::HWC:
|
||||||
|
return f.template operator()<Layout::HWC>();
|
||||||
|
case Layout::CHW:
|
||||||
|
return f.template operator()<Layout::CHW>();
|
||||||
|
case Layout::OIHW:
|
||||||
|
return f.template operator()<Layout::OIHW>();
|
||||||
|
case Layout::IOHW:
|
||||||
|
return f.template operator()<Layout::IOHW>();
|
||||||
|
case Layout::OHWI:
|
||||||
|
return f.template operator()<Layout::OHWI>();
|
||||||
|
case Layout::IHWO:
|
||||||
|
return f.template operator()<Layout::IHWO>();
|
||||||
|
case Layout::LINEAR:
|
||||||
|
return f.template operator()<Layout::LINEAR>();
|
||||||
|
case Layout::SCALAR:
|
||||||
|
return f.template operator()<Layout::SCALAR>();
|
||||||
|
case Layout::BHWC:
|
||||||
|
return f.template operator()<Layout::BHWC>();
|
||||||
|
case Layout::UNKNOWN:
|
||||||
|
return f.template operator()<Layout::UNKNOWN>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Layout T>
|
||||||
|
constexpr int Size() {
|
||||||
|
return StrongShape<T>::size();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Layout T>
|
||||||
|
constexpr Axis GetAxis(int index) {
|
||||||
|
return StrongShape<T>::axis(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Layout T>
|
||||||
|
constexpr int GetAxisIndex(Axis axis) {
|
||||||
|
return StrongShape<T>::index(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis D>
|
||||||
|
inline int32_t Shape::get() const {
|
||||||
|
return DispatchByLayout(
|
||||||
|
layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int32_t Shape::get(Axis d) const {
|
||||||
|
return DispatchByLayout(layout, internal_shape::DimensionGetterFunc{d, this});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <Axis D>
|
||||||
|
inline bool Shape::set(int32_t t) {
|
||||||
|
return DispatchByLayout(
|
||||||
|
layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool Shape::set(Axis d, int32_t t) {
|
||||||
|
return DispatchByLayout(layout,
|
||||||
|
internal_shape::DimensionSetterFunc{d, this, t});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
|
123
tensorflow/lite/delegates/gpu/common/shape_test.cc
Normal file
123
tensorflow/lite/delegates/gpu/common/shape_test.cc
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(OIHW, Smoke) {
|
||||||
|
OIHW OIHW;
|
||||||
|
|
||||||
|
// Test 4 different versions of setters.
|
||||||
|
OIHW.i = 1;
|
||||||
|
ASSERT_TRUE(OIHW.set<Axis::OUTPUT_CHANNELS>(2));
|
||||||
|
ASSERT_TRUE(OIHW.set(Axis::HEIGHT, 3));
|
||||||
|
ASSERT_TRUE(OIHW.set(3, 4));
|
||||||
|
|
||||||
|
// Make sure invalid setters return false.
|
||||||
|
ASSERT_FALSE(OIHW.set(5, 10));
|
||||||
|
ASSERT_FALSE(OIHW.set(Axis::CHANNELS, 10));
|
||||||
|
ASSERT_FALSE(OIHW.set<Axis::CHANNELS>(10));
|
||||||
|
|
||||||
|
// Test 4 different versions of getters
|
||||||
|
EXPECT_EQ(1, OIHW.get(Axis::INPUT_CHANNELS));
|
||||||
|
EXPECT_EQ(2, OIHW.o);
|
||||||
|
EXPECT_EQ(3, OIHW.get(2));
|
||||||
|
EXPECT_EQ(4, OIHW.get<Axis::WIDTH>());
|
||||||
|
|
||||||
|
// Make sure getters that fall outside of a range return invalid axis.
|
||||||
|
EXPECT_EQ(-1, OIHW.get(5));
|
||||||
|
EXPECT_EQ(-1, OIHW.get(Axis::CHANNELS));
|
||||||
|
EXPECT_EQ(-1, OIHW.get<Axis::CHANNELS>());
|
||||||
|
|
||||||
|
// Check axis indices are all correct.
|
||||||
|
ASSERT_EQ(4, OIHW.size());
|
||||||
|
std::vector<Axis> expected = {Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS,
|
||||||
|
Axis::HEIGHT, Axis::WIDTH};
|
||||||
|
for (int i = 0; i < OIHW.size(); ++i) {
|
||||||
|
Axis axis = OIHW.axis(i);
|
||||||
|
ASSERT_EQ(expected[i], axis);
|
||||||
|
ASSERT_EQ(i, OIHW.index(axis));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check equivalent conversions.
|
||||||
|
OHWI ohwi;
|
||||||
|
ASSERT_TRUE(ohwi.CopyAllDefinedAxis(OIHW));
|
||||||
|
EXPECT_EQ(ohwi.o, OIHW.o);
|
||||||
|
EXPECT_EQ(ohwi.i, OIHW.i);
|
||||||
|
EXPECT_EQ(ohwi.h, OIHW.h);
|
||||||
|
EXPECT_EQ(ohwi.w, OIHW.w);
|
||||||
|
|
||||||
|
ohwi = OHWI(10, 20, 30, 40);
|
||||||
|
ASSERT_TRUE(OIHW.CopyAllGivenAxis(ohwi));
|
||||||
|
EXPECT_EQ(ohwi.o, OIHW.o);
|
||||||
|
EXPECT_EQ(ohwi.i, OIHW.i);
|
||||||
|
EXPECT_EQ(ohwi.h, OIHW.h);
|
||||||
|
EXPECT_EQ(ohwi.w, OIHW.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Layout, Smoke) {
|
||||||
|
EXPECT_EQ(4, Size<Layout::OIHW>());
|
||||||
|
EXPECT_EQ(4, Size(Layout::OIHW));
|
||||||
|
std::vector<Axis> expected = {Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS,
|
||||||
|
Axis::HEIGHT, Axis::WIDTH};
|
||||||
|
for (int i = 0; i < Size<Layout::OIHW>(); ++i) {
|
||||||
|
Axis axis = GetAxis<Layout::OIHW>(i);
|
||||||
|
ASSERT_EQ(expected[i], axis);
|
||||||
|
ASSERT_EQ(axis, GetAxis(Layout::OIHW, i));
|
||||||
|
ASSERT_EQ(i, GetAxisIndex<Layout::OIHW>(axis));
|
||||||
|
ASSERT_EQ(i, GetAxisIndex(Layout::OIHW, axis));
|
||||||
|
}
|
||||||
|
EXPECT_EQ(Axis::UNKNOWN, GetAxis(Layout::OIHW, 5));
|
||||||
|
EXPECT_EQ(-1, GetAxisIndex<Layout::OIHW>(Axis::CHANNELS));
|
||||||
|
EXPECT_EQ(-1, GetAxisIndex<Layout::OIHW>(Axis::CHANNELS));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Shape, Smoke) {
|
||||||
|
Shape s(Layout::OIHW, {1, 2, 3, 4});
|
||||||
|
EXPECT_TRUE(s.set(Axis::HEIGHT, 10));
|
||||||
|
EXPECT_TRUE(s.set<Axis::WIDTH>(20));
|
||||||
|
EXPECT_FALSE(s.set(Axis::BATCH, 10));
|
||||||
|
EXPECT_FALSE(s.set<Axis::BATCH>(20));
|
||||||
|
|
||||||
|
ASSERT_EQ(10, s.get<Axis::HEIGHT>());
|
||||||
|
ASSERT_EQ(20, s.get(Axis::WIDTH));
|
||||||
|
EXPECT_EQ(20, s.dimensions[3]);
|
||||||
|
|
||||||
|
OIHW oihw(1, 2, 10, 20);
|
||||||
|
Shape s2 = oihw.ToShape();
|
||||||
|
EXPECT_EQ(s2.layout, oihw.layout);
|
||||||
|
EXPECT_EQ(s.layout, s2.layout);
|
||||||
|
EXPECT_EQ(s.dimensions, s2.dimensions);
|
||||||
|
|
||||||
|
// Convert layout into compatible shape.
|
||||||
|
OHWI ohwi;
|
||||||
|
ASSERT_TRUE(ohwi.Adopt(s2));
|
||||||
|
EXPECT_EQ(1, ohwi.o);
|
||||||
|
EXPECT_EQ(2, ohwi.i);
|
||||||
|
EXPECT_EQ(10, ohwi.h);
|
||||||
|
EXPECT_EQ(20, ohwi.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
124
tensorflow/lite/delegates/gpu/common/status.h
Normal file
124
tensorflow/lite/delegates/gpu/common/status.h
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
enum class StatusCode {
|
||||||
|
kOk = 0,
|
||||||
|
kCancelled = 1,
|
||||||
|
kUnknown = 2,
|
||||||
|
kInvalidArgument = 3,
|
||||||
|
kDeadlineExceeded = 4,
|
||||||
|
kNotFound = 5,
|
||||||
|
kAlreadyExists = 6,
|
||||||
|
kPermissionDenied = 7,
|
||||||
|
kResourceExhausted = 8,
|
||||||
|
kFailedPrecondition = 9,
|
||||||
|
kAborted = 10,
|
||||||
|
kOutOfRange = 11,
|
||||||
|
kUnimplemented = 12,
|
||||||
|
kInternal = 13,
|
||||||
|
kUnavailable = 14,
|
||||||
|
kDataLoss = 15,
|
||||||
|
kUnauthenticated = 16,
|
||||||
|
kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20
|
||||||
|
};
|
||||||
|
|
||||||
|
// Lite version of Status without dependency on protobuf.
|
||||||
|
// TODO(b/128867901): Migrate to absl::Status.
|
||||||
|
class Status {
|
||||||
|
public:
|
||||||
|
Status() = default;
|
||||||
|
Status(StatusCode code) : code_(code) {}
|
||||||
|
Status(StatusCode code, const std::string& error_message)
|
||||||
|
: code_(code), error_message_(error_message) {}
|
||||||
|
|
||||||
|
const std::string& error_message() const { return error_message_; }
|
||||||
|
StatusCode code() const { return code_; }
|
||||||
|
bool ok() const { return code_ == StatusCode::kOk; }
|
||||||
|
|
||||||
|
void IgnoreError() const {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
StatusCode code_ = StatusCode::kOk;
|
||||||
|
std::string error_message_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define RETURN_IF_ERROR(status) \
|
||||||
|
{ \
|
||||||
|
const auto status2 = (status); \
|
||||||
|
if (!status2.ok()) return status2; \
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status OkStatus() { return Status(); }
|
||||||
|
|
||||||
|
inline Status AlreadyExistsError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kAlreadyExists, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status DeadlineExceededError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kDeadlineExceeded, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status FailedPreconditionError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kFailedPrecondition, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status InternalError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kInternal, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status InvalidArgumentError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kInvalidArgument, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status NotFoundError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kNotFound, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status OutOfRangeError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kOutOfRange, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status PermissionDeniedError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kPermissionDenied, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status ResourceExhaustedError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kResourceExhausted, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status UnavailableError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kUnavailable, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status UnimplementedError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kUnimplemented, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Status UnknownError(const std::string& message) {
|
||||||
|
return Status(StatusCode::kUnknown, message);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_
|
94
tensorflow/lite/delegates/gpu/common/tensor.h
Normal file
94
tensorflow/lite/delegates/gpu/common/tensor.h
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace internal_tensor {
|
||||||
|
|
||||||
|
// Meta function given element type returns a type for Tensor data container.
|
||||||
|
template <DataType Type>
|
||||||
|
struct StorageType;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct StorageType<DataType::FLOAT32> {
|
||||||
|
using value = std::vector<float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct StorageType<DataType::INT32> {
|
||||||
|
using value = std::vector<int32_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal_tensor
|
||||||
|
|
||||||
|
template <typename ShapeT, DataType Type>
|
||||||
|
struct Tensor {
|
||||||
|
using ShapeType = ShapeT;
|
||||||
|
|
||||||
|
constexpr static DataType kType = Type;
|
||||||
|
|
||||||
|
using TensorStorageType = typename internal_tensor::StorageType<Type>::value;
|
||||||
|
|
||||||
|
// Opaque id of a tensor.
|
||||||
|
int64_t id = -1;
|
||||||
|
|
||||||
|
ShapeType shape;
|
||||||
|
|
||||||
|
TensorStorageType data;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TensorRef is a reference to another tensor. If an object should never hold
|
||||||
|
// tensor data, then TensorRef should be used instead.
|
||||||
|
template <typename ShapeT>
|
||||||
|
struct TensorRef {
|
||||||
|
using ShapeType = ShapeT;
|
||||||
|
|
||||||
|
DataType type = DataType::UNKNOWN;
|
||||||
|
|
||||||
|
ShapeT shape;
|
||||||
|
|
||||||
|
// Opaque reference to a tensor. Upstream component is responsible for
|
||||||
|
// resolving this reference into an actual tensor.
|
||||||
|
int64_t ref = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ShapeT, DataType Type>
|
||||||
|
constexpr DataType Tensor<ShapeT, Type>::kType;
|
||||||
|
|
||||||
|
template <typename ShapeT, DataType Type>
|
||||||
|
Tensor<ShapeT, Type> MakeZeroTensor(const ShapeT& shape) {
|
||||||
|
Tensor<ShapeT, Type> tensor;
|
||||||
|
tensor.shape = shape;
|
||||||
|
tensor.data = typename Tensor<ShapeT, Type>::TensorStorageType(
|
||||||
|
shape.DimensionsProduct(), 0);
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
using TensorFloat32 = Tensor<BHWC, DataType::FLOAT32>;
|
||||||
|
using TensorRefFloat32 = TensorRef<BHWC>;
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_
|
203
tensorflow/lite/delegates/gpu/common/transformations/BUILD
Normal file
203
tensorflow/lite/delegates/gpu/common/transformations/BUILD
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "add_bias",
|
||||||
|
srcs = ["add_bias.cc"],
|
||||||
|
hdrs = ["add_bias.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fuse_add_to_conv",
|
||||||
|
srcs = ["fuse_add_to_conv.cc"],
|
||||||
|
hdrs = ["fuse_add_to_conv.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "fuse_add_to_conv_test",
|
||||||
|
srcs = ["fuse_add_to_conv_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":fuse_add_to_conv",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fuse_mul_to_conv",
|
||||||
|
srcs = ["fuse_mul_to_conv.cc"],
|
||||||
|
hdrs = ["fuse_mul_to_conv.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "fuse_mul_to_conv_test",
|
||||||
|
srcs = ["fuse_mul_to_conv_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":fuse_mul_to_conv",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "general_transformations",
|
||||||
|
srcs = ["general_transformations.cc"],
|
||||||
|
hdrs = ["general_transformations.h"],
|
||||||
|
deps = [
|
||||||
|
":fuse_add_to_conv",
|
||||||
|
":fuse_mul_to_conv",
|
||||||
|
":make_fully_connected",
|
||||||
|
":make_padding",
|
||||||
|
":merge_padding_with",
|
||||||
|
":remove_noop",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "make_fully_connected",
|
||||||
|
srcs = ["make_fully_connected.cc"],
|
||||||
|
hdrs = ["make_fully_connected.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "make_fully_connected_test",
|
||||||
|
srcs = ["make_fully_connected_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":make_fully_connected",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "make_padding",
|
||||||
|
srcs = ["make_padding.cc"],
|
||||||
|
hdrs = ["make_padding.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "make_padding_test",
|
||||||
|
srcs = ["make_padding_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":make_padding",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "matching",
|
||||||
|
hdrs = ["matching.h"],
|
||||||
|
deps = ["//tensorflow/lite/delegates/gpu/common:model"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "merge_padding_with",
|
||||||
|
srcs = ["merge_padding_with.cc"],
|
||||||
|
hdrs = ["merge_padding_with.h"],
|
||||||
|
deps = [
|
||||||
|
":matching",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "merge_padding_with_test",
|
||||||
|
srcs = ["merge_padding_with_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":merge_padding_with",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "remove_noop",
|
||||||
|
srcs = ["remove_noop.cc"],
|
||||||
|
hdrs = ["remove_noop.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "remove_noop_test",
|
||||||
|
srcs = ["remove_noop_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":remove_noop",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
TransformResult FillBias(Node* node) {
|
||||||
|
auto& attr = absl::any_cast<T&>(node->operation.attributes);
|
||||||
|
if (attr.bias.data.empty()) {
|
||||||
|
const int dst_channels = attr.weights.shape.o;
|
||||||
|
attr.bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(dst_channels));
|
||||||
|
return {TransformStatus::APPLIED, "Added bias"};
|
||||||
|
}
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
template TransformResult FillBias<Convolution2DAttributes>(Node* node);
|
||||||
|
template TransformResult FillBias<ConvolutionTransposedAttributes>(Node* node);
|
||||||
|
template TransformResult FillBias<DepthwiseConvolution2DAttributes>(Node* node);
|
||||||
|
template TransformResult FillBias<FullyConnectedAttributes>(Node* node);
|
||||||
|
|
||||||
|
class AddBias : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||||
|
if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||||
|
return FillBias<Convolution2DAttributes>(node);
|
||||||
|
}
|
||||||
|
if (node->operation.type ==
|
||||||
|
ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
|
||||||
|
return FillBias<ConvolutionTransposedAttributes>(node);
|
||||||
|
}
|
||||||
|
if (node->operation.type ==
|
||||||
|
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||||
|
return FillBias<DepthwiseConvolution2DAttributes>(node);
|
||||||
|
}
|
||||||
|
if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) {
|
||||||
|
return FillBias<FullyConnectedAttributes>(node);
|
||||||
|
}
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<NodeTransformation> NewAddBias() {
|
||||||
|
return absl::make_unique<AddBias>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_BIAS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_BIAS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Makes optional bias(Conv/Deconv and etc) as not optional(always present)
|
||||||
|
std::unique_ptr<NodeTransformation> NewAddBias();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_BIAS_H_
|
@ -0,0 +1,235 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void FuseBiasWithAddAttributes(const AddAttributes& add_attr,
|
||||||
|
const int channels,
|
||||||
|
Tensor<Linear, DataType::FLOAT32>* bias) {
|
||||||
|
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||||
|
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||||
|
if (bias->data.empty()) {
|
||||||
|
*bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(channels));
|
||||||
|
}
|
||||||
|
for (int d = 0; d < channels; ++d) {
|
||||||
|
bias->data[d] += add ? add->data[d] : *add_scalar;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MergeConvolutionWithAdd : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final {
|
||||||
|
auto& conv_node = *sequence[0];
|
||||||
|
auto& add_node = *sequence[1];
|
||||||
|
if (add_node.operation.type != ToString(OperationType::ADD)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
AddAttributes add_attr =
|
||||||
|
absl::any_cast<AddAttributes>(add_node.operation.attributes);
|
||||||
|
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param) &&
|
||||||
|
!absl::get_if<float>(&add_attr.param)) {
|
||||||
|
return {TransformStatus::DECLINED,
|
||||||
|
"This fuse applicable only for broadcast or scalar addition."};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||||
|
Convolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseConvolution2DWithAdd(add_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
|
||||||
|
ConvolutionTransposedAttributes* conv_attr =
|
||||||
|
absl::any_cast<ConvolutionTransposedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseConvolutionTransposedWithAdd(add_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||||
|
DepthwiseConvolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::FULLY_CONNECTED)) {
|
||||||
|
FullyConnectedAttributes* conv_attr =
|
||||||
|
absl::any_cast<FullyConnectedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseFullyConnectedWithAdd(add_attr, conv_attr);
|
||||||
|
} else {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status = RemoveFollowingNode(graph, &add_node, &conv_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove add node after convolution: " +
|
||||||
|
status.error_message()};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MergeAddWithConvolution : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final {
|
||||||
|
auto& conv_node = *sequence[1];
|
||||||
|
auto& add_node = *sequence[0];
|
||||||
|
if (add_node.operation.type != ToString(OperationType::ADD)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
AddAttributes add_attr =
|
||||||
|
absl::any_cast<AddAttributes>(add_node.operation.attributes);
|
||||||
|
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param) &&
|
||||||
|
!absl::get_if<float>(&add_attr.param)) {
|
||||||
|
return {TransformStatus::DECLINED,
|
||||||
|
"This fuse applicable only for broadcast or scalar addition."};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||||
|
Convolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseAddWithConvolution2D(add_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||||
|
DepthwiseConvolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseAddWithDepthwiseConvolution2D(add_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::FULLY_CONNECTED)) {
|
||||||
|
FullyConnectedAttributes* conv_attr =
|
||||||
|
absl::any_cast<FullyConnectedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseAddWithFullyConnected(add_attr, conv_attr);
|
||||||
|
} else {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status = RemovePrecedingNode(graph, &add_node, &conv_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove add node after convolution: " +
|
||||||
|
status.error_message()};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
|
||||||
|
return absl::make_unique<MergeConvolutionWithAdd>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution() {
|
||||||
|
return absl::make_unique<MergeAddWithConvolution>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseConvolution2DWithAdd(const AddAttributes& add_attr,
|
||||||
|
Convolution2DAttributes* attr) {
|
||||||
|
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseDepthwiseConvolution2DWithAdd(const AddAttributes& add_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr) {
|
||||||
|
FuseBiasWithAddAttributes(
|
||||||
|
add_attr, attr->weights.shape.o * attr->weights.shape.i, &attr->bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr,
|
||||||
|
ConvolutionTransposedAttributes* attr) {
|
||||||
|
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
||||||
|
FullyConnectedAttributes* attr) {
|
||||||
|
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
|
||||||
|
Convolution2DAttributes* attr) {
|
||||||
|
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||||
|
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||||
|
if (attr->bias.data.empty()) {
|
||||||
|
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
||||||
|
Linear(attr->weights.shape.o));
|
||||||
|
}
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float add_value = add ? add->data[s] : *add_scalar;
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
|
||||||
|
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr) {
|
||||||
|
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||||
|
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||||
|
if (attr->bias.data.empty()) {
|
||||||
|
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
||||||
|
Linear(attr->weights.shape.o * attr->weights.shape.i));
|
||||||
|
}
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float add_value = add ? add->data[s] : *add_scalar;
|
||||||
|
for (int g = 0; g < attr->weights.shape.o; ++g) {
|
||||||
|
const int d = s * attr->weights.shape.o + g;
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s});
|
||||||
|
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
|
||||||
|
FullyConnectedAttributes* attr) {
|
||||||
|
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||||
|
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||||
|
if (attr->bias.data.empty()) {
|
||||||
|
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
||||||
|
Linear(attr->weights.shape.o));
|
||||||
|
}
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float add_value = add ? add->data[s] : *add_scalar;
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, 0, 0, s});
|
||||||
|
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_ADD_TO_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_ADD_TO_CONV_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Fuse Add Scalar or Add Broadcast after Convolution(Convolution2D,
|
||||||
|
// DepthWise, TransposedConvolution, FullyConnected) into biases of
|
||||||
|
// convolution.
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd();
|
||||||
|
|
||||||
|
// Fuse Add Scalar or Add Broadcast before Convolution(Convolution2D,
|
||||||
|
// DepthWise, FullyConnected) into biases of
|
||||||
|
// convolution.
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution();
|
||||||
|
|
||||||
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
|
// modified attributes we will have the same result as convolution
|
||||||
|
// with old attributes and following add operation.
|
||||||
|
void FuseConvolution2DWithAdd(const AddAttributes& add_attr,
|
||||||
|
Convolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||||
|
// convolution with modified attributes we will have the same result as depth
|
||||||
|
// wise convolution with old attributes and following add operation.
|
||||||
|
void FuseDepthwiseConvolution2DWithAdd(const AddAttributes& add_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||||
|
// transposed with modified attributes we will have the same result as
|
||||||
|
// convolution transposed with old attributes and following add operation.
|
||||||
|
void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr,
|
||||||
|
ConvolutionTransposedAttributes* attr);
|
||||||
|
|
||||||
|
// Modify FullyConnectedAttributes so that after making fully connected with
|
||||||
|
// modified attributes we will have the same result as fully connected
|
||||||
|
// with old attributes and following add operation.
|
||||||
|
void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
||||||
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
|
// modified attributes we will have the same result as add operation and
|
||||||
|
// convolution with old attributes
|
||||||
|
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
|
||||||
|
Convolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||||
|
// convolution with modified attributes we will have the same result as add
|
||||||
|
// operation and depth wise convolution with old attributes
|
||||||
|
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify FullyConnectedAttributes so that after making fully connected
|
||||||
|
// with modified attributes we will have the same result as add operation and
|
||||||
|
// fully connected with old attributes
|
||||||
|
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
|
||||||
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_ADD_TO_CONV_H_
|
@ -0,0 +1,281 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
using ::testing::FloatNear;
|
||||||
|
using ::testing::Pointwise;
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MergeConvolutionWithAddTest, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
Convolution2DAttributes conv_attr;
|
||||||
|
conv_attr.padding.prepended = HW(0, 0);
|
||||||
|
conv_attr.padding.appended = HW(0, 0);
|
||||||
|
conv_attr.strides = HW(1, 1);
|
||||||
|
conv_attr.dilations = HW(1, 1);
|
||||||
|
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
|
||||||
|
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
|
||||||
|
conv_attr.bias.shape = Linear(16);
|
||||||
|
conv_attr.bias.data.resize(16);
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(16);
|
||||||
|
add_tensor.data.resize(16);
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
auto conv_node = graph.NewNode();
|
||||||
|
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv_node->operation.attributes = conv_attr;
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
add_node->operation.attributes = add_attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link1;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok());
|
||||||
|
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergeConvolutionWithAdd();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_convolution_with_add", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, graph.nodes().size());
|
||||||
|
EXPECT_EQ(2, graph.values().size());
|
||||||
|
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||||
|
graph.nodes()[0]->operation.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MergeAddWithConvolutionTest, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
Convolution2DAttributes conv_attr;
|
||||||
|
conv_attr.padding.prepended = HW(0, 0);
|
||||||
|
conv_attr.padding.appended = HW(0, 0);
|
||||||
|
conv_attr.strides = HW(1, 1);
|
||||||
|
conv_attr.dilations = HW(1, 1);
|
||||||
|
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
|
||||||
|
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
|
||||||
|
conv_attr.bias.shape = Linear(16);
|
||||||
|
conv_attr.bias.data.resize(16);
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(8);
|
||||||
|
add_tensor.data.resize(8);
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
auto conv_node = graph.NewNode();
|
||||||
|
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv_node->operation.attributes = conv_attr;
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
add_node->operation.attributes = add_attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link1;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
|
||||||
|
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergeAddWithConvolution();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_add_with_convolution", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, graph.nodes().size());
|
||||||
|
EXPECT_EQ(2, graph.values().size());
|
||||||
|
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||||
|
graph.nodes()[0]->operation.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddAfterConvolution2DTest, Smoke) {
|
||||||
|
Convolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.1f, 1.2f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(2);
|
||||||
|
add_tensor.data = {0.3f, 0.7f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseConvolution2DWithAdd(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddAfterDepthwiseConvolution2DTest, Smoke) {
|
||||||
|
DepthwiseConvolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(4);
|
||||||
|
attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(4);
|
||||||
|
add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseDepthwiseConvolution2DWithAdd(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {1.4f, 1.9f, 1.8f, 1.5f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddAfterConvolutionTransposedTest, Smoke) {
|
||||||
|
ConvolutionTransposedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.1f, 1.2f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(2);
|
||||||
|
add_tensor.data = {0.3f, 0.7f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseConvolutionTransposedWithAdd(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddAfterFullyConnectedTest, Smoke) {
|
||||||
|
FullyConnectedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 1, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.1f, 1.2f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(2);
|
||||||
|
add_tensor.data = {0.3f, 0.7f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseFullyConnectedWithAdd(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddBeforeConvolution2DTest, Smoke) {
|
||||||
|
Convolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.1f, 1.2f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(2);
|
||||||
|
add_tensor.data = {2.0f, 0.5f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseAddWithConvolution2D(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {2.2f, 4.3f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddBeforeDepthwiseConvolution2DTest, Smoke) {
|
||||||
|
DepthwiseConvolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(4);
|
||||||
|
attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(4);
|
||||||
|
add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseAddWithDepthwiseConvolution2D(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {1.22f, 1.56f, 1.72f, 2.38f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseAddBeforeFullyConnectedTest, Smoke) {
|
||||||
|
FullyConnectedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 1, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.1f, 1.2f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(2);
|
||||||
|
add_tensor.data = {0.5f, 2.0f};
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
|
||||||
|
FuseAddWithFullyConnected(add_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.55f, 2.15f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,304 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class MergeConvolutionWithMul : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final {
|
||||||
|
auto& conv_node = *sequence[0];
|
||||||
|
auto& mul_node = *sequence[1];
|
||||||
|
if (mul_node.operation.type != ToString(OperationType::MUL) &&
|
||||||
|
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiplyScalarAttributes mul_attr =
|
||||||
|
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
|
||||||
|
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
||||||
|
&mul_attr.param) &&
|
||||||
|
!absl::get_if<float>(&mul_attr.param)) {
|
||||||
|
return {
|
||||||
|
TransformStatus::DECLINED,
|
||||||
|
"This fuse applicable only for broadcast or scalar multiplication."};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||||
|
Convolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseConvolution2DWithMultiply(mul_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
|
||||||
|
ConvolutionTransposedAttributes* conv_attr =
|
||||||
|
absl::any_cast<ConvolutionTransposedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseConvolutionTransposedWithMultiply(mul_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||||
|
DepthwiseConvolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseDepthwiseConvolution2DWithMultiply(mul_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::FULLY_CONNECTED)) {
|
||||||
|
FullyConnectedAttributes* conv_attr =
|
||||||
|
absl::any_cast<FullyConnectedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseFullyConnectedWithMultiply(mul_attr, conv_attr);
|
||||||
|
} else {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status = RemoveFollowingNode(graph, &mul_node, &conv_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove mul node after convolution: " +
|
||||||
|
status.error_message()};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MergeMulWithConvolution : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final {
|
||||||
|
auto& conv_node = *sequence[1];
|
||||||
|
auto& mul_node = *sequence[0];
|
||||||
|
if (mul_node.operation.type != ToString(OperationType::MUL) &&
|
||||||
|
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiplyScalarAttributes mul_attr =
|
||||||
|
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
|
||||||
|
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
||||||
|
&mul_attr.param) &&
|
||||||
|
!absl::get_if<float>(&mul_attr.param)) {
|
||||||
|
return {
|
||||||
|
TransformStatus::DECLINED,
|
||||||
|
"This fuse applicable only for broadcast or scalar multiplication."};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||||
|
Convolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseMultiplyWithConvolution2D(mul_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
|
||||||
|
ConvolutionTransposedAttributes* conv_attr =
|
||||||
|
absl::any_cast<ConvolutionTransposedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseMultiplyWithConvolutionTransposed(mul_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||||
|
DepthwiseConvolution2DAttributes* conv_attr =
|
||||||
|
absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, conv_attr);
|
||||||
|
} else if (conv_node.operation.type ==
|
||||||
|
ToString(OperationType::FULLY_CONNECTED)) {
|
||||||
|
FullyConnectedAttributes* conv_attr =
|
||||||
|
absl::any_cast<FullyConnectedAttributes>(
|
||||||
|
&conv_node.operation.attributes);
|
||||||
|
FuseMultiplyWithFullyConnected(mul_attr, conv_attr);
|
||||||
|
} else {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status = RemovePrecedingNode(graph, &mul_node, &conv_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove mul node after convolution: " +
|
||||||
|
status.error_message()};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithMul() {
|
||||||
|
return absl::make_unique<MergeConvolutionWithMul>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
|
||||||
|
return absl::make_unique<MergeMulWithConvolution>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
Convolution2DAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
const float multiplier = mul ? mul->data[d] : *mul_scalar;
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!attr->bias.data.empty()) {
|
||||||
|
attr->bias.data[d] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseDepthwiseConvolution2DWithMultiply(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int g = 0; g < attr->weights.shape.o; ++g) {
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const int d = s * attr->weights.shape.o + g;
|
||||||
|
const float multiplier = mul ? mul->data[d] : *mul_scalar;
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!attr->bias.data.empty()) {
|
||||||
|
attr->bias.data[d] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseConvolutionTransposedWithMultiply(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
ConvolutionTransposedAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
const float multiplier = mul ? mul->data[d] : *mul_scalar;
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!attr->bias.data.empty()) {
|
||||||
|
attr->bias.data[d] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
FullyConnectedAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
const float multiplier = mul ? mul->data[d] : *mul_scalar;
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, 0, 0, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
if (!attr->bias.data.empty()) {
|
||||||
|
attr->bias.data[d] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
Convolution2DAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float multiplier = mul ? mul->data[s] : *mul_scalar;
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseMultiplyWithDepthwiseConvolution2D(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float multiplier = mul ? mul->data[s] : *mul_scalar;
|
||||||
|
for (int g = 0; g < attr->weights.shape.o; ++g) {
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseMultiplyWithConvolutionTransposed(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
ConvolutionTransposedAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float multiplier = mul ? mul->data[s] : *mul_scalar;
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||||
|
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
FullyConnectedAttributes* attr) {
|
||||||
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
|
const float multiplier = mul ? mul->data[s] : *mul_scalar;
|
||||||
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
|
const int index = attr->weights.shape.LinearIndex({d, 0, 0, s});
|
||||||
|
attr->weights.data[index] *= multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_MUL_TO_CONV_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_MUL_TO_CONV_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Fuse Multiply Scalar or Multiply Broadcast after Convolution(Convolution2D,
|
||||||
|
// DepthWise, TransposedConvolution, FullyConnected) into weights and biases of
|
||||||
|
// convolution.
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithMul();
|
||||||
|
|
||||||
|
// Fuse Multiply Scalar or Multiply Broadcast before Convolution(Convolution2D,
|
||||||
|
// DepthWise, TransposedConvolution, FullyConnected) into weights and biases of
|
||||||
|
// convolution.
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution();
|
||||||
|
|
||||||
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
|
// modified attributes we will have the same result as convolution
|
||||||
|
// with old attributes and following multiply operation.
|
||||||
|
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
Convolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||||
|
// convolution with modified attributes we will have the same result as depth
|
||||||
|
// wise convolution with old attributes and following multiply operation.
|
||||||
|
void FuseDepthwiseConvolution2DWithMultiply(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||||
|
// transposed with modified attributes we will have the same result as
|
||||||
|
// convolution transposed with old attributes and following multiply operation.
|
||||||
|
void FuseConvolutionTransposedWithMultiply(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
ConvolutionTransposedAttributes* attr);
|
||||||
|
|
||||||
|
// Modify FullyConnectedAttributes so that after making fully connected with
|
||||||
|
// modified attributes we will have the same result as fully connected
|
||||||
|
// with old attributes and following multiply operation.
|
||||||
|
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
|
// modified attributes we will have the same result as multiply operation and
|
||||||
|
// convolution with old attributes
|
||||||
|
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
Convolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||||
|
// convolution with modified attributes we will have the same result as multiply
|
||||||
|
// operation and depth wise convolution with old attributes
|
||||||
|
void FuseMultiplyWithDepthwiseConvolution2D(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
DepthwiseConvolution2DAttributes* attr);
|
||||||
|
|
||||||
|
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||||
|
// transposed with modified attributes we will have the same result as multiply
|
||||||
|
// operation and convolution transposed with old attributes
|
||||||
|
void FuseMultiplyWithConvolutionTransposed(
|
||||||
|
const MultiplyScalarAttributes& mul_attr,
|
||||||
|
ConvolutionTransposedAttributes* attr);
|
||||||
|
|
||||||
|
// Modify FullyConnectedAttributes so that after making fully connected
|
||||||
|
// with modified attributes we will have the same result as multiply
|
||||||
|
// operation and fully connected with old attributes
|
||||||
|
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
|
||||||
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_MUL_TO_CONV_H_
|
@ -0,0 +1,303 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
using ::testing::FloatNear;
|
||||||
|
using ::testing::Pointwise;
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MergeConvolutionWithMulTest, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
Convolution2DAttributes conv_attr;
|
||||||
|
conv_attr.padding.prepended = HW(0, 0);
|
||||||
|
conv_attr.padding.appended = HW(0, 0);
|
||||||
|
conv_attr.strides = HW(1, 1);
|
||||||
|
conv_attr.dilations = HW(1, 1);
|
||||||
|
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
|
||||||
|
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
|
||||||
|
conv_attr.bias.shape = Linear(16);
|
||||||
|
conv_attr.bias.data.resize(16);
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(16);
|
||||||
|
mul_tensor.data.resize(16);
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
auto conv_node = graph.NewNode();
|
||||||
|
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv_node->operation.attributes = conv_attr;
|
||||||
|
auto mul_node = graph.NewNode();
|
||||||
|
mul_node->operation.type = ToString(OperationType::MUL);
|
||||||
|
mul_node->operation.attributes = mul_attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link1;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok());
|
||||||
|
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergeConvolutionWithMul();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_convolution_with_mul", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, graph.nodes().size());
|
||||||
|
EXPECT_EQ(2, graph.values().size());
|
||||||
|
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||||
|
graph.nodes()[0]->operation.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MergeMulWithConvolutionTest, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(8);
|
||||||
|
mul_tensor.data.resize(8);
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
Convolution2DAttributes conv_attr;
|
||||||
|
conv_attr.padding.prepended = HW(0, 0);
|
||||||
|
conv_attr.padding.appended = HW(0, 0);
|
||||||
|
conv_attr.strides = HW(1, 1);
|
||||||
|
conv_attr.dilations = HW(1, 1);
|
||||||
|
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
|
||||||
|
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
|
||||||
|
conv_attr.bias.shape = Linear(16);
|
||||||
|
conv_attr.bias.data.resize(16);
|
||||||
|
|
||||||
|
auto conv_node = graph.NewNode();
|
||||||
|
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv_node->operation.attributes = conv_attr;
|
||||||
|
auto mul_node = graph.NewNode();
|
||||||
|
mul_node->operation.type = ToString(OperationType::MUL);
|
||||||
|
mul_node->operation.attributes = mul_attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok());
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link1;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok());
|
||||||
|
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergeMulWithConvolution();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_mul_with_convolution", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, graph.nodes().size());
|
||||||
|
EXPECT_EQ(2, graph.values().size());
|
||||||
|
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||||
|
graph.nodes()[0]->operation.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulAfterConvolution2DTest, Smoke) {
|
||||||
|
Convolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.5f, 2.5f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(2);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseConvolution2DWithMultiply(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.05f, 0.1f, 0.15f, 0.2f, 1.0f, 1.2f, 1.4f, 1.6f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {0.75f, 5.0f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulAfterDepthwiseConvolution2DTest, Smoke) {
|
||||||
|
DepthwiseConvolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(4);
|
||||||
|
attr.bias.data = {1.5f, 2.5f, 1.0f, 2.0f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(4);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.05f, 0.8f, 0.15f, 1.6f, 1.0f, 0.15f, 1.4f, 0.2f}));
|
||||||
|
EXPECT_THAT(attr.bias.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {0.75f, 5.0f, 4.0f, 0.5f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulAfterConvolutionTransposedTest, Smoke) {
|
||||||
|
ConvolutionTransposedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.5f, 2.5f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(2);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseConvolutionTransposedWithMultiply(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.05f, 0.1f, 0.15f, 0.2f, 1.0f, 1.2f, 1.4f, 1.6f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {0.75f, 5.0f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulAfterFullyConnectedTest, Smoke) {
|
||||||
|
FullyConnectedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 1, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.5f, 2.5f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(2);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseFullyConnectedWithMultiply(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {0.05f, 0.1f, 0.6f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {0.75f, 5.0f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulBeforeConvolution2DTest, Smoke) {
|
||||||
|
Convolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.5f, 2.5f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(2);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseMultiplyWithConvolution2D(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.05f, 0.4f, 0.15f, 0.8f, 0.25f, 1.2f, 0.35f, 1.6f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.5f, 2.5f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulBeforeDepthwiseConvolution2DTest, Smoke) {
|
||||||
|
DepthwiseConvolution2DAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(4);
|
||||||
|
attr.bias.data = {1.5f, 2.5f, 1.0f, 2.0f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(4);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.05f, 0.4f, 0.15f, 0.8f, 0.25f, 1.2f, 0.35f, 1.6f}));
|
||||||
|
EXPECT_THAT(attr.bias.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {1.5f, 2.5f, 1.0f, 2.0f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulBeforeConvolutionTransposedTest, Smoke) {
|
||||||
|
ConvolutionTransposedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.5f, 2.5f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(2);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseMultiplyWithConvolutionTransposed(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6),
|
||||||
|
{0.05f, 0.4f, 0.15f, 0.8f, 0.25f, 1.2f, 0.35f, 1.6f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.5f, 2.5f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FuseMulBeforeFullyConnectedTest, Smoke) {
|
||||||
|
FullyConnectedAttributes attr;
|
||||||
|
attr.weights.shape = OHWI(2, 1, 1, 2);
|
||||||
|
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
|
||||||
|
attr.bias.shape = Linear(2);
|
||||||
|
attr.bias.data = {1.5f, 2.5f};
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
|
mul_tensor.shape = Linear(2);
|
||||||
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
|
MultiplyScalarAttributes mul_attr;
|
||||||
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
|
FuseMultiplyWithFullyConnected(mul_attr, &attr);
|
||||||
|
|
||||||
|
EXPECT_THAT(attr.weights.data,
|
||||||
|
Pointwise(FloatNear(1e-6), {0.05f, 0.4f, 0.15f, 0.8f}));
|
||||||
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.5f, 2.5f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
bool ApplyGeneralTransformations(ModelTransformer* transformer) {
|
||||||
|
// whenever any of these transforms return false, that means that a graph
|
||||||
|
// is in the broken state and processing should not continue.
|
||||||
|
return transformer->Apply("remove_degenerate_upsampling",
|
||||||
|
NewRemoveDegenerateUpsampling().get()) &&
|
||||||
|
transformer->Apply("remove_single_input_add",
|
||||||
|
NewRemoveSingleInputAdd().get()) &&
|
||||||
|
transformer->Apply("remove_single_input_concat",
|
||||||
|
NewRemoveSingleInputConcat().get()) &&
|
||||||
|
transformer->Apply("make_padding_from_concat",
|
||||||
|
NewMakePaddingFromConcat().get()) &&
|
||||||
|
transformer->Apply("make_fully_connected_from_convolution",
|
||||||
|
NewMakeFullyConnectedFromConvolution().get()) &&
|
||||||
|
transformer->Apply("merge_padding_with_convolution",
|
||||||
|
NewMergePaddingWithConvolution2D().get()) &&
|
||||||
|
transformer->Apply("merge_padding_with_pooling",
|
||||||
|
NewMergePaddingWithPooling().get()) &&
|
||||||
|
transformer->Apply("merge_padding_with_depthwise_convolution",
|
||||||
|
NewMergePaddingWithDepthwiseConvolution().get()) &&
|
||||||
|
transformer->Apply("merge_convolution_with_mul",
|
||||||
|
NewMergeConvolutionWithMul().get()) &&
|
||||||
|
transformer->Apply("merge_convolution_with_add",
|
||||||
|
NewMergeConvolutionWithAdd().get()) &&
|
||||||
|
transformer->Apply("merge_mul_with_convolution",
|
||||||
|
NewMergeMulWithConvolution().get()) &&
|
||||||
|
transformer->Apply("merge_add_with_convolution",
|
||||||
|
NewMergeAddWithConvolution().get());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GENERAL_TRANSFORMATIONS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GENERAL_TRANSFORMATIONS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// @return false when something went wrong that turned a graph in a broken state
|
||||||
|
bool ApplyGeneralTransformations(ModelTransformer* transformer);
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GENERAL_TRANSFORMATIONS_H_
|
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool IsConvEquivalentToFullyConnected(const Convolution2DAttributes& attr) {
|
||||||
|
return attr.weights.shape.w == 1 && //
|
||||||
|
attr.weights.shape.h == 1 && //
|
||||||
|
attr.strides == HW(1, 1) && //
|
||||||
|
attr.dilations == HW(1, 1) && //
|
||||||
|
attr.padding.prepended == HW(0, 0) && //
|
||||||
|
attr.padding.appended == HW(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
class MakeFullyConnectedFromConvolution : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||||
|
if (node->operation.type != ToString(OperationType::CONVOLUTION_2D)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
auto inputs = graph->FindInputs(node->id);
|
||||||
|
if (inputs.size() != 1) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& input_shape = inputs[0]->tensor.shape;
|
||||||
|
if (input_shape.w != 1 || input_shape.h != 1) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& conv_attr = absl::any_cast<const Convolution2DAttributes&>(
|
||||||
|
node->operation.attributes);
|
||||||
|
if (!IsConvEquivalentToFullyConnected(conv_attr)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
FullyConnectedAttributes fc_attr;
|
||||||
|
fc_attr.weights = conv_attr.weights;
|
||||||
|
fc_attr.bias = conv_attr.bias;
|
||||||
|
|
||||||
|
node->operation.attributes = fc_attr;
|
||||||
|
node->operation.type = ToString(OperationType::FULLY_CONNECTED);
|
||||||
|
return {TransformStatus::APPLIED,
|
||||||
|
"Replaced convolution with fully connected."};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<NodeTransformation> NewMakeFullyConnectedFromConvolution() {
|
||||||
|
return absl::make_unique<MakeFullyConnectedFromConvolution>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_FULLY_CONNECTED_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_FULLY_CONNECTED_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Turns convolution with kernel 1x1 and input tensor with h=1 and w=1 into
|
||||||
|
// fully connected operation
|
||||||
|
std::unique_ptr<NodeTransformation> NewMakeFullyConnectedFromConvolution();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_FULLY_CONNECTED_H_
|
@ -0,0 +1,108 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MakeFullyConnected, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
Convolution2DAttributes attr0;
|
||||||
|
attr0.padding.prepended = HW(0, 0);
|
||||||
|
attr0.padding.appended = HW(0, 0);
|
||||||
|
attr0.strides = HW(1, 1);
|
||||||
|
attr0.dilations = HW(1, 1);
|
||||||
|
attr0.weights.shape = OHWI(16, 1, 1, 8);
|
||||||
|
attr0.bias.shape = Linear(16);
|
||||||
|
|
||||||
|
Convolution2DAttributes attr1;
|
||||||
|
attr1.padding.prepended = HW(0, 0);
|
||||||
|
attr1.padding.appended = HW(0, 0);
|
||||||
|
attr1.strides = HW(4, 4);
|
||||||
|
attr1.dilations = HW(1, 1);
|
||||||
|
attr1.weights.shape = OHWI(16, 4, 4, 16);
|
||||||
|
attr1.bias.shape = Linear(16);
|
||||||
|
|
||||||
|
Convolution2DAttributes attr2;
|
||||||
|
attr2.padding.prepended = HW(0, 0);
|
||||||
|
attr2.padding.appended = HW(0, 0);
|
||||||
|
attr2.strides = HW(1, 1);
|
||||||
|
attr2.dilations = HW(1, 1);
|
||||||
|
attr2.weights.shape = OHWI(32, 1, 1, 16);
|
||||||
|
attr2.bias.shape = Linear(32);
|
||||||
|
|
||||||
|
auto conv1x1_node0 = graph.NewNode();
|
||||||
|
conv1x1_node0->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv1x1_node0->operation.attributes = attr0;
|
||||||
|
auto conv4x4_node1 = graph.NewNode();
|
||||||
|
conv4x4_node1->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv4x4_node1->operation.attributes = attr1;
|
||||||
|
auto conv1x1_node2 = graph.NewNode();
|
||||||
|
conv1x1_node2->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
conv1x1_node2->operation.attributes = attr2;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok());
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 1, 1, 32);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link1;
|
||||||
|
ASSERT_TRUE(
|
||||||
|
ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok());
|
||||||
|
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link2;
|
||||||
|
ASSERT_TRUE(
|
||||||
|
ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok());
|
||||||
|
link2->tensor.shape = BHWC(1, 1, 1, 16);
|
||||||
|
|
||||||
|
ASSERT_EQ(3, graph.nodes().size());
|
||||||
|
ASSERT_EQ(4, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewMakeFullyConnectedFromConvolution();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("make_fully_connected", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(3, graph.nodes().size());
|
||||||
|
ASSERT_EQ(4, graph.values().size());
|
||||||
|
ASSERT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||||
|
graph.nodes()[0]->operation.type);
|
||||||
|
ASSERT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||||
|
graph.nodes()[1]->operation.type);
|
||||||
|
ASSERT_EQ(ToString(OperationType::FULLY_CONNECTED),
|
||||||
|
graph.nodes()[2]->operation.type);
|
||||||
|
auto fc_attr = absl::any_cast<FullyConnectedAttributes>(
|
||||||
|
graph.nodes()[2]->operation.attributes);
|
||||||
|
EXPECT_EQ(OHWI(32, 1, 1, 16), fc_attr.weights.shape);
|
||||||
|
EXPECT_EQ(Linear(32), fc_attr.bias.shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,101 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool IsConstZeros(const Node& node) {
|
||||||
|
if (node.operation.type != ToString(OperationType::CONST)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto& attr =
|
||||||
|
absl::any_cast<const ConstTensorAttributes&>(node.operation.attributes);
|
||||||
|
for (auto f : attr.tensor.data) {
|
||||||
|
if (f != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
class MakePaddingFromZerosConcat : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||||
|
if (node->operation.type != ToString(OperationType::CONCAT)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
auto inputs = graph->FindInputs(node->id);
|
||||||
|
if (inputs.size() != 2) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool first = true;
|
||||||
|
for (auto input : inputs) {
|
||||||
|
auto dep = graph->FindProducer(input->id);
|
||||||
|
if (dep != nullptr && IsConstZeros(*dep)) {
|
||||||
|
auto& concat_attr =
|
||||||
|
absl::any_cast<const ConcatAttributes&>(node->operation.attributes);
|
||||||
|
PadAttributes pad_attr;
|
||||||
|
pad_attr.type = PaddingContentType::ZEROS;
|
||||||
|
pad_attr.appended = HWC(0, 0, 0);
|
||||||
|
pad_attr.prepended = HWC(0, 0, 0);
|
||||||
|
HWC* p = first ? &pad_attr.prepended : &pad_attr.appended;
|
||||||
|
switch (concat_attr.axis) {
|
||||||
|
case Axis::HEIGHT:
|
||||||
|
p->h = input->tensor.shape.h;
|
||||||
|
break;
|
||||||
|
case Axis::WIDTH:
|
||||||
|
p->w = input->tensor.shape.w;
|
||||||
|
break;
|
||||||
|
case Axis::CHANNELS:
|
||||||
|
p->c = input->tensor.shape.c;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return {TransformStatus::DECLINED,
|
||||||
|
"Padding for concat axis is unsupported: " +
|
||||||
|
ToString(concat_attr.axis)};
|
||||||
|
}
|
||||||
|
Status status = RemovePrecedingNode(graph, dep, node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove const node: " + status.error_message()};
|
||||||
|
}
|
||||||
|
node->operation.attributes = pad_attr;
|
||||||
|
node->operation.type = ToString(OperationType::PAD);
|
||||||
|
return {TransformStatus::APPLIED, "Replaced concat with padding"};
|
||||||
|
}
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<NodeTransformation> NewMakePaddingFromConcat() {
|
||||||
|
return absl::make_unique<MakePaddingFromZerosConcat>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_PADDING_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_PADDING_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Turns concat that handles only two tensors, where one tensor is zeros, into
|
||||||
|
// padding operation.
|
||||||
|
std::unique_ptr<NodeTransformation> NewMakePaddingFromConcat();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_PADDING_H_
|
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MakePadding, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 2, 3, 5);
|
||||||
|
|
||||||
|
auto concat_node = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(concat_node->id, input->id).ok());
|
||||||
|
concat_node->operation.type = ToString(OperationType::CONCAT);
|
||||||
|
ConcatAttributes attr;
|
||||||
|
attr.axis = Axis::HEIGHT;
|
||||||
|
concat_node->operation.attributes = attr;
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 7, 3, 5);
|
||||||
|
|
||||||
|
auto const_node = graph.NewNode();
|
||||||
|
const_node->operation.type = ToString(OperationType::CONST);
|
||||||
|
ConstTensorAttributes const_attr;
|
||||||
|
const_attr.tensor.shape = BHWC(1, 5, 3, 5);
|
||||||
|
const_attr.tensor.data =
|
||||||
|
std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0);
|
||||||
|
const_node->operation.attributes = const_attr;
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* const_link;
|
||||||
|
ASSERT_TRUE(
|
||||||
|
ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok());
|
||||||
|
const_link->tensor.shape = const_attr.tensor.shape;
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
|
||||||
|
auto transformation = NewMakePaddingFromConcat();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("make_padding", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
auto pad_node = graph.nodes()[0];
|
||||||
|
ASSERT_EQ(ToString(OperationType::PAD), pad_node->operation.type);
|
||||||
|
auto pad_attr = absl::any_cast<PadAttributes>(pad_node->operation.attributes);
|
||||||
|
EXPECT_EQ(HWC(0, 0, 0), pad_attr.prepended);
|
||||||
|
EXPECT_EQ(HWC(5, 0, 0), pad_attr.appended);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCHING_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCHING_H_
|
||||||
|
|
||||||
|
// A file provides predicates to match subgraphs.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Returns true if a container of nodes contains nodes that all match given
|
||||||
|
// operation_types.
|
||||||
|
template <typename T>
|
||||||
|
bool MatchesByOperationType(const T& nodes,
|
||||||
|
const std::vector<std::string>& types) {
|
||||||
|
if (nodes.size() != types.size()) return false;
|
||||||
|
return std::mismatch(nodes.begin(), nodes.end(), types.begin(),
|
||||||
|
[&](typename T::value_type a, const std::string& b) {
|
||||||
|
return a->operation.type == b;
|
||||||
|
})
|
||||||
|
.first == nodes.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCHING_H_
|
@ -0,0 +1,171 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/matching.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename Attr>
|
||||||
|
class MergePaddingWith2DOperation : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
explicit MergePaddingWith2DOperation(OperationType operation_type)
|
||||||
|
: operations_to_match_(
|
||||||
|
{ToString(OperationType::PAD), ToString(operation_type)}) {}
|
||||||
|
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final {
|
||||||
|
if (!MatchesByOperationType(sequence, operations_to_match_)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* pad_node = sequence.front();
|
||||||
|
Node* op_node = sequence.back();
|
||||||
|
|
||||||
|
PadAttributes pad_attr =
|
||||||
|
absl::any_cast<PadAttributes>(pad_node->operation.attributes);
|
||||||
|
|
||||||
|
if (pad_attr.type != PaddingContentType::ZEROS) {
|
||||||
|
return {TransformStatus::DECLINED, "Only Zero padding is supported."};
|
||||||
|
}
|
||||||
|
if (pad_attr.appended.c != 0 || pad_attr.prepended.c != 0) {
|
||||||
|
return {TransformStatus::DECLINED,
|
||||||
|
"Pad has non-zero padding on non HW axis."};
|
||||||
|
}
|
||||||
|
|
||||||
|
Attr* node_attr = absl::any_cast<Attr>(&op_node->operation.attributes);
|
||||||
|
Status status = RemovePrecedingNode(graph, pad_node, op_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove Pad node with Operation node: " +
|
||||||
|
status.error_message()};
|
||||||
|
}
|
||||||
|
|
||||||
|
node_attr->padding.appended.h += pad_attr.appended.h;
|
||||||
|
node_attr->padding.appended.w += pad_attr.appended.w;
|
||||||
|
node_attr->padding.prepended.h += pad_attr.prepended.h;
|
||||||
|
node_attr->padding.prepended.w += pad_attr.prepended.w;
|
||||||
|
return {
|
||||||
|
TransformStatus::APPLIED,
|
||||||
|
absl::StrCat("Added padding: prepended = {h = ", pad_attr.prepended.h,
|
||||||
|
", w = ", pad_attr.prepended.w, "}, appended = { h = ",
|
||||||
|
pad_attr.appended.h, ", w = ", pad_attr.appended.w, "}")};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::vector<std::string> operations_to_match_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergePaddingWithPooling() {
|
||||||
|
return absl::make_unique<MergePaddingWith2DOperation<Pooling2DAttributes>>(
|
||||||
|
OperationType::POOLING_2D);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergePaddingWithConvolution2D() {
|
||||||
|
return absl::make_unique<
|
||||||
|
MergePaddingWith2DOperation<Convolution2DAttributes>>(
|
||||||
|
OperationType::CONVOLUTION_2D);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation>
|
||||||
|
NewMergePaddingWithDepthwiseConvolution() {
|
||||||
|
return absl::make_unique<
|
||||||
|
MergePaddingWith2DOperation<DepthwiseConvolution2DAttributes>>(
|
||||||
|
OperationType::DEPTHWISE_CONVOLUTION);
|
||||||
|
}
|
||||||
|
|
||||||
|
class MergePaddingWithAddOperation : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||||
|
if (node->operation.type != ToString(OperationType::PAD)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
auto inputs = graph->FindInputs(node->id);
|
||||||
|
if (inputs.size() != 1) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||||
|
if (input_shape.c % 4 != 0) {
|
||||||
|
return {TransformStatus::DECLINED,
|
||||||
|
"Pad with input where src_channels % 4 != 0"};
|
||||||
|
}
|
||||||
|
|
||||||
|
PadAttributes pad_attr =
|
||||||
|
absl::any_cast<PadAttributes>(node->operation.attributes);
|
||||||
|
|
||||||
|
if (pad_attr.type != PaddingContentType::ZEROS) {
|
||||||
|
return {TransformStatus::DECLINED, "Only Zero padding is supported."};
|
||||||
|
}
|
||||||
|
if (pad_attr.prepended != HWC(0, 0, 0) || pad_attr.appended.h != 0 ||
|
||||||
|
pad_attr.appended.w != 0) {
|
||||||
|
return {TransformStatus::DECLINED,
|
||||||
|
"Pad has padding not only in appended channels axis."};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pad_output = graph->FindOutputs(node->id)[0];
|
||||||
|
auto consumer_nodes = graph->FindConsumers(pad_output->id);
|
||||||
|
if (consumer_nodes.size() != 1) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
auto add_node = consumer_nodes[0];
|
||||||
|
auto consumer_type = OperationTypeFromString(add_node->operation.type);
|
||||||
|
if (consumer_type != OperationType::ADD) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
AddAttributes add_attr =
|
||||||
|
absl::any_cast<AddAttributes>(add_node->operation.attributes);
|
||||||
|
auto add_broadcated_vector =
|
||||||
|
absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||||
|
if (add_broadcated_vector) {
|
||||||
|
return {TransformStatus::SKIPPED,
|
||||||
|
"Can not remove padding when this broadcasted ADD"};
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status = RemovePrecedingNode(graph, node, add_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove Pad node " + status.error_message()};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {TransformStatus::APPLIED,
|
||||||
|
"Removed padding with zeroes in appended channels dimension"};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<NodeTransformation> NewMergePaddingWithAdd() {
|
||||||
|
return absl::make_unique<MergePaddingWithAddOperation>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MERGE_PADDING_WITH_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MERGE_PADDING_WITH_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergePaddingWithPooling();
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewMergePaddingWithConvolution2D();
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation>
|
||||||
|
NewMergePaddingWithDepthwiseConvolution();
|
||||||
|
|
||||||
|
// This transform requires Add operation support of unequal tensors on input.
|
||||||
|
// Padding should be with zeroes, and only appended in Z axis.
|
||||||
|
// Also input tensor channels should be divisible by 4(aligned).
|
||||||
|
// It should replace following pattern:
|
||||||
|
// 1) some tensor padded with zeroes in Z dim, for example from 24 to 32
|
||||||
|
// channels
|
||||||
|
// 2) than this tensor used only in Add operation and Add operation
|
||||||
|
// adds this useless zeroes on 24-32 channels.
|
||||||
|
// It removes this useless addition
|
||||||
|
// by using Add with unequal tensors on input. Instead of filling with zeroes
|
||||||
|
// and adding this part in Add operation, Add operation makes additional check
|
||||||
|
// for this tensor:
|
||||||
|
// if (channels < src_channels) {
|
||||||
|
// result += tensor_from_pad_operation.data[index];
|
||||||
|
// }
|
||||||
|
std::unique_ptr<NodeTransformation> NewMergePaddingWithAdd();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MERGE_PADDING_WITH_H_
|
@ -0,0 +1,151 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MergePaddingWith, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
|
||||||
|
auto pad_node = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(pad_node->id, input->id).ok());
|
||||||
|
pad_node->operation.type = ToString(OperationType::PAD);
|
||||||
|
PadAttributes attr;
|
||||||
|
attr.prepended = HWC(1, 1, 0);
|
||||||
|
attr.appended = HWC(2, 2, 0);
|
||||||
|
pad_node->operation.attributes = attr;
|
||||||
|
|
||||||
|
auto conv_node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* temp;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok());
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
|
||||||
|
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
Convolution2DAttributes conv_attr;
|
||||||
|
conv_attr.padding.appended = HW(0, 0);
|
||||||
|
conv_attr.padding.prepended = HW(0, 0);
|
||||||
|
conv_node->operation.attributes = conv_attr;
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergePaddingWithConvolution2D();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_padding", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
ASSERT_EQ(conv_node, graph.nodes()[0]);
|
||||||
|
conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(conv_node->operation.attributes);
|
||||||
|
EXPECT_EQ(HW(1, 1), conv_attr.padding.prepended);
|
||||||
|
EXPECT_EQ(HW(2, 2), conv_attr.padding.appended);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MergePaddingWith, MergeTwo) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
|
||||||
|
auto pad_node1 = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(pad_node1->id, input->id).ok());
|
||||||
|
pad_node1->operation.type = ToString(OperationType::PAD);
|
||||||
|
PadAttributes attr;
|
||||||
|
attr.prepended = HWC(1, 1, 0);
|
||||||
|
attr.appended = HWC(0, 0, 0);
|
||||||
|
pad_node1->operation.attributes = attr;
|
||||||
|
|
||||||
|
auto pad_node2 = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* temp;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok());
|
||||||
|
pad_node2->operation.type = ToString(OperationType::PAD);
|
||||||
|
attr.prepended = HWC(0, 0, 0);
|
||||||
|
attr.appended = HWC(2, 2, 0);
|
||||||
|
pad_node2->operation.attributes = attr;
|
||||||
|
|
||||||
|
auto conv_node = graph.NewNode();
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node2, conv_node, &temp).ok());
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
|
||||||
|
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||||
|
Convolution2DAttributes conv_attr;
|
||||||
|
conv_attr.padding.appended = HW(0, 0);
|
||||||
|
conv_attr.padding.prepended = HW(0, 0);
|
||||||
|
conv_node->operation.attributes = conv_attr;
|
||||||
|
|
||||||
|
ASSERT_EQ(3, graph.nodes().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergePaddingWithConvolution2D();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_padding", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
ASSERT_EQ(conv_node, graph.nodes()[0]);
|
||||||
|
conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(conv_node->operation.attributes);
|
||||||
|
EXPECT_EQ(HW(1, 1), conv_attr.padding.prepended);
|
||||||
|
EXPECT_EQ(HW(2, 2), conv_attr.padding.appended);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MergePaddingWithAdd, MergeOne) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input0 = graph.NewValue();
|
||||||
|
input0->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
auto input1 = graph.NewValue();
|
||||||
|
auto padded = graph.NewValue();
|
||||||
|
auto output = graph.NewValue();
|
||||||
|
|
||||||
|
auto pad_node = graph.NewNode();
|
||||||
|
pad_node->operation.type = ToString(OperationType::PAD);
|
||||||
|
PadAttributes pad_attr;
|
||||||
|
pad_attr.prepended = HWC(0, 0, 0);
|
||||||
|
pad_attr.appended = HWC(0, 0, 32);
|
||||||
|
pad_node->operation.attributes = pad_attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(pad_node->id, input0->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(pad_node->id, padded->id).ok());
|
||||||
|
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
AddAttributes add_attr;
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(add_node->id, padded->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(add_node->id, input1->id).ok());
|
||||||
|
ASSERT_TRUE(graph.SetProducer(add_node->id, output->id).ok());
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
add_node->operation.attributes = add_attr;
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(4, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewMergePaddingWithAdd();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("merge_padding", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
EXPECT_EQ(add_node, graph.nodes()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,100 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ShouldRemoveOperation = std::function<bool(GraphFloat32* graph, Node*)>;
|
||||||
|
|
||||||
|
class RemoveOperation : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
explicit RemoveOperation(ShouldRemoveOperation remove_predicate)
|
||||||
|
: remove_predicate_(std::move(remove_predicate)) {}
|
||||||
|
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final {
|
||||||
|
Node* prev_op_node = sequence.front();
|
||||||
|
Node* op_node = sequence.back();
|
||||||
|
if (!remove_predicate_(graph, op_node)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
Status status = RemoveFollowingNode(graph, op_node, prev_op_node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove a node: " + status.error_message()};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ShouldRemoveOperation remove_predicate_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewRemoveSingleInputConcat() {
|
||||||
|
// Using SequenceTransformation implies that CONCAT has a single input.
|
||||||
|
auto type = ToString(OperationType::CONCAT);
|
||||||
|
return absl::make_unique<RemoveOperation>(
|
||||||
|
[type](GraphFloat32* graph, Node* node) {
|
||||||
|
return type == node->operation.type;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
|
||||||
|
// Using SequenceTransformation implies that ADD has a single input.
|
||||||
|
auto type = ToString(OperationType::ADD);
|
||||||
|
return absl::make_unique<RemoveOperation>(
|
||||||
|
[type](GraphFloat32* graph, Node* node) {
|
||||||
|
if (node->operation.type != type) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto& attr =
|
||||||
|
absl::any_cast<const AddAttributes&>(node->operation.attributes);
|
||||||
|
return absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param) ==
|
||||||
|
nullptr;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewRemoveDegenerateUpsampling() {
|
||||||
|
auto type = ToString(OperationType::UPSAMPLE_2D);
|
||||||
|
return absl::make_unique<RemoveOperation>(
|
||||||
|
[type](GraphFloat32* graph, Node* node) {
|
||||||
|
if (node->operation.type != type) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto inputs = graph->FindInputs(node->id);
|
||||||
|
auto outputs = graph->FindOutputs(node->id);
|
||||||
|
return inputs.size() == 1 && outputs.size() == 1 &&
|
||||||
|
inputs[0]->tensor.shape == outputs[0]->tensor.shape;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_REMOVE_NOOP_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_REMOVE_NOOP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewRemoveSingleInputConcat();
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd();
|
||||||
|
|
||||||
|
std::unique_ptr<SequenceTransformation> NewRemoveDegenerateUpsampling();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_REMOVE_NOOP_H_
|
@ -0,0 +1,146 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(RemoveSingleInputAdd, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
auto first_node = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||||
|
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
add_node->operation.attributes = AddAttributes();
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* temp;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewRemoveSingleInputAdd();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("noop", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
ASSERT_EQ(first_node, graph.nodes()[0]);
|
||||||
|
ASSERT_EQ(input, graph.values()[0]);
|
||||||
|
ASSERT_EQ(output, graph.values()[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
auto first_node = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||||
|
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
AddAttributes attr;
|
||||||
|
attr.param = Tensor<Linear, DataType::FLOAT32>();
|
||||||
|
add_node->operation.attributes = attr;
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* temp;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewRemoveSingleInputAdd();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("noop", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
auto node_a = graph.NewNode();
|
||||||
|
auto node_b = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node_a->id, input->id).ok());
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok());
|
||||||
|
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* temp;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok());
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok());
|
||||||
|
ASSERT_EQ(3, graph.nodes().size());
|
||||||
|
ASSERT_EQ(4, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewRemoveSingleInputAdd();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("noop", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(3, graph.nodes().size());
|
||||||
|
ASSERT_EQ(4, graph.values().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveDegenerateUpsampling, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
auto first_node = graph.NewNode();
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||||
|
|
||||||
|
auto node_to_remove = graph.NewNode();
|
||||||
|
Value<TensorRefFloat32>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 5, 5, 1);
|
||||||
|
node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D);
|
||||||
|
Upsample2DAttributes attr;
|
||||||
|
attr.new_shape = HW(5, 5);
|
||||||
|
attr.type = UpsamplingType::BILINEAR;
|
||||||
|
node_to_remove->operation.attributes = attr;
|
||||||
|
|
||||||
|
Value<TensorRefFloat32>* link;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
|
||||||
|
link->tensor.shape = output->tensor.shape;
|
||||||
|
ASSERT_EQ(2, graph.nodes().size());
|
||||||
|
ASSERT_EQ(3, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewRemoveDegenerateUpsampling();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("noop", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
EXPECT_EQ(first_node, graph.nodes()[0]);
|
||||||
|
EXPECT_EQ(input, graph.values()[0]);
|
||||||
|
EXPECT_EQ(output, graph.values()[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
208
tensorflow/lite/delegates/gpu/common/types.h
Normal file
208
tensorflow/lite/delegates/gpu/common/types.h
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TYPES_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TYPES_H_
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <fp16.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// TODO(akulik): make these types Google-style compliant.
|
||||||
|
|
||||||
|
using HalfBits = uint16_t;
|
||||||
|
|
||||||
|
class alignas(2) half {
|
||||||
|
public:
|
||||||
|
HalfBits bits;
|
||||||
|
|
||||||
|
half() = default;
|
||||||
|
|
||||||
|
half(const half& f) : bits(f.bits) {}
|
||||||
|
|
||||||
|
explicit half(float other) { bits = fp16_ieee_from_fp32_value(other); }
|
||||||
|
|
||||||
|
void operator=(float f) { *this = half(f); }
|
||||||
|
|
||||||
|
operator float() const { return fp16_ieee_to_fp32_value(bits); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct alignas(sizeof(T)) Vec4 {
|
||||||
|
union {
|
||||||
|
struct {
|
||||||
|
T x, y, z, w;
|
||||||
|
};
|
||||||
|
std::array<T, 4> data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Vec4() : Vec4(T(0.0f)) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
Vec4(S x_, S y_, S z_, S w_) : x(x_), y(y_), z(z_), w(w_) {}
|
||||||
|
explicit Vec4(T v) : x(v), y(v), z(v), w(v) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
explicit Vec4(S v) : x(v), y(v), z(v), w(v) {}
|
||||||
|
|
||||||
|
Vec4(const Vec4& f) : x(f.x), y(f.y), z(f.z), w(f.w) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
Vec4(const Vec4<S>& f) : x(f.x), y(f.y), z(f.z), w(f.w) {}
|
||||||
|
|
||||||
|
Vec4& operator=(const Vec4& other) {
|
||||||
|
x = other.x;
|
||||||
|
y = other.y;
|
||||||
|
z = other.z;
|
||||||
|
w = other.w;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int size() { return 4; }
|
||||||
|
|
||||||
|
T& operator[](size_t n) { return data_[n]; }
|
||||||
|
T operator[](size_t n) const { return data_[n]; }
|
||||||
|
|
||||||
|
bool operator==(const Vec4& value) const {
|
||||||
|
return data_[0] == value[0] && data_[1] == value[1] &&
|
||||||
|
data_[2] == value[2] && data_[3] == value[3];
|
||||||
|
}
|
||||||
|
bool operator!=(const Vec4& value) const {
|
||||||
|
return !(this->operator==(value));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct alignas(sizeof(T)) Vec3 {
|
||||||
|
union {
|
||||||
|
struct {
|
||||||
|
T x, y, z;
|
||||||
|
};
|
||||||
|
std::array<T, 3> data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Vec3() : Vec3(T(0.0f)) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
constexpr Vec3(S x_, S y_, S z_) : x(x_), y(y_), z(z_) {}
|
||||||
|
explicit Vec3(T v) : x(v), y(v), z(v) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
explicit Vec3(S v) : x(v), y(v), z(v) {}
|
||||||
|
|
||||||
|
Vec3(const Vec3& f) : x(f.x), y(f.y), z(f.z) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
Vec3(const Vec3<S>& f) : x(f.x), y(f.y), z(f.z) {}
|
||||||
|
|
||||||
|
Vec3& operator=(const Vec3& other) {
|
||||||
|
x = other.x;
|
||||||
|
y = other.y;
|
||||||
|
z = other.z;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int size() { return 3; }
|
||||||
|
|
||||||
|
T& operator[](size_t n) { return data_[n]; }
|
||||||
|
T operator[](size_t n) const { return data_[n]; }
|
||||||
|
bool operator==(const Vec3& value) const {
|
||||||
|
return data_[0] == value[0] && data_[1] == value[1] && data_[2] == value[2];
|
||||||
|
}
|
||||||
|
bool operator!=(const Vec3& value) const {
|
||||||
|
return !(this->operator==(value));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct alignas(sizeof(T)) Vec2 {
|
||||||
|
union {
|
||||||
|
struct {
|
||||||
|
T x, y;
|
||||||
|
};
|
||||||
|
std::array<T, 2> data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Vec2() : Vec2(T(0.0f)) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
Vec2(S x_, S y_) : x(x_), y(y_) {}
|
||||||
|
explicit Vec2(T v) : x(v), y(v) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
explicit Vec2(S v) : x(v), y(v) {}
|
||||||
|
|
||||||
|
Vec2(const Vec2& f) : x(f.x), y(f.y) {}
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
Vec2(const Vec2<S>& f) : x(f.x), y(f.y) {}
|
||||||
|
|
||||||
|
Vec2& operator=(const Vec2& other) {
|
||||||
|
x = other.x;
|
||||||
|
y = other.y;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const Vec2& value) const {
|
||||||
|
return data_[0] == value[0] && data_[1] == value[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const Vec2& value) const {
|
||||||
|
return !(this->operator==(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int size() { return 2; }
|
||||||
|
|
||||||
|
T& operator[](size_t n) { return data_[n]; }
|
||||||
|
T operator[](size_t n) const { return data_[n]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
using float2 = Vec2<float>;
|
||||||
|
using half2 = Vec2<half>;
|
||||||
|
using byte2 = Vec2<int8_t>;
|
||||||
|
using ubyte2 = Vec2<uint8_t>;
|
||||||
|
using short2 = Vec2<int16_t>;
|
||||||
|
using ushort2 = Vec2<uint16_t>;
|
||||||
|
using int2 = Vec2<int32_t>;
|
||||||
|
using uint2 = Vec2<uint32_t>;
|
||||||
|
|
||||||
|
using float3 = Vec3<float>;
|
||||||
|
using half3 = Vec3<half>;
|
||||||
|
using byte3 = Vec3<int8_t>;
|
||||||
|
using ubyte3 = Vec3<uint8_t>;
|
||||||
|
using short3 = Vec3<int16_t>;
|
||||||
|
using ushort3 = Vec3<uint16_t>;
|
||||||
|
using int3 = Vec3<int32_t>;
|
||||||
|
using uint3 = Vec3<uint32_t>;
|
||||||
|
|
||||||
|
using float4 = Vec4<float>;
|
||||||
|
using half4 = Vec4<half>;
|
||||||
|
using byte4 = Vec4<int8_t>;
|
||||||
|
using ubyte4 = Vec4<uint8_t>;
|
||||||
|
using short4 = Vec4<int16_t>;
|
||||||
|
using ushort4 = Vec4<uint16_t>;
|
||||||
|
using int4 = Vec4<int32_t>;
|
||||||
|
using uint4 = Vec4<uint32_t>;
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TYPES_H_
|
51
tensorflow/lite/delegates/gpu/common/util.h
Normal file
51
tensorflow/lite/delegates/gpu/common/util.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_UTIL_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// @param n must be non negative
|
||||||
|
// @param divisor must be greater than zero
|
||||||
|
template <typename T, typename N>
|
||||||
|
T IntegralDivideRoundUp(T n, N divisor) {
|
||||||
|
const T div = static_cast<T>(divisor);
|
||||||
|
const T q = n / div;
|
||||||
|
return n % div == 0 ? q : q + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline ::tflite::gpu::uint3 IntegralDivideRoundUp(
|
||||||
|
::tflite::gpu::uint3 n, ::tflite::gpu::uint3 divisor) {
|
||||||
|
return ::tflite::gpu::uint3(IntegralDivideRoundUp(n.x, divisor.x),
|
||||||
|
IntegralDivideRoundUp(n.y, divisor.y),
|
||||||
|
IntegralDivideRoundUp(n.z, divisor.z));
|
||||||
|
}
|
||||||
|
|
||||||
|
// @param number or its components must be greater than zero
|
||||||
|
// @param n must be greater than zero
|
||||||
|
template <typename T, typename N>
|
||||||
|
T AlignByN(T number, N n) {
|
||||||
|
return IntegralDivideRoundUp(number, n) * n;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_UTIL_H_
|
53
tensorflow/lite/delegates/gpu/common/util_test.cc
Normal file
53
tensorflow/lite/delegates/gpu/common/util_test.cc
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using testing::Eq;
|
||||||
|
|
||||||
|
TEST(UtilTest, IntegralDivideRoundUp) {
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(0, 256), Eq(0));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(2u, 256), Eq(1));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(2, 256), Eq(1));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(255u, 256), Eq(1));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(255, 256), Eq(1));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(256u, 256), Eq(1));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(256, 256), Eq(1));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(257u, 256), Eq(2));
|
||||||
|
EXPECT_THAT(IntegralDivideRoundUp(257, 256), Eq(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UtilTest, AlignByN) {
|
||||||
|
EXPECT_THAT(AlignByN(0u, 256), Eq(0));
|
||||||
|
EXPECT_THAT(AlignByN(1u, 256), Eq(256));
|
||||||
|
EXPECT_THAT(AlignByN(255u, 256), Eq(256));
|
||||||
|
EXPECT_THAT(AlignByN(256u, 256), Eq(256));
|
||||||
|
EXPECT_THAT(AlignByN(257u, 256), Eq(512));
|
||||||
|
|
||||||
|
EXPECT_THAT(AlignByN(1, 4), Eq(4));
|
||||||
|
EXPECT_THAT(AlignByN(80, 4), Eq(80));
|
||||||
|
EXPECT_THAT(AlignByN(81, 4), Eq(84));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
433
tensorflow/lite/delegates/gpu/gl/BUILD
Normal file
433
tensorflow/lite/delegates/gpu/gl/BUILD
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "api",
|
||||||
|
srcs = ["api.cc"],
|
||||||
|
hdrs = ["api.h"],
|
||||||
|
deps = [
|
||||||
|
":command_queue",
|
||||||
|
":compiler",
|
||||||
|
":compiler_options",
|
||||||
|
":gl_call",
|
||||||
|
":gpu_info",
|
||||||
|
":node_shader",
|
||||||
|
":object",
|
||||||
|
":object_manager",
|
||||||
|
":portable",
|
||||||
|
":runtime",
|
||||||
|
":runtime_options",
|
||||||
|
":stats",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/workgroups:calculator",
|
||||||
|
] + select({
|
||||||
|
"//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
":serialization",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "command_queue",
|
||||||
|
srcs = ["command_queue.cc"],
|
||||||
|
hdrs = ["command_queue.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_program",
|
||||||
|
":gl_sync",
|
||||||
|
":gpu_info",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "common_cc_fbs",
|
||||||
|
srcs = ["common.fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generic schema for inference on GPU device.
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "compiled_model_cc_fbs",
|
||||||
|
srcs = ["compiled_model.fbs"],
|
||||||
|
flatc_args = [
|
||||||
|
"--scoped-enums",
|
||||||
|
],
|
||||||
|
includes = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:common_cc_fbs_includes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compiler",
|
||||||
|
srcs = ["compiler.cc"],
|
||||||
|
hdrs = ["compiler.h"],
|
||||||
|
deps = [
|
||||||
|
":compiler_options",
|
||||||
|
":float16_conversions",
|
||||||
|
":gpu_info",
|
||||||
|
":node_shader",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/compiler:compiled_node",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/compiler:fuse_auto_input",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/compiler:fuse_inline",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/compiler:fuse_inplace",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/compiler:shader_code",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/compiler:shader_codegen",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compiler_options",
|
||||||
|
hdrs = ["compiler_options.h"],
|
||||||
|
deps = [
|
||||||
|
":gpu_info",
|
||||||
|
":object",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "egl_context",
|
||||||
|
srcs = ["egl_context.cc"],
|
||||||
|
hdrs = ["egl_context.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "egl_environment",
|
||||||
|
srcs = ["egl_environment.cc"],
|
||||||
|
hdrs = ["egl_environment.h"],
|
||||||
|
deps = [
|
||||||
|
":egl_context",
|
||||||
|
":egl_surface",
|
||||||
|
":gl_call",
|
||||||
|
":gpu_info",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "egl_surface",
|
||||||
|
srcs = ["egl_surface.cc"],
|
||||||
|
hdrs = ["egl_surface.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "float16_conversions",
|
||||||
|
srcs = ["float16_conversions.cc"],
|
||||||
|
hdrs = ["float16_conversions.h"],
|
||||||
|
deps = [
|
||||||
|
":object",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"@FP16",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_buffer",
|
||||||
|
srcs = ["gl_buffer.cc"],
|
||||||
|
hdrs = ["gl_buffer.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "gl_buffer_test",
|
||||||
|
srcs = ["gl_buffer_test.cc"],
|
||||||
|
linkopts = [
|
||||||
|
"-lGLESv3",
|
||||||
|
"-lEGL",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":egl_environment",
|
||||||
|
":gl_buffer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_call",
|
||||||
|
hdrs = ["gl_call.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_errors",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_errors",
|
||||||
|
srcs = ["gl_errors.cc"],
|
||||||
|
hdrs = ["gl_errors.h"],
|
||||||
|
deps = [
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_program",
|
||||||
|
srcs = ["gl_program.cc"],
|
||||||
|
hdrs = ["gl_program.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":gl_shader",
|
||||||
|
":portable",
|
||||||
|
":uniform_parameter",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_shader",
|
||||||
|
srcs = ["gl_shader.cc"],
|
||||||
|
hdrs = ["gl_shader.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_texture",
|
||||||
|
srcs = ["gl_texture.cc"],
|
||||||
|
hdrs = ["gl_texture.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gl_sync",
|
||||||
|
srcs = ["gl_sync.cc"],
|
||||||
|
hdrs = ["gl_sync.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gpu_info",
|
||||||
|
srcs = ["gpu_info.cc"],
|
||||||
|
hdrs = ["gpu_info.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_errors",
|
||||||
|
":portable",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "metadata_cc_fbs",
|
||||||
|
srcs = ["metadata.fbs"],
|
||||||
|
includes = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:common_cc_fbs_includes",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:workgroups_cc_fbs_includes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "node_shader",
|
||||||
|
hdrs = ["node_shader.h"],
|
||||||
|
deps = [
|
||||||
|
":compiler_options",
|
||||||
|
":gpu_info",
|
||||||
|
":object",
|
||||||
|
":uniform_parameter",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "object",
|
||||||
|
hdrs = ["object.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "object_manager",
|
||||||
|
srcs = ["object_manager.cc"],
|
||||||
|
hdrs = ["object_manager.h"],
|
||||||
|
deps = [
|
||||||
|
":gl_buffer",
|
||||||
|
":gl_texture",
|
||||||
|
":stats",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "portable",
|
||||||
|
hdrs = [
|
||||||
|
"portable_egl.h",
|
||||||
|
"portable_gl31.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "runtime",
|
||||||
|
srcs = ["runtime.cc"],
|
||||||
|
hdrs = ["runtime.h"],
|
||||||
|
deps = [
|
||||||
|
":command_queue",
|
||||||
|
":gl_buffer",
|
||||||
|
":gl_call",
|
||||||
|
":gl_errors",
|
||||||
|
":gl_program",
|
||||||
|
":gl_shader",
|
||||||
|
":gl_texture",
|
||||||
|
":gpu_info",
|
||||||
|
":object",
|
||||||
|
":object_manager",
|
||||||
|
":portable",
|
||||||
|
":runtime_options",
|
||||||
|
":stats",
|
||||||
|
":uniform_parameter",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl/runtime:shared_buffer",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "runtime_options",
|
||||||
|
hdrs = ["runtime_options.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "serialization",
|
||||||
|
srcs = ["serialization.cc"],
|
||||||
|
hdrs = ["serialization.h"],
|
||||||
|
deps = [
|
||||||
|
":common_cc_fbs",
|
||||||
|
":compiled_model_cc_fbs",
|
||||||
|
":object",
|
||||||
|
":uniform_parameter",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
"@flatbuffers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "serialization_test",
|
||||||
|
srcs = ["serialization_test.cc"],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":object",
|
||||||
|
":serialization",
|
||||||
|
":uniform_parameter",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "stats",
|
||||||
|
hdrs = ["stats.h"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "uniform_parameter",
|
||||||
|
hdrs = ["uniform_parameter.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "workgroups_cc_fbs",
|
||||||
|
srcs = ["workgroups.fbs"],
|
||||||
|
includes = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:common_cc_fbs_includes",
|
||||||
|
],
|
||||||
|
)
|
418
tensorflow/lite/delegates/gpu/gl/api.cc
Normal file
418
tensorflow/lite/delegates/gpu/gl/api.cc
Normal file
@ -0,0 +1,418 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/api.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <deque>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/runtime.h"
|
||||||
|
|
||||||
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/serialization.h"
|
||||||
|
#endif // TFLITE_GPU_BINARY_RELEASE
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ObjectsSizes = std::unordered_map<ValueId, size_t>;
|
||||||
|
|
||||||
|
enum class InferenceContextState {
|
||||||
|
NOT_STARTED,
|
||||||
|
IN_PROGRESS,
|
||||||
|
};
|
||||||
|
|
||||||
|
class InferenceContextImpl : public InferenceContext {
|
||||||
|
public:
|
||||||
|
explicit InferenceContextImpl(std::unique_ptr<Runtime> runtime)
|
||||||
|
: runtime_(std::move(runtime)) {}
|
||||||
|
|
||||||
|
Status Execute() final {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
if (state_ != InferenceContextState::NOT_STARTED) {
|
||||||
|
return FailedPreconditionError("InferenceContext is not reset");
|
||||||
|
}
|
||||||
|
state_ = InferenceContextState::IN_PROGRESS;
|
||||||
|
return runtime_->Execute();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Reset() final {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
// TODO(akulik): should Reset not return Status?
|
||||||
|
state_ = InferenceContextState::NOT_STARTED;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
RuntimeStats stats() const final { return runtime_->stats(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<Runtime> runtime_;
|
||||||
|
|
||||||
|
mutable std::mutex guard_;
|
||||||
|
InferenceContextState state_ = InferenceContextState::NOT_STARTED;
|
||||||
|
};
|
||||||
|
|
||||||
|
class InferenceContextWithBatchImpl : public InferenceContext {
|
||||||
|
public:
|
||||||
|
InferenceContextWithBatchImpl(const ObjectsSizes& sizes,
|
||||||
|
const ObjectManager* objects,
|
||||||
|
std::unique_ptr<ObjectManager> refs,
|
||||||
|
std::unique_ptr<Runtime> runtime)
|
||||||
|
: sizes_(sizes),
|
||||||
|
objects_(objects),
|
||||||
|
refs_(std::move(refs)),
|
||||||
|
runtime_(std::move(runtime)) {}
|
||||||
|
|
||||||
|
Status Execute() final {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
if (state_ != InferenceContextState::NOT_STARTED) {
|
||||||
|
return FailedPreconditionError("InferenceContext is not reset");
|
||||||
|
}
|
||||||
|
state_ = InferenceContextState::IN_PROGRESS;
|
||||||
|
|
||||||
|
// Calculate expected number of batches and check that all external objects
|
||||||
|
// match that number.
|
||||||
|
int num_batches = 0;
|
||||||
|
for (const auto& s : sizes_) {
|
||||||
|
const ValueId id = s.first;
|
||||||
|
const size_t byte_size = s.second;
|
||||||
|
|
||||||
|
auto buffer = objects_->FindBuffer(id);
|
||||||
|
if (!buffer) continue;
|
||||||
|
|
||||||
|
if (buffer->bytes_size() % byte_size) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"Object ", id, " does not match expected byte size: ", byte_size));
|
||||||
|
}
|
||||||
|
size_t b = buffer->bytes_size() / byte_size;
|
||||||
|
if (num_batches == 0) {
|
||||||
|
num_batches = b;
|
||||||
|
} else {
|
||||||
|
if (num_batches != b) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"Object ", id, " size does not match expected batch size: ", b,
|
||||||
|
" vs ", num_batches));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t b = 0; b < num_batches; ++b) {
|
||||||
|
// slice external objects by batch.
|
||||||
|
for (const auto& s : sizes_) {
|
||||||
|
const ValueId id = s.first;
|
||||||
|
const size_t byte_size = s.second;
|
||||||
|
auto buffer = objects_->FindBuffer(id);
|
||||||
|
if (buffer) {
|
||||||
|
auto ref = refs_->FindBuffer(id);
|
||||||
|
if (!ref) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
absl::StrCat("Reference to ", id, " is not found"));
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(buffer->MakeView(b * byte_size, byte_size, ref));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(runtime_->Execute());
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Reset() final {
|
||||||
|
std::lock_guard<std::mutex> lock(guard_);
|
||||||
|
state_ = InferenceContextState::NOT_STARTED;
|
||||||
|
// TODO(akulik): should Reset not return Status?
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
RuntimeStats stats() const final { return runtime_->stats(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ObjectsSizes sizes_;
|
||||||
|
const ObjectManager* objects_;
|
||||||
|
|
||||||
|
// view over external objects provided by a user.
|
||||||
|
std::unique_ptr<ObjectManager> refs_;
|
||||||
|
std::unique_ptr<Runtime> runtime_;
|
||||||
|
|
||||||
|
mutable std::mutex guard_;
|
||||||
|
InferenceContextState state_ = InferenceContextState::NOT_STARTED;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ProgramParameters {
|
||||||
|
// A list of uniform parameters to be set.
|
||||||
|
std::vector<UniformParameter> parameters;
|
||||||
|
|
||||||
|
// A list of objects to bind to opengl program.
|
||||||
|
std::vector<Object> objects;
|
||||||
|
|
||||||
|
uint3 workgroup_size;
|
||||||
|
uint3 num_workgroups;
|
||||||
|
|
||||||
|
size_t shader_idx;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string GetShaderHeader(uint3 localsize) {
|
||||||
|
return absl::StrCat("#version 310 es\nlayout(local_size_x = ", localsize.x,
|
||||||
|
", local_size_y = ", localsize.y,
|
||||||
|
", local_size_z = ", localsize.z, ") in;\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
class CompiledModelImpl
|
||||||
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
|
: public CompiledModel,
|
||||||
|
public DeserializationHandler {
|
||||||
|
#else
|
||||||
|
: public CompiledModel {
|
||||||
|
#endif // TFLITE_GPU_BINARY_RELEASE
|
||||||
|
public:
|
||||||
|
explicit CompiledModelImpl(const GpuInfo& gpu_info) : gpu_info_(gpu_info) {}
|
||||||
|
|
||||||
|
// Called while compiling shaders from scratch
|
||||||
|
Status Add(const WorkgroupsCalculator& workgroup_calculator,
|
||||||
|
ShaderCode code) {
|
||||||
|
// Calculate workgroup size.
|
||||||
|
uint3 workgroup_size = workgroup_calculator.Calculate(code);
|
||||||
|
uint3 num_workgroups = IntegralDivideRoundUp(code.workload, workgroup_size);
|
||||||
|
|
||||||
|
for (const auto& object : code.objects) {
|
||||||
|
if (IsRef(object)) {
|
||||||
|
object_sizes_[GetRef(object)] = ByteSizeOf(object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store full shader and compile it if necessary.
|
||||||
|
size_t shader_idx;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
AddFullShader(code.source_code, workgroup_size, &shader_idx));
|
||||||
|
programs_.push_back({
|
||||||
|
std::move(code.parameters),
|
||||||
|
std::move(code.objects),
|
||||||
|
workgroup_size,
|
||||||
|
num_workgroups,
|
||||||
|
shader_idx,
|
||||||
|
});
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store full shader and compile it if necessary.
|
||||||
|
// Returns full_shader_index
|
||||||
|
Status AddFullShader(const std::string& partial_shader,
|
||||||
|
const uint3& workgroup_size, size_t* size) {
|
||||||
|
std::string shader_src = GetShaderHeader(workgroup_size) + partial_shader;
|
||||||
|
auto it = shader_to_index_.find(shader_src);
|
||||||
|
if (it == shader_to_index_.end()) {
|
||||||
|
GlShader shader;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src, &shader));
|
||||||
|
shaders_.push_back(std::move(shader));
|
||||||
|
shader_to_index_.insert({shader_src, shader_to_index_.size()});
|
||||||
|
*size = shader_to_index_.size() - 1;
|
||||||
|
} else {
|
||||||
|
*size = it->second;
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NewRun(
|
||||||
|
const RuntimeOptions& options, const ObjectManager* objects,
|
||||||
|
CommandQueue* command_queue,
|
||||||
|
std::unique_ptr<InferenceContext>* inference_context) const final {
|
||||||
|
std::unique_ptr<ObjectManager> refs;
|
||||||
|
if (dynamic_batch_) {
|
||||||
|
// Runtime is using objects from refs that will point to provided objects.
|
||||||
|
// At this point just create 0 batch slice references.
|
||||||
|
refs = absl::make_unique<ObjectManager>();
|
||||||
|
for (const auto& s : object_sizes_) {
|
||||||
|
auto buffer = objects->FindBuffer(s.first);
|
||||||
|
if (!buffer) continue;
|
||||||
|
GlBuffer ref;
|
||||||
|
RETURN_IF_ERROR(buffer->MakeView(0, s.second, &ref));
|
||||||
|
RETURN_IF_ERROR(refs->RegisterBuffer(s.first, std::move(ref)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto runtime = absl::make_unique<Runtime>(options, gpu_info_, command_queue,
|
||||||
|
(refs ? refs.get() : objects));
|
||||||
|
for (auto& c : programs_) {
|
||||||
|
RETURN_IF_ERROR(runtime->AddProgram(shaders_[c.shader_idx], c.parameters,
|
||||||
|
c.objects, c.num_workgroups));
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(runtime->PrepareForExecution());
|
||||||
|
if (dynamic_batch_) {
|
||||||
|
*inference_context = absl::make_unique<InferenceContextWithBatchImpl>(
|
||||||
|
object_sizes_, objects, std::move(refs), std::move(runtime));
|
||||||
|
} else {
|
||||||
|
*inference_context =
|
||||||
|
absl::make_unique<InferenceContextImpl>(std::move(runtime));
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
|
// Called on deserialization
|
||||||
|
Status OnProgram(const std::vector<UniformParameter>& parameters,
|
||||||
|
const std::vector<Object>& objects,
|
||||||
|
const uint3& workgroup_size, const uint3& num_workgroups,
|
||||||
|
size_t partial_shader_index) final {
|
||||||
|
for (auto& object : objects) {
|
||||||
|
if (IsRef(object)) {
|
||||||
|
object_sizes_[GetRef(object)] = ByteSizeOf(object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t shader_idx;
|
||||||
|
RETURN_IF_ERROR(AddFullShader(partial_shaders_[partial_shader_index],
|
||||||
|
workgroup_size, &shader_idx));
|
||||||
|
programs_.push_back({
|
||||||
|
parameters,
|
||||||
|
objects,
|
||||||
|
workgroup_size,
|
||||||
|
num_workgroups,
|
||||||
|
shader_idx,
|
||||||
|
});
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Serialize(
|
||||||
|
std::vector<uint8_t>* serialized_compiled_model) const final {
|
||||||
|
SerializedCompiledModelBuilder builder;
|
||||||
|
|
||||||
|
// sort shaders first. They need to be serialized in order.
|
||||||
|
std::vector<std::string> full_shaders(shaders_.size());
|
||||||
|
for (const auto& shader : shader_to_index_) {
|
||||||
|
full_shaders[shader.second] = shader.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, size_t> partial_shader_to_index;
|
||||||
|
std::vector<std::string> partial_shaders;
|
||||||
|
for (const auto& program : programs_) {
|
||||||
|
// Remove a header from a shader.
|
||||||
|
std::string shader_without_header = full_shaders[program.shader_idx];
|
||||||
|
shader_without_header.erase(0, shader_without_header.find("in;") + 3);
|
||||||
|
|
||||||
|
// Insert shader into partial shaders array.
|
||||||
|
auto it = partial_shader_to_index.find(shader_without_header);
|
||||||
|
size_t shader_idx;
|
||||||
|
if (it == partial_shader_to_index.end()) {
|
||||||
|
shader_idx = partial_shaders.size();
|
||||||
|
partial_shaders.push_back(shader_without_header);
|
||||||
|
builder.AddShader(shader_without_header);
|
||||||
|
partial_shader_to_index.insert({shader_without_header, shader_idx});
|
||||||
|
} else {
|
||||||
|
shader_idx = it->second;
|
||||||
|
}
|
||||||
|
builder.AddProgram(program.parameters, program.objects,
|
||||||
|
program.workgroup_size, program.num_workgroups,
|
||||||
|
shader_idx);
|
||||||
|
}
|
||||||
|
CompiledModelOptions options;
|
||||||
|
options.dynamic_batch = dynamic_batch_;
|
||||||
|
auto data = builder.Finalize(options);
|
||||||
|
serialized_compiled_model->insert(serialized_compiled_model->end(),
|
||||||
|
data.begin(), data.end());
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OnShader(absl::Span<const char> shader_src) final {
|
||||||
|
std::string source(shader_src.data(), shader_src.size());
|
||||||
|
partial_shaders_.push_back(source);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnOptions(const CompiledModelOptions& options) final {
|
||||||
|
dynamic_batch_ = options.dynamic_batch;
|
||||||
|
}
|
||||||
|
#endif // TFLITE_GPU_BINARY_RELEASE
|
||||||
|
|
||||||
|
CompilerStats stats() const final { return stats_; }
|
||||||
|
|
||||||
|
void set_dynamic_batch(bool dynamic_batch) { dynamic_batch_ = dynamic_batch; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const GpuInfo gpu_info_;
|
||||||
|
bool dynamic_batch_ = false;
|
||||||
|
|
||||||
|
std::vector<std::string> partial_shaders_;
|
||||||
|
std::vector<GlShader> shaders_;
|
||||||
|
|
||||||
|
// Shaders are serialized in order of their indices.
|
||||||
|
std::unordered_map<std::string, size_t> shader_to_index_;
|
||||||
|
std::deque<ProgramParameters> programs_;
|
||||||
|
std::unordered_map<ValueId, size_t> object_sizes_;
|
||||||
|
CompilerStats stats_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// @return true if all tensors have same batch value.
|
||||||
|
bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
|
||||||
|
int32_t b = model.values()[0]->tensor.shape.b;
|
||||||
|
for (auto value : model.values()) {
|
||||||
|
if (value->tensor.shape.b != b) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status Compile(const CompilationOptions& options, const GraphFloat32& model,
|
||||||
|
const NodeShader& node_shader,
|
||||||
|
const WorkgroupsCalculator& workgroup_calculator,
|
||||||
|
std::unique_ptr<CompiledModel>* compiled_model) {
|
||||||
|
if (!IsBatchMatchesForAllValues(model)) {
|
||||||
|
return InvalidArgumentError("Only identical batch dimension is supported");
|
||||||
|
}
|
||||||
|
GpuInfo gpu_info;
|
||||||
|
RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
|
||||||
|
auto compiled_model_impl = absl::make_unique<CompiledModelImpl>(gpu_info);
|
||||||
|
compiled_model_impl->set_dynamic_batch(options.dynamic_batch);
|
||||||
|
auto compiler = NewCompiler(&node_shader, &gpu_info, options);
|
||||||
|
RETURN_IF_ERROR(compiler->Compile(model, [&](ShaderCode code) -> Status {
|
||||||
|
return compiled_model_impl->Add(workgroup_calculator, std::move(code));
|
||||||
|
}));
|
||||||
|
*compiled_model = std::move(compiled_model_impl);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
|
Status ReadSerializedModel(const std::vector<uint8_t>& serialized_model,
|
||||||
|
std::unique_ptr<CompiledModel>* compiled_model) {
|
||||||
|
GpuInfo gpu_info;
|
||||||
|
RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
|
||||||
|
auto compiled_model_impl = absl::make_unique<CompiledModelImpl>(gpu_info);
|
||||||
|
RETURN_IF_ERROR(DeserializeCompiledModel(
|
||||||
|
absl::MakeConstSpan(serialized_model), compiled_model_impl.get()));
|
||||||
|
*compiled_model = std::move(compiled_model_impl);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
#endif // TFLITE_GPU_BINARY_RELEASE
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
103
tensorflow/lite/delegates/gpu/gl/api.h
Normal file
103
tensorflow/lite/delegates/gpu/gl/api.h
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_API_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_API_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object_manager.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/runtime_options.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/stats.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
class InferenceContext;
|
||||||
|
|
||||||
|
// Represents a model that was prepared for execution. It is stored in a format
|
||||||
|
// most suitable for execution and optionally may include pre-generated or
|
||||||
|
// pre-compiled GPU shaders or whatever is needed for efficient execution.
|
||||||
|
class CompiledModel {
|
||||||
|
public:
|
||||||
|
virtual ~CompiledModel() = default;
|
||||||
|
|
||||||
|
virtual CompilerStats stats() const = 0;
|
||||||
|
|
||||||
|
// Creates new inference context. Result can outlive @this.
|
||||||
|
//
|
||||||
|
// NewRun call as well as subsequent calls to InferenceContext methods should
|
||||||
|
// be done from the same EGL context.
|
||||||
|
virtual Status NewRun(
|
||||||
|
const RuntimeOptions& options, const ObjectManager* objects,
|
||||||
|
CommandQueue* command_queue,
|
||||||
|
std::unique_ptr<InferenceContext>* inference_context) const = 0;
|
||||||
|
|
||||||
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
|
// Serializes compiled model to a string.
|
||||||
|
// @return true if serialization finished successfully.
|
||||||
|
virtual Status Serialize(
|
||||||
|
std::vector<uint8_t>* serialized_compiled_model) const = 0;
|
||||||
|
#endif // TFLITE_GPU_BINARY_RELEASE
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turns the given model into "compiled" form that is suitable for inference.
|
||||||
|
Status Compile(const CompilationOptions& options, const GraphFloat32& model,
|
||||||
|
const NodeShader& node_shader,
|
||||||
|
const WorkgroupsCalculator& workgroup_calculator,
|
||||||
|
std::unique_ptr<CompiledModel>* compiled_model);
|
||||||
|
|
||||||
|
#ifndef TFLITE_GPU_BINARY_RELEASE
|
||||||
|
// Reads serialized representation previously created with
|
||||||
|
// CompiledModel::Serialize call.
|
||||||
|
Status ReadSerializedModel(const std::vector<uint8_t>& serialized_model,
|
||||||
|
std::unique_ptr<CompiledModel>* compiled_model);
|
||||||
|
#endif // TFLITE_GPU_BINARY_RELEASE
|
||||||
|
|
||||||
|
// Encapsulates everything needed for one or more inference executions done
|
||||||
|
// sequentially.
|
||||||
|
//
|
||||||
|
// Thread-safe.
|
||||||
|
class InferenceContext {
|
||||||
|
public:
|
||||||
|
virtual ~InferenceContext() = default;
|
||||||
|
|
||||||
|
virtual RuntimeStats stats() const = 0;
|
||||||
|
|
||||||
|
// Executes inference.
|
||||||
|
virtual Status Execute() = 0;
|
||||||
|
|
||||||
|
// Asks context to reset it for another round. Keep in mind that does not
|
||||||
|
// affect inputs nor outputs which are not cleared, so it is possible to
|
||||||
|
// re-use them.
|
||||||
|
// It is an error to call Reset while previous run is still in progress.
|
||||||
|
virtual Status Reset() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_API_H_
|
85
tensorflow/lite/delegates/gpu/gl/command_queue.cc
Normal file
85
tensorflow/lite/delegates/gpu/gl/command_queue.cc
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_sync.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class DefaultCommandQueue : public CommandQueue {
|
||||||
|
public:
|
||||||
|
Status Dispatch(const GlProgram& program, const uint3& workgroups) override {
|
||||||
|
RETURN_IF_ERROR(program.Dispatch(workgroups));
|
||||||
|
return TFLITE_GPU_CALL_GL(glMemoryBarrier, GL_ALL_BARRIER_BITS);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status WaitForCompletion() override {
|
||||||
|
// TODO(akulik): may be let a user to choose what wait method to use.
|
||||||
|
return GlActiveSyncWait();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// On Adreno do flush periodically as this affects performance. Command queue
|
||||||
|
// needs to be manually managed to ensure that accumulated work goes to GPU as
|
||||||
|
// fast as it can.
|
||||||
|
//
|
||||||
|
// Also, on older Adreno devices glFlush is required after every memory barrier
|
||||||
|
// to avoid hitting GPU driver bug.
|
||||||
|
class AdrenoCommandQueue : public DefaultCommandQueue {
|
||||||
|
public:
|
||||||
|
explicit AdrenoCommandQueue(int flush_every_n)
|
||||||
|
: flush_every_n_(flush_every_n) {}
|
||||||
|
|
||||||
|
Status Dispatch(const GlProgram& program, const uint3& workgroups) final {
|
||||||
|
RETURN_IF_ERROR(DefaultCommandQueue::Dispatch(program, workgroups));
|
||||||
|
if ((++program_counter_ % flush_every_n_) == 0) {
|
||||||
|
glFlush();
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int flush_every_n_;
|
||||||
|
int program_counter_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<CommandQueue> NewCommandQueue(const GpuInfo& gpu_info) {
|
||||||
|
if (gpu_info.type == GpuType::ADRENO) {
|
||||||
|
int flush_every_n = 1;
|
||||||
|
// On Adreno 630 and Adreno 505 there is up to 2x performance boost when
|
||||||
|
// glFlush happens not so often.
|
||||||
|
if (gpu_info.gpu_model == GpuModel::ADRENO630 ||
|
||||||
|
gpu_info.gpu_model == GpuModel::ADRENO505) {
|
||||||
|
flush_every_n = 10;
|
||||||
|
}
|
||||||
|
return absl::make_unique<AdrenoCommandQueue>(flush_every_n);
|
||||||
|
}
|
||||||
|
return absl::make_unique<DefaultCommandQueue>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
52
tensorflow/lite/delegates/gpu/gl/command_queue.h
Normal file
52
tensorflow/lite/delegates/gpu/gl/command_queue.h
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMMAND_QUEUE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMMAND_QUEUE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// GL programs can be executed directly via dispatch call or using a queue
|
||||||
|
// abstraction similar to one in OpenCL and Vulkan.
|
||||||
|
// CommandQueue executes given programs in order as they come.
|
||||||
|
class CommandQueue {
|
||||||
|
public:
|
||||||
|
virtual ~CommandQueue() = default;
|
||||||
|
|
||||||
|
// Dispatches a program. It may or may not call glFlush.
|
||||||
|
virtual Status Dispatch(const GlProgram& program,
|
||||||
|
const uint3& workgroups) = 0;
|
||||||
|
|
||||||
|
// Waits until all programs dispatched prior this call are completed.
|
||||||
|
virtual Status WaitForCompletion() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// By default memory barrier is inserted after every dispatch.
|
||||||
|
std::unique_ptr<CommandQueue> NewCommandQueue(const GpuInfo& gpu_info);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMMAND_QUEUE_H_
|
16
tensorflow/lite/delegates/gpu/gl/common.fbs
Normal file
16
tensorflow/lite/delegates/gpu/gl/common.fbs
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
namespace tflite.gpu.gl.data;
|
||||||
|
|
||||||
|
table Uint3 {
|
||||||
|
x:uint32;
|
||||||
|
y:uint32;
|
||||||
|
z:uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Uint2 {
|
||||||
|
x:uint32;
|
||||||
|
y:uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Uint1 {
|
||||||
|
x:uint32;
|
||||||
|
}
|
155
tensorflow/lite/delegates/gpu/gl/compiled_model.fbs
Normal file
155
tensorflow/lite/delegates/gpu/gl/compiled_model.fbs
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
include "common.fbs";
|
||||||
|
|
||||||
|
namespace tflite.gpu.gl.data;
|
||||||
|
|
||||||
|
file_identifier "AFCM";
|
||||||
|
|
||||||
|
file_extension "flow";
|
||||||
|
|
||||||
|
// Encapsulates entire OpenGL program with all necessary dependencies and
|
||||||
|
// parameters.
|
||||||
|
table Program {
|
||||||
|
// A collection of objects this program refers to.
|
||||||
|
objects:[Object];
|
||||||
|
|
||||||
|
// Uniform parameters to be set before execution.
|
||||||
|
parameters:[UniformParameter];
|
||||||
|
|
||||||
|
// Defines the number of work groups.
|
||||||
|
number_workgroups:Uint3;
|
||||||
|
|
||||||
|
// Defines the size of a workgroup.
|
||||||
|
workgroup_size:Uint3;
|
||||||
|
|
||||||
|
// Reference to a shader in this compiled model.
|
||||||
|
shader_index:uint32;
|
||||||
|
|
||||||
|
// Contains binary code that was once created after successful shader
|
||||||
|
// compilation. Normally it is much faster to instantiate a program from
|
||||||
|
// compiled binary.
|
||||||
|
binary:ProgramBinary;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compiled binary representation of a program.
|
||||||
|
table ProgramBinary {
|
||||||
|
format:uint32; // GLenum
|
||||||
|
|
||||||
|
// Compiled binary shader blob extracted from GL.
|
||||||
|
binary:[ubyte];
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ParameterType : byte {
|
||||||
|
INT32 = 0,
|
||||||
|
UINT32 = 1,
|
||||||
|
FLOAT32 = 2,
|
||||||
|
INT32_2 = 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum DataType : byte {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
FLOAT32 = 1,
|
||||||
|
FLOAT16 = 2,
|
||||||
|
INT32 = 3,
|
||||||
|
INT16 = 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
union DataVariant {
|
||||||
|
DataInt32,
|
||||||
|
DataFloat,
|
||||||
|
DataUint32,
|
||||||
|
}
|
||||||
|
|
||||||
|
table DataFloat {
|
||||||
|
data:[float];
|
||||||
|
}
|
||||||
|
|
||||||
|
table DataInt32 {
|
||||||
|
data:[int32];
|
||||||
|
}
|
||||||
|
|
||||||
|
table DataUint32 {
|
||||||
|
data:[uint32];
|
||||||
|
}
|
||||||
|
|
||||||
|
table UniformParameter {
|
||||||
|
name:string;
|
||||||
|
|
||||||
|
type:ParameterType;
|
||||||
|
|
||||||
|
// Data is optional. If it is known in advance, it is encoded here, otherwise
|
||||||
|
// a parameter will be set in runtime.
|
||||||
|
data:DataVariant;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum AccessType : byte {
|
||||||
|
READ = 0,
|
||||||
|
WRITE = 1,
|
||||||
|
READ_WRITE = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ObjectType : byte {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
BUFFER = 1,
|
||||||
|
TEXTURE = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
union ObjectVariant {
|
||||||
|
ObjectData,
|
||||||
|
ObjectRef,
|
||||||
|
}
|
||||||
|
|
||||||
|
union ObjectSize {
|
||||||
|
Uint1,
|
||||||
|
Uint2,
|
||||||
|
Uint3,
|
||||||
|
}
|
||||||
|
|
||||||
|
table Object {
|
||||||
|
access:AccessType;
|
||||||
|
|
||||||
|
binding:uint32;
|
||||||
|
|
||||||
|
data_type:DataType;
|
||||||
|
|
||||||
|
type:ObjectType;
|
||||||
|
|
||||||
|
size:ObjectSize;
|
||||||
|
|
||||||
|
object:ObjectVariant;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents a reference to another object provided by object manager.
|
||||||
|
table ObjectRef {
|
||||||
|
// Unique global identifier to be used by an object manager to lookup this
|
||||||
|
// buffer.
|
||||||
|
global_id:uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ObjectData {
|
||||||
|
data:[uint8];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents entire model as a collection of programs, inputs and outputs.
|
||||||
|
table CompiledModel {
|
||||||
|
parameters:Parameters;
|
||||||
|
|
||||||
|
// A collection of shaders used by programs.
|
||||||
|
shaders:[string];
|
||||||
|
|
||||||
|
// A collection of programs that need to be executed in the same order.
|
||||||
|
programs:[Program];
|
||||||
|
}
|
||||||
|
|
||||||
|
table Parameters {
|
||||||
|
// indicated flow engine version that compiled this model. If engine version
|
||||||
|
// does not match compiled model, then a model need to be recompiled.
|
||||||
|
// version:uint32; // not implemented
|
||||||
|
|
||||||
|
// Could potentially be used to track environment when a model was compiled
|
||||||
|
// and detect whether it was changed and model recompilation is needed.
|
||||||
|
// environment_hash:uint32; // not implemented
|
||||||
|
|
||||||
|
dynamic_batch:bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
root_type CompiledModel;
|
295
tensorflow/lite/delegates/gpu/gl/compiler.cc
Normal file
295
tensorflow/lite/delegates/gpu/gl/compiler.cc
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/float16_conversions.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ExceedSizeChecker {
|
||||||
|
bool operator()(uint32_t v) const { return v > max_size; }
|
||||||
|
|
||||||
|
bool operator()(const uint2& v) const {
|
||||||
|
return v.x > max_size || v.y > max_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator()(const uint3& v) const {
|
||||||
|
return v.x > max_size || v.y > max_size || v.z > max_z_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
int max_size;
|
||||||
|
int max_z_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns true if any size variable exceeds the given limit
|
||||||
|
bool ExceedsMaxSize(const Object& object, const GpuInfo& gpu_info) {
|
||||||
|
return absl::visit(ExceedSizeChecker{gpu_info.max_texture_size,
|
||||||
|
gpu_info.max_array_texture_layers},
|
||||||
|
object.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
ObjectType ChooseFastestObjectType(const GpuInfo& gpu_info) {
|
||||||
|
return gpu_info.type == GpuType::ADRENO ? ObjectType::TEXTURE
|
||||||
|
: ObjectType::BUFFER;
|
||||||
|
}
|
||||||
|
|
||||||
|
ObjectType ChooseFastestRefObjectType(const GpuInfo& gpu_info,
|
||||||
|
const CompilationOptions& options) {
|
||||||
|
if (gpu_info.type != GpuType::ADRENO) {
|
||||||
|
return ObjectType::BUFFER;
|
||||||
|
}
|
||||||
|
switch (gpu_info.gpu_model) {
|
||||||
|
case GpuModel::ADRENO630:
|
||||||
|
return ObjectType::TEXTURE;
|
||||||
|
default:
|
||||||
|
return options.allow_precision_loss ? ObjectType::TEXTURE
|
||||||
|
: ObjectType::BUFFER;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compiler executes the following steps:
|
||||||
|
// 1. Runs NodeShader for every node in the input graph.
|
||||||
|
// 2. Creates a compiled graph that mirrors the input graph and keeps
|
||||||
|
// GeneratedCode in operation's attributes.
|
||||||
|
// 3. Fuses nodes in the compiled graph.
|
||||||
|
// 4. Generates the full shader code using the nodes in the compiled graph.
|
||||||
|
class CompilerImpl : public Compiler {
|
||||||
|
public:
|
||||||
|
// We use const GpuInfo* because it doesn't let you assign temporary object
|
||||||
|
CompilerImpl(const NodeShader* node_shader, const GpuInfo* gpu_info,
|
||||||
|
const CompilationOptions& options)
|
||||||
|
: node_shader_(*node_shader), gpu_info_(*gpu_info), options_(options) {
|
||||||
|
if (options_.preferred_obj_type == ObjectType::UNKNOWN) {
|
||||||
|
options_.preferred_obj_type = ChooseFastestObjectType(*gpu_info);
|
||||||
|
}
|
||||||
|
if (options_.ref_obj_type == ObjectType::UNKNOWN) {
|
||||||
|
options_.ref_obj_type = ChooseFastestRefObjectType(*gpu_info, options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Compile(const GraphFloat32& graph,
|
||||||
|
const ShaderCodeCallback& callback) final {
|
||||||
|
// It is important to have ids in a compiled graph identical to the given
|
||||||
|
// graph.
|
||||||
|
RETURN_IF_ERROR(graph.MakeExactCopy(&compiled_graph_));
|
||||||
|
|
||||||
|
// Clear out batch dimension for dynamic batch support.
|
||||||
|
if (options_.dynamic_batch) {
|
||||||
|
for (auto value : compiled_graph_.values()) {
|
||||||
|
value->tensor.shape.b = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a shader for a node and all input/output objects.
|
||||||
|
for (auto node : compiled_graph_.nodes()) {
|
||||||
|
CompiledNodeAttributes attr;
|
||||||
|
attr.node_indices.push_back(node->id);
|
||||||
|
RETURN_IF_ERROR(node_shader_.GenerateCode(
|
||||||
|
{&compiled_graph_, &gpu_info_, node, options_}, &attr.code));
|
||||||
|
node->operation.attributes = std::move(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelTransformer transformer(&compiled_graph_, nullptr);
|
||||||
|
if (options_.fuse_operations) {
|
||||||
|
FuseAutoOutputWithInline fuse_inline;
|
||||||
|
if (!transformer.Apply("fuse_auto_with_inline", &fuse_inline)) {
|
||||||
|
return InternalError("fuse_auto_with_inline failed");
|
||||||
|
}
|
||||||
|
FuseInplaceUpdate fuse_inplace;
|
||||||
|
if (!transformer.Apply("fuse_inplace_update", &fuse_inplace)) {
|
||||||
|
return InternalError("fuse_inplace failed");
|
||||||
|
}
|
||||||
|
if (options_.auto_input_fusion) {
|
||||||
|
FuseAutoInput fuse_auto_input;
|
||||||
|
if (!transformer.Apply("fuse_auto_input", &fuse_auto_input)) {
|
||||||
|
return InternalError("fuse_auto_input failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RemoveUnusedInplaceUpdates remove_inplace_updates;
|
||||||
|
if (!transformer.Apply("remove_inplace_updates", &remove_inplace_updates)) {
|
||||||
|
return InternalError("remove_inplace_updates failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare internal objects.
|
||||||
|
std::unordered_map<ValueId, Object> objects;
|
||||||
|
for (auto value : compiled_graph_.values()) {
|
||||||
|
Object object = MakePHWC4Ref(value->id, value->tensor.shape);
|
||||||
|
object.data_type = value->tensor.type;
|
||||||
|
// External references may not be upgraded to f16 nor be represented as
|
||||||
|
// textures.
|
||||||
|
bool is_external =
|
||||||
|
graph.IsGraphOutput(value->id) || graph.IsGraphInput(value->id);
|
||||||
|
if (is_external) {
|
||||||
|
object.object_type = ObjectType::BUFFER;
|
||||||
|
} else {
|
||||||
|
if (options_.allow_precision_loss) {
|
||||||
|
MaybeConvertToFloat16(&object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
objects[value->id] = std::move(object);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare readonly objects and check whether object types are supported.
|
||||||
|
for (auto node : compiled_graph_.nodes()) {
|
||||||
|
auto& attr =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
|
||||||
|
|
||||||
|
// Set workload explicitly.
|
||||||
|
if (attr.code.workload == uint3()) {
|
||||||
|
auto outputs = compiled_graph_.FindOutputs(node->id);
|
||||||
|
auto shape = outputs[0]->tensor.shape;
|
||||||
|
for (auto output : outputs) {
|
||||||
|
if (shape != output->tensor.shape) {
|
||||||
|
return FailedPreconditionError(
|
||||||
|
"Workload uint3() requires all output sizes to match");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr.code.workload =
|
||||||
|
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_textures = 0;
|
||||||
|
// Counts number of used textures and chooses ObjectType for an object.
|
||||||
|
auto set_object_type = [&](Object* object) {
|
||||||
|
if (object->object_type == ObjectType::BUFFER) {
|
||||||
|
// Don't change from buffer once it is set.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bool is_ref = IsRef(*object);
|
||||||
|
if (num_textures < gpu_info_.max_image_units &&
|
||||||
|
!ExceedsMaxSize(*object, gpu_info_) &&
|
||||||
|
(object->object_type == ObjectType::TEXTURE ||
|
||||||
|
(is_ref && options_.ref_obj_type == ObjectType::TEXTURE) ||
|
||||||
|
(!is_ref && options_.preferred_obj_type == ObjectType::TEXTURE))) {
|
||||||
|
object->object_type = ObjectType::TEXTURE;
|
||||||
|
num_textures++;
|
||||||
|
} else {
|
||||||
|
object->object_type = ObjectType::BUFFER;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& object : attr.code.objects) {
|
||||||
|
// Downgrade readonly objects to F16 is requested.
|
||||||
|
if (options_.allow_precision_loss) {
|
||||||
|
MaybeConvertToFloat16(&object.second);
|
||||||
|
}
|
||||||
|
set_object_type(&object.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto ref : compiled_graph_.FindInputs(node->id)) {
|
||||||
|
set_object_type(&objects[ref->id]);
|
||||||
|
}
|
||||||
|
for (auto ref : compiled_graph_.FindOutputs(node->id)) {
|
||||||
|
set_object_type(&objects[ref->id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate shaders from the transformed graph.
|
||||||
|
ShaderCodegen codegen(options_, gpu_info_);
|
||||||
|
for (auto node : compiled_graph_.nodes()) {
|
||||||
|
auto& attr =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
|
||||||
|
if (attr.code.source_code.empty()) {
|
||||||
|
// noop. Skip this node.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declare inputs and outputs explicitly.
|
||||||
|
for (auto ref : compiled_graph_.FindInputs(node->id)) {
|
||||||
|
auto object = objects[ref->id];
|
||||||
|
object.access = AccessType::READ;
|
||||||
|
attr.inputs.push_back(object);
|
||||||
|
}
|
||||||
|
for (auto ref : compiled_graph_.FindOutputs(node->id)) {
|
||||||
|
auto object = objects[ref->id];
|
||||||
|
object.access = AccessType::WRITE;
|
||||||
|
attr.outputs.push_back(object);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate bindings. Textures must be bound first. max_image_units also
|
||||||
|
// defines max binding number for a texture.
|
||||||
|
uint32_t binding = 0;
|
||||||
|
auto set_binding = [&](ObjectType type, Object& object) {
|
||||||
|
if (object.object_type == type) {
|
||||||
|
object.binding = binding++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto& object : attr.inputs) {
|
||||||
|
set_binding(ObjectType::TEXTURE, object);
|
||||||
|
}
|
||||||
|
for (auto& object : attr.outputs) {
|
||||||
|
set_binding(ObjectType::TEXTURE, object);
|
||||||
|
}
|
||||||
|
for (auto& object : attr.code.objects) {
|
||||||
|
set_binding(ObjectType::TEXTURE, object.second);
|
||||||
|
}
|
||||||
|
for (auto& object : attr.inputs) {
|
||||||
|
set_binding(ObjectType::BUFFER, object);
|
||||||
|
}
|
||||||
|
for (auto& object : attr.outputs) {
|
||||||
|
set_binding(ObjectType::BUFFER, object);
|
||||||
|
}
|
||||||
|
for (auto& object : attr.code.objects) {
|
||||||
|
set_binding(ObjectType::BUFFER, object.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate source code.
|
||||||
|
ShaderCode shader_code;
|
||||||
|
RETURN_IF_ERROR(codegen.Build(std::move(attr), &shader_code));
|
||||||
|
RETURN_IF_ERROR(callback(std::move(shader_code)));
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const NodeShader& node_shader_;
|
||||||
|
const GpuInfo& gpu_info_;
|
||||||
|
CompilationOptions options_;
|
||||||
|
GraphFloat32 compiled_graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Compiler> NewCompiler(const NodeShader* node_shader,
|
||||||
|
const GpuInfo* gpu_info,
|
||||||
|
const CompilationOptions& options) {
|
||||||
|
return absl::make_unique<CompilerImpl>(node_shader, gpu_info, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
54
tensorflow/lite/delegates/gpu/gl/compiler.h
Normal file
54
tensorflow/lite/delegates/gpu/gl/compiler.h
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
using ShaderCodeCallback = std::function<Status(ShaderCode code)>;
|
||||||
|
|
||||||
|
class Compiler {
|
||||||
|
public:
|
||||||
|
virtual ~Compiler() = default;
|
||||||
|
|
||||||
|
// Goes over a graph and generates OpenGL shaders for the given graph.
|
||||||
|
// Callback is called for every generated shader. Callback may execute shaders
|
||||||
|
// as they come or store them elsewhere to execute later.
|
||||||
|
virtual Status Compile(const GraphFloat32& graph,
|
||||||
|
const ShaderCodeCallback& callback) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<Compiler> NewCompiler(
|
||||||
|
const NodeShader* node_shader, const GpuInfo* gpu_info,
|
||||||
|
const CompilationOptions& options);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_H_
|
192
tensorflow/lite/delegates/gpu/gl/compiler/BUILD
Normal file
192
tensorflow/lite/delegates/gpu/gl/compiler/BUILD
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "preprocessor",
|
||||||
|
srcs = ["preprocessor.cc"],
|
||||||
|
hdrs = ["preprocessor.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "preprocessor_test",
|
||||||
|
srcs = ["preprocessor_test.cc"],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":preprocessor",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "parameter_accessor",
|
||||||
|
srcs = ["parameter_accessor.cc"],
|
||||||
|
hdrs = ["parameter_accessor.h"],
|
||||||
|
deps = [
|
||||||
|
":preprocessor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:uniform_parameter",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "parameter_accessor_test",
|
||||||
|
srcs = ["parameter_accessor_test.cc"],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":parameter_accessor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "object_accessor",
|
||||||
|
srcs = ["object_accessor.cc"],
|
||||||
|
hdrs = ["object_accessor.h"],
|
||||||
|
deps = [
|
||||||
|
":parameter_accessor",
|
||||||
|
":preprocessor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:object",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "object_accessor_test",
|
||||||
|
srcs = ["object_accessor_test.cc"],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":object_accessor",
|
||||||
|
":parameter_accessor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "shader_code",
|
||||||
|
hdrs = ["shader_code.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:object",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:uniform_parameter",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "shader_codegen",
|
||||||
|
srcs = ["shader_codegen.cc"],
|
||||||
|
hdrs = ["shader_codegen.h"],
|
||||||
|
deps = [
|
||||||
|
":compiled_node",
|
||||||
|
":object_accessor",
|
||||||
|
":parameter_accessor",
|
||||||
|
":preprocessor",
|
||||||
|
":shader_code",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:compiler_options",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gpu_info",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:object",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compiled_node",
|
||||||
|
srcs = ["compiled_node.cc"],
|
||||||
|
hdrs = ["compiled_node.h"],
|
||||||
|
deps = [
|
||||||
|
":rename",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:object",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fuse_inplace",
|
||||||
|
srcs = ["fuse_inplace.cc"],
|
||||||
|
hdrs = ["fuse_inplace.h"],
|
||||||
|
deps = [
|
||||||
|
":compiled_node",
|
||||||
|
":preprocessor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fuse_inline",
|
||||||
|
srcs = ["fuse_inline.cc"],
|
||||||
|
hdrs = ["fuse_inline.h"],
|
||||||
|
deps = [
|
||||||
|
":compiled_node",
|
||||||
|
":shader_code",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "rename",
|
||||||
|
srcs = ["rename.cc"],
|
||||||
|
hdrs = ["rename.h"],
|
||||||
|
deps = [
|
||||||
|
":object_accessor",
|
||||||
|
":parameter_accessor",
|
||||||
|
":preprocessor",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:object",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:uniform_parameter",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fuse_auto_input",
|
||||||
|
srcs = ["fuse_auto_input.cc"],
|
||||||
|
hdrs = ["fuse_auto_input.h"],
|
||||||
|
deps = [
|
||||||
|
":compiled_node",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
64
tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc
Normal file
64
tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
Status MergeCode(CompiledNodeAttributes* attr,
|
||||||
|
CompiledNodeAttributes* merged_attr) {
|
||||||
|
// build a map of known names.
|
||||||
|
std::unordered_set<std::string> known_names;
|
||||||
|
for (const auto& parameter : merged_attr->code.parameters) {
|
||||||
|
known_names.insert(parameter.name);
|
||||||
|
}
|
||||||
|
for (const auto& object : merged_attr->code.objects) {
|
||||||
|
known_names.insert(object.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite parameters with unique names.
|
||||||
|
int index =
|
||||||
|
merged_attr->code.parameters.size() + merged_attr->code.objects.size();
|
||||||
|
RETURN_IF_ERROR(Rename(
|
||||||
|
[&](absl::string_view name) -> std::string {
|
||||||
|
std::string n(name.begin(), name.end());
|
||||||
|
// if a name is unique, then keep it as is. Otherwise append an unique
|
||||||
|
// index.
|
||||||
|
if (known_names.find(n) == known_names.end()) {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
return absl::StrCat(n, index++);
|
||||||
|
},
|
||||||
|
&attr->code));
|
||||||
|
std::move(attr->code.objects.begin(), attr->code.objects.end(),
|
||||||
|
std::back_inserter(merged_attr->code.objects));
|
||||||
|
std::move(attr->code.parameters.begin(), attr->code.parameters.end(),
|
||||||
|
std::back_inserter(merged_attr->code.parameters));
|
||||||
|
std::move(attr->node_indices.begin(), attr->node_indices.end(),
|
||||||
|
std::back_inserter(merged_attr->node_indices));
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
52
tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h
Normal file
52
tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_COMPILED_NODE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_COMPILED_NODE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Contains compiler internal attributes for each node after it was processed by
|
||||||
|
// NodeShader.
|
||||||
|
struct CompiledNodeAttributes {
|
||||||
|
std::vector<Object> inputs;
|
||||||
|
std::vector<Object> outputs;
|
||||||
|
|
||||||
|
GeneratedCode code;
|
||||||
|
|
||||||
|
// nodes that are covered by the provided shader.
|
||||||
|
std::vector<NodeId> node_indices;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Moves all code objects, parameters and node indices from attr to merged_attr.
|
||||||
|
// Parameters and objects in attr.code.source_code are renamed to ensure
|
||||||
|
// uniqueness.
|
||||||
|
Status MergeCode(CompiledNodeAttributes* attr,
|
||||||
|
CompiledNodeAttributes* merged_attr);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_COMPILED_NODE_H_
|
228
tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc
Normal file
228
tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_replace.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::pair<std::string, std::string> MakeValueReplacement(int n, int k) {
|
||||||
|
return {absl::StrCat("value_", n), absl::StrCat("value_", k)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, std::string> MakeDataReplacement(int n, int k) {
|
||||||
|
return {absl::StrCat("input_data_", n), absl::StrCat("input_data_", k)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) {
|
||||||
|
auto& node_attr =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
|
||||||
|
auto& node_code = node_attr.code;
|
||||||
|
|
||||||
|
if (node_code.input != IOStructure::AUTO) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
uint3 workgroup = node_code.workgroup;
|
||||||
|
|
||||||
|
auto node_outputs = graph->FindOutputs(node->id);
|
||||||
|
|
||||||
|
// Check which inputs could be fused into the current node.
|
||||||
|
std::vector<std::pair<Node*, int>> nodes_to_fuse;
|
||||||
|
std::vector<std::pair<ValueId, int>> input_values;
|
||||||
|
int input_num = -1;
|
||||||
|
for (auto input_value : graph->FindInputs(node->id)) {
|
||||||
|
input_num++;
|
||||||
|
const ValueId input_id = input_value->id;
|
||||||
|
input_values.push_back({input_id, input_num});
|
||||||
|
|
||||||
|
if (graph->FindConsumers(input_id).size() > 1) {
|
||||||
|
continue; // input is consumed by >1 nodes
|
||||||
|
}
|
||||||
|
Node* input_producer = graph->FindProducer(input_id);
|
||||||
|
if (input_producer == nullptr) {
|
||||||
|
continue; // graph's input
|
||||||
|
}
|
||||||
|
if (graph->FindOutputs(input_producer->id).size() != 1) {
|
||||||
|
continue; // input node has more than one output
|
||||||
|
}
|
||||||
|
auto& input_producer_attr = absl::any_cast<const CompiledNodeAttributes&>(
|
||||||
|
input_producer->operation.attributes);
|
||||||
|
if (input_producer_attr.code.output != IOStructure::AUTO) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (input_producer_attr.code.workload != node_code.workload &&
|
||||||
|
uint3() != input_producer_attr.code.workload) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (input_producer_attr.code.workgroup != uint3()) {
|
||||||
|
// New fused node should fuse only a single shader that has pre-defined
|
||||||
|
// workgroup. Such shader is considered "heavy". Do not fuse two heavy
|
||||||
|
// shaders into one.
|
||||||
|
// TODO(eignasheva): make sure it still works.
|
||||||
|
if (workgroup != uint3()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
workgroup = input_producer_attr.code.workgroup;
|
||||||
|
}
|
||||||
|
nodes_to_fuse.push_back({input_producer, input_num});
|
||||||
|
input_values.pop_back(); // this value will not be used as input.
|
||||||
|
}
|
||||||
|
if (nodes_to_fuse.empty()) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Break connections between current node and its inputs.
|
||||||
|
for (auto value : graph->FindInputs(node->id)) {
|
||||||
|
if (!graph->RemoveConsumer(node->id, value->id).ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string operation_type;
|
||||||
|
std::string source_code;
|
||||||
|
std::string values;
|
||||||
|
|
||||||
|
// Node source code need to be appended later to the end.
|
||||||
|
std::swap(source_code, node_code.source_code);
|
||||||
|
|
||||||
|
// Indicates value_k that is beyond originally declared [0..n] values,
|
||||||
|
// therefore, it can be used by newly added dependencies.
|
||||||
|
int extra_input_num = input_num;
|
||||||
|
input_num = 0;
|
||||||
|
|
||||||
|
// Fuse all nodes into one.
|
||||||
|
for (auto input_and_num : nodes_to_fuse) {
|
||||||
|
auto& input = input_and_num.first;
|
||||||
|
auto& attr =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(input->operation.attributes);
|
||||||
|
auto super_inputs = graph->FindInputs(input->id);
|
||||||
|
|
||||||
|
// Replace all internal references in the input source code. For example:
|
||||||
|
// source code "value_0 = max(0, value_0);" will be rewritten into
|
||||||
|
// "value_2 = max(0, value_2);"
|
||||||
|
std::vector<std::pair<std::string, std::string>> replacements;
|
||||||
|
for (int i = 0; i < super_inputs.size(); ++i) {
|
||||||
|
// Node source code uses value_N to access output value from the fused
|
||||||
|
// node. Use correct reference.
|
||||||
|
//
|
||||||
|
// Here value_N does not correspond to input_N anymore. Instead it tracks
|
||||||
|
// value_n and input_m independently. Value_index uses an index needed
|
||||||
|
// for the "final" shader, while input_num preserves the order of inputs.
|
||||||
|
// For example:
|
||||||
|
// Shader A: input_0, input_1
|
||||||
|
// value_0 = value_0 > value_1 ? value_0 : value_1;
|
||||||
|
//
|
||||||
|
// Shader B: input_0
|
||||||
|
// value_0 = max(0, value_0);
|
||||||
|
//
|
||||||
|
// AddShader: input_0, input_1
|
||||||
|
// value_0 = value_0 + value_1;
|
||||||
|
//
|
||||||
|
// Fused shader is going to have 3 inputs: input_0 (A), input_1 (A),
|
||||||
|
// input_2 (B). But Shader B need to store result in value_1, because
|
||||||
|
// AddShader refers to it as 'value_1'. So, fused shader will look as
|
||||||
|
// follows:
|
||||||
|
//
|
||||||
|
// // Shader A
|
||||||
|
// vec4 value_0 = input_data_0.data[gid.x, gid.y, gid.z];
|
||||||
|
// vec4 value_2 = input_data_1.data[gid.x, gid.y, gid.z];
|
||||||
|
// value_0 = value_0 > value_2 ? value_0 : value_2;
|
||||||
|
//
|
||||||
|
// // Shader B
|
||||||
|
// vec4 value_1 = input_data_2.data[gid.x, gid.y, gid.z];
|
||||||
|
// value_1 = max(0, value_1);
|
||||||
|
//
|
||||||
|
// // AddShader
|
||||||
|
// value_0 = value_0 + value_1;
|
||||||
|
//
|
||||||
|
// output_data_0.data[gid.x, gid.y, gid.z] = value_0;
|
||||||
|
int value_index = i == 0 ? input_and_num.second : ++extra_input_num;
|
||||||
|
replacements.push_back(MakeValueReplacement(i, value_index));
|
||||||
|
replacements.push_back(MakeDataReplacement(i, input_num));
|
||||||
|
|
||||||
|
// Declare input values based on the input structure of the merged node.
|
||||||
|
// This code copies what shader_codegen would do automatically.
|
||||||
|
if (attr.code.input == IOStructure::AUTO) {
|
||||||
|
absl::StrAppend(&values, " value_", value_index, " = $input_data_",
|
||||||
|
input_num, "[gid.x, gid.y, gid.z]$;\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!graph->AddConsumer(node->id, super_inputs[i]->id).ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
input_num++;
|
||||||
|
}
|
||||||
|
attr.code.source_code =
|
||||||
|
absl::StrReplaceAll(attr.code.source_code, replacements);
|
||||||
|
|
||||||
|
// Merge all objects, parameters and source code.
|
||||||
|
if (!MergeCode(&attr, &node_attr).ok()) {
|
||||||
|
return {TransformStatus::INVALID, "Unable to merge the code"};
|
||||||
|
}
|
||||||
|
absl::StrAppend(&node_attr.code.source_code, "{\n", attr.code.source_code,
|
||||||
|
"\n}");
|
||||||
|
|
||||||
|
if (!operation_type.empty()) {
|
||||||
|
operation_type += ",";
|
||||||
|
}
|
||||||
|
operation_type += input->operation.type;
|
||||||
|
|
||||||
|
if (!graph->DeleteNode(input->id).ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add back all inputs that are used directly by the fused node.
|
||||||
|
for (int i = 0; i < input_values.size(); i++) {
|
||||||
|
if (node_code.input == IOStructure::AUTO) {
|
||||||
|
absl::StrAppend(&values, " value_", input_values[i].second,
|
||||||
|
" = $input_data_", input_num,
|
||||||
|
"[gid.x, gid.y, gid.z]$;\n");
|
||||||
|
}
|
||||||
|
if (!graph->AddConsumer(node->id, input_values[i].first).ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
input_num++;
|
||||||
|
}
|
||||||
|
|
||||||
|
node_code.input = IOStructure::ONLY_DEFINITIONS;
|
||||||
|
|
||||||
|
absl::StrAppend(&node->operation.type, "(", operation_type, ")");
|
||||||
|
node_code.source_code =
|
||||||
|
absl::StrCat(values, node_code.source_code, "{//FUSED",
|
||||||
|
node->operation.type, "\n", source_code, "\n}");
|
||||||
|
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
49
tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h
Normal file
49
tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_AUTO_INPUT_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_AUTO_INPUT_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Fuses nodes that have auto output with auto input node using the following
|
||||||
|
// rules.
|
||||||
|
//
|
||||||
|
// Source graph:
|
||||||
|
// A B C
|
||||||
|
// \ | /
|
||||||
|
// D
|
||||||
|
//
|
||||||
|
// - A, B and C each have a single output marked as AUTO
|
||||||
|
// - Each output is used only by D
|
||||||
|
// - D has all inputs marked as AUTO
|
||||||
|
//
|
||||||
|
// Result: in the best case a single node that does (A,B,C)+D operations.
|
||||||
|
//
|
||||||
|
class FuseAutoInput : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_AUTO_INPUT_H_
|
78
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc
Normal file
78
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iterator>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
TransformResult FuseAutoOutputWithInline::ApplyToNodesSequence(
|
||||||
|
const std::vector<Node*>& sequence, GraphFloat32* graph) {
|
||||||
|
Node* node1 = sequence.front();
|
||||||
|
Node* node2 = sequence.back();
|
||||||
|
auto& attr1 =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node1->operation.attributes);
|
||||||
|
auto& attr2 =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node2->operation.attributes);
|
||||||
|
|
||||||
|
if (attr1.code.output != IOStructure::AUTO ||
|
||||||
|
graph->FindInputs(node2->id).size() != 1 ||
|
||||||
|
graph->FindOutputs(node2->id).size() != 1 ||
|
||||||
|
attr2.code.output != IOStructure::AUTO ||
|
||||||
|
attr2.code.input != IOStructure::AUTO ||
|
||||||
|
(attr1.code.workload != attr2.code.workload &&
|
||||||
|
uint3() != attr2.code.workload) ||
|
||||||
|
graph->FindOutputs(node1->id).size() !=
|
||||||
|
graph->FindInputs(node2->id).size()) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the code was not fused yet, and wrap source code into {}.
|
||||||
|
if (node1->operation.type.find('+') == std::string::npos) {
|
||||||
|
attr1.code.source_code =
|
||||||
|
absl::StrCat("\n{\n", attr1.code.source_code, "\n}\n");
|
||||||
|
}
|
||||||
|
if (!MergeCode(&attr2, &attr1).ok()) {
|
||||||
|
return {TransformStatus::INVALID, "Unable to merge two nodes"};
|
||||||
|
}
|
||||||
|
absl::StrAppend(&attr1.code.source_code, "{\n", attr2.code.source_code,
|
||||||
|
"\n}");
|
||||||
|
node1->operation.type += "+" + node2->operation.type;
|
||||||
|
|
||||||
|
if (!RemoveFollowingNode(graph, node2, node1).ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove node " + std::to_string(node2->id)};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
57
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h
Normal file
57
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INLINE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INLINE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Fuses every two nodes where first node does default output and second node
|
||||||
|
// is INLINE.
|
||||||
|
//
|
||||||
|
// Generates code as follows:
|
||||||
|
// 1. all uniforms are inlined
|
||||||
|
// 2. source code is wrapped into {}
|
||||||
|
// For example:
|
||||||
|
// value = clamp(value, 0.0, clip);
|
||||||
|
// +
|
||||||
|
// value = 1.0 / (1.0 + exp(-1.0 * value));
|
||||||
|
// will turn into:
|
||||||
|
// {
|
||||||
|
// value = clamp(value, 0.0, clip);
|
||||||
|
// }
|
||||||
|
// {
|
||||||
|
// value = 1.0 / (1.0 + exp(-1.0 * value));
|
||||||
|
// }
|
||||||
|
class FuseAutoOutputWithInline : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INLINE_H_
|
151
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc
Normal file
151
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_replace.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static const char* kInplacePrefix = "inplace_update:\0";
|
||||||
|
|
||||||
|
class EmptyInplaceRewrite : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
|
||||||
|
if (input.compare(0, strlen(kInplacePrefix), kInplacePrefix) == 0) {
|
||||||
|
num_rewrites_++;
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_rewrites() const { return num_rewrites_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_rewrites_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Takes a code as an input. Replaces 'value_0' in the code with a value that
|
||||||
|
// comes in a rewrite. For example:
|
||||||
|
// code: value_0 = max(value_0, 0);
|
||||||
|
// rewrite: inplace_update:result_12 -> result_12 = max(result_12, 0);
|
||||||
|
//
|
||||||
|
class InplaceCodeRewrite : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
explicit InplaceCodeRewrite(const std::string& code) : code_(code) {}
|
||||||
|
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
|
||||||
|
int len = strlen(kInplacePrefix);
|
||||||
|
if (input.compare(0, len, kInplacePrefix) == 0) {
|
||||||
|
auto variable_name = input.substr(len);
|
||||||
|
absl::StrAppend(output,
|
||||||
|
absl::StrReplaceAll(code_, {{"value_0", variable_name}}));
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string code_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TransformResult RemoveUnusedInplaceUpdates::ApplyToNode(Node* node,
|
||||||
|
GraphFloat32* graph) {
|
||||||
|
auto& attr =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
|
||||||
|
// Remove inplace block by rewriting to empty string.
|
||||||
|
EmptyInplaceRewrite rewrite;
|
||||||
|
TextPreprocessor preprocessor('$', true);
|
||||||
|
preprocessor.AddRewrite(&rewrite);
|
||||||
|
if (!preprocessor.Rewrite(attr.code.source_code, &attr.code.source_code)
|
||||||
|
.ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
return {rewrite.num_rewrites() > 0 ? TransformStatus::APPLIED
|
||||||
|
: TransformStatus::SKIPPED,
|
||||||
|
""};
|
||||||
|
}
|
||||||
|
|
||||||
|
TransformResult FuseInplaceUpdate::ApplyToNodesSequence(
|
||||||
|
const std::vector<Node*>& sequence, GraphFloat32* graph) {
|
||||||
|
Node* node1 = sequence.front();
|
||||||
|
Node* node2 = sequence.back();
|
||||||
|
auto& attr1 =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node1->operation.attributes);
|
||||||
|
auto& attr2 =
|
||||||
|
absl::any_cast<CompiledNodeAttributes&>(node2->operation.attributes);
|
||||||
|
|
||||||
|
if (graph->FindInputs(node2->id).size() != 1 ||
|
||||||
|
graph->FindOutputs(node2->id).size() != 1 ||
|
||||||
|
attr2.code.output != IOStructure::AUTO ||
|
||||||
|
attr2.code.input != IOStructure::AUTO ||
|
||||||
|
(attr1.code.workload != attr2.code.workload &&
|
||||||
|
uint3() != attr2.code.workload)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
// First count of replaces that would happen to check whether rewrite is
|
||||||
|
// needed.
|
||||||
|
{
|
||||||
|
EmptyInplaceRewrite counting_rewrite;
|
||||||
|
TextPreprocessor preprocessor('$', true);
|
||||||
|
preprocessor.AddRewrite(&counting_rewrite);
|
||||||
|
std::string temp;
|
||||||
|
if (!preprocessor.Rewrite(attr1.code.source_code, &temp).ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
// no rewrites in the source code. skip it.
|
||||||
|
if (counting_rewrite.num_rewrites() == 0) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!MergeCode(&attr2, &attr1).ok()) {
|
||||||
|
return {TransformStatus::INVALID, "Unable to merge two nodes"};
|
||||||
|
}
|
||||||
|
TextPreprocessor preprocessor('$', true);
|
||||||
|
InplaceCodeRewrite rewrite(attr2.code.source_code);
|
||||||
|
preprocessor.AddRewrite(&rewrite);
|
||||||
|
if (!preprocessor.Rewrite(attr1.code.source_code, &attr1.code.source_code)
|
||||||
|
.ok()) {
|
||||||
|
return {TransformStatus::INVALID, ""};
|
||||||
|
}
|
||||||
|
node1->operation.type += "+" + node2->operation.type;
|
||||||
|
|
||||||
|
if (!RemoveFollowingNode(graph, node2, node1).ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Unable to remove node " + std::to_string(node2->id)};
|
||||||
|
}
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
67
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h
Normal file
67
tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INPLACE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INPLACE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Fuse two shaders where second shader is inline shader with the first.
|
||||||
|
// First shader should have a special symbol that defines a place where such
|
||||||
|
// fusion should be made and what variable needs to be changed.
|
||||||
|
// Second shader needs to operation with 'value_0' variable.
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// First shader:
|
||||||
|
// vec4 result = input_data_0.data[gid.x, gid.y, gid.z];
|
||||||
|
// $inplace_update:result$
|
||||||
|
// ...
|
||||||
|
// output_data_0.data[1,2,3] = result;
|
||||||
|
//
|
||||||
|
// Second shader:
|
||||||
|
// value_0 = max(value_0, 0);
|
||||||
|
//
|
||||||
|
// Fused shader:
|
||||||
|
// vec4 result = input_data_0.data[gid.x, gid.y, gid.z];
|
||||||
|
// result = max(result, 0);
|
||||||
|
// ...
|
||||||
|
// output_data_0.data[1,2,3] = result;
|
||||||
|
//
|
||||||
|
class FuseInplaceUpdate : public SequenceTransformation {
|
||||||
|
public:
|
||||||
|
int ExpectedSequenceLength() const final { return 2; }
|
||||||
|
|
||||||
|
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||||
|
GraphFloat32* graph) final;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Removes all %inplace_update:XXX% strings from the code.
|
||||||
|
class RemoveUnusedInplaceUpdates : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INPLACE_H_
|
546
tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc
Normal file
546
tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc
Normal file
@ -0,0 +1,546 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace object_accessor_internal {
|
||||||
|
|
||||||
|
// Splits name[index1, index2...] into 'name' and {'index1', 'index2'...}.
|
||||||
|
IndexedElement ParseElement(absl::string_view input) {
|
||||||
|
auto i = input.find('[');
|
||||||
|
if (i == std::string::npos || input.back() != ']') {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return {input.substr(0, i),
|
||||||
|
absl::StrSplit(input.substr(i + 1, input.size() - i - 2), ',',
|
||||||
|
absl::SkipWhitespace())};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace object_accessor_internal
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void MaybeConvertToHalf(DataType data_type, absl::string_view value,
|
||||||
|
std::string* output) {
|
||||||
|
if (data_type == DataType::FLOAT16) {
|
||||||
|
absl::StrAppend(output, "Vec4ToHalf(", value, ")");
|
||||||
|
} else {
|
||||||
|
absl::StrAppend(output, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaybeConvertFromHalf(DataType data_type, absl::string_view value,
|
||||||
|
std::string* output) {
|
||||||
|
if (data_type == DataType::FLOAT16) {
|
||||||
|
absl::StrAppend(output, "Vec4FromHalf(", value, ")");
|
||||||
|
} else {
|
||||||
|
absl::StrAppend(output, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReadFromTextureGenerator {
|
||||||
|
RewriteStatus operator()(uint32_t) const {
|
||||||
|
if (element.indices.size() != 1) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
// 1D textures are emulated as 2D textures
|
||||||
|
absl::StrAppend(result, "imageLoad(", element.object_name, ", ivec2(",
|
||||||
|
element.indices[0], ", 0))");
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Shape>
|
||||||
|
RewriteStatus operator()(const Shape&) const {
|
||||||
|
if (element.indices.size() != Shape::size()) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
absl::StrAppend(result, "imageLoad(", element.object_name, ", ivec",
|
||||||
|
Shape::size(), "(", absl::StrJoin(element.indices, ", "),
|
||||||
|
"))");
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
const object_accessor_internal::IndexedElement& element;
|
||||||
|
std::string* result;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReadFromBufferGenerator {
|
||||||
|
RewriteStatus operator()(uint32_t) const {
|
||||||
|
if (element.indices.size() != 1) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
MaybeConvertFromHalf(
|
||||||
|
data_type,
|
||||||
|
absl::StrCat(element.object_name, ".data[", element.indices[0], "]"),
|
||||||
|
result);
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus operator()(const uint2& size) const {
|
||||||
|
if (element.indices.size() == 1) {
|
||||||
|
// access by linear index. Use method above to generate accessor.
|
||||||
|
return (*this)(1U);
|
||||||
|
}
|
||||||
|
if (element.indices.size() != 2) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
MaybeConvertFromHalf(
|
||||||
|
data_type,
|
||||||
|
absl::StrCat(element.object_name, ".data[", element.indices[0], " + $",
|
||||||
|
element.object_name, "_w$ * (", element.indices[1], ")]"),
|
||||||
|
result);
|
||||||
|
*requires_sizes = true;
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus operator()(const uint3& size) const {
|
||||||
|
if (element.indices.size() == 1) {
|
||||||
|
// access by linear index. Use method above to generate accessor.
|
||||||
|
return (*this)(1U);
|
||||||
|
}
|
||||||
|
if (element.indices.size() != 3) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
MaybeConvertFromHalf(
|
||||||
|
data_type,
|
||||||
|
absl::StrCat(element.object_name, ".data[", element.indices[0], " + $",
|
||||||
|
element.object_name, "_w$ * (", element.indices[1], " + $",
|
||||||
|
element.object_name, "_h$ * (", element.indices[2], "))]"),
|
||||||
|
result);
|
||||||
|
*requires_sizes = true;
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType data_type;
|
||||||
|
const object_accessor_internal::IndexedElement& element;
|
||||||
|
std::string* result;
|
||||||
|
|
||||||
|
// indicates that generated code accessed _w and/or _h index variables.
|
||||||
|
bool* requires_sizes;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generates code for reading an element from an object.
|
||||||
|
RewriteStatus GenerateReadAccessor(
|
||||||
|
const Object& object,
|
||||||
|
const object_accessor_internal::IndexedElement& element,
|
||||||
|
std::string* result, bool* requires_sizes) {
|
||||||
|
switch (object.object_type) {
|
||||||
|
case ObjectType::BUFFER:
|
||||||
|
return absl::visit(ReadFromBufferGenerator{object.data_type, element,
|
||||||
|
result, requires_sizes},
|
||||||
|
object.size);
|
||||||
|
case ObjectType::TEXTURE:
|
||||||
|
return absl::visit(ReadFromTextureGenerator{element, result},
|
||||||
|
object.size);
|
||||||
|
case ObjectType::UNKNOWN:
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WriteToBufferGenerator {
|
||||||
|
RewriteStatus operator()(uint32_t) const {
|
||||||
|
if (element.indices.size() != 1) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
absl::StrAppend(result, element.object_name, ".data[", element.indices[0],
|
||||||
|
"] = ");
|
||||||
|
MaybeConvertToHalf(data_type, value, result);
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus operator()(const uint2& size) const {
|
||||||
|
if (element.indices.size() == 1) {
|
||||||
|
// access by linear index. Use method above to generate accessor.
|
||||||
|
return (*this)(1U);
|
||||||
|
}
|
||||||
|
if (element.indices.size() != 2) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
absl::StrAppend(result, element.object_name, ".data[", element.indices[0],
|
||||||
|
" + $", element.object_name, "_w$ * (", element.indices[1],
|
||||||
|
")] = ");
|
||||||
|
MaybeConvertToHalf(data_type, value, result);
|
||||||
|
*requires_sizes = true;
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus operator()(const uint3& size) const {
|
||||||
|
if (element.indices.size() == 1) {
|
||||||
|
// access by linear index. Use method above to generate accessor.
|
||||||
|
return (*this)(1U);
|
||||||
|
}
|
||||||
|
if (element.indices.size() != 3) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
absl::StrAppend(result, element.object_name, ".data[", element.indices[0],
|
||||||
|
" + $", element.object_name, "_w$ * (", element.indices[1],
|
||||||
|
" + $", element.object_name, "_h$ * (", element.indices[2],
|
||||||
|
"))] = ");
|
||||||
|
MaybeConvertToHalf(data_type, value, result);
|
||||||
|
*requires_sizes = true;
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType data_type;
|
||||||
|
const object_accessor_internal::IndexedElement& element;
|
||||||
|
absl::string_view value;
|
||||||
|
std::string* result;
|
||||||
|
|
||||||
|
// indicates that generated code accessed _w and/or _h index variables.
|
||||||
|
bool* requires_sizes;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WriteToTextureGenerator {
|
||||||
|
RewriteStatus operator()(uint32_t) const {
|
||||||
|
if (element.indices.size() != 1) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
// 1D textures are emulated as 2D textures
|
||||||
|
absl::StrAppend(result, "imageStore(", element.object_name, ", ivec2(",
|
||||||
|
element.indices[0], ", 0), ", value, ")");
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Shape>
|
||||||
|
RewriteStatus operator()(const Shape&) const {
|
||||||
|
if (element.indices.size() != Shape::size()) {
|
||||||
|
result->append("WRONG_NUMBER_OF_INDICES");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
absl::StrAppend(result, "imageStore(", element.object_name, ", ivec",
|
||||||
|
Shape::size(), "(", absl::StrJoin(element.indices, ", "),
|
||||||
|
"), ", value, ")");
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
const object_accessor_internal::IndexedElement& element;
|
||||||
|
absl::string_view value;
|
||||||
|
std::string* result;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generates code for writing value an element in an object.
|
||||||
|
RewriteStatus GenerateWriteAccessor(
|
||||||
|
const Object& object,
|
||||||
|
const object_accessor_internal::IndexedElement& element,
|
||||||
|
absl::string_view value, std::string* result, bool* requires_sizes) {
|
||||||
|
switch (object.object_type) {
|
||||||
|
case ObjectType::BUFFER:
|
||||||
|
return absl::visit(WriteToBufferGenerator{object.data_type, element,
|
||||||
|
value, result, requires_sizes},
|
||||||
|
object.size);
|
||||||
|
case ObjectType::TEXTURE:
|
||||||
|
return absl::visit(WriteToTextureGenerator{element, value, result},
|
||||||
|
object.size);
|
||||||
|
case ObjectType::UNKNOWN:
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToAccessModifier(AccessType access, bool use_readonly_modifier) {
|
||||||
|
switch (access) {
|
||||||
|
case AccessType::READ:
|
||||||
|
return use_readonly_modifier ? " readonly" : "";
|
||||||
|
case AccessType::WRITE:
|
||||||
|
return " writeonly";
|
||||||
|
case AccessType::READ_WRITE:
|
||||||
|
return " restrict";
|
||||||
|
}
|
||||||
|
return " unknown_access";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToBufferType(DataType data_type) {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::UINT8:
|
||||||
|
case DataType::UINT16:
|
||||||
|
case DataType::UINT32:
|
||||||
|
return "uvec4";
|
||||||
|
case DataType::INT8:
|
||||||
|
case DataType::INT16:
|
||||||
|
case DataType::INT32:
|
||||||
|
return "ivec4";
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
return "uvec2";
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "vec4";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextureImageTypeGetter {
|
||||||
|
std::string operator()(uint32_t) const {
|
||||||
|
// 1D textures are emulated as 2D textures
|
||||||
|
return (*this)(uint2());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string operator()(const uint2&) const {
|
||||||
|
switch (type) {
|
||||||
|
case DataType::UINT16:
|
||||||
|
case DataType::UINT32:
|
||||||
|
return "uimage2D";
|
||||||
|
case DataType::INT16:
|
||||||
|
case DataType::INT32:
|
||||||
|
return "iimage2D";
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "image2D";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string operator()(const uint3&) const {
|
||||||
|
switch (type) {
|
||||||
|
case DataType::UINT16:
|
||||||
|
case DataType::UINT32:
|
||||||
|
return "uimage2DArray";
|
||||||
|
case DataType::INT16:
|
||||||
|
case DataType::INT32:
|
||||||
|
return "iimage2DArray";
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "image2DArray";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType type;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string ToImageType(const Object& object) {
|
||||||
|
return absl::visit(TextureImageTypeGetter{object.data_type}, object.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToImageLayoutQualifier(DataType type) {
|
||||||
|
switch (type) {
|
||||||
|
case DataType::UINT16:
|
||||||
|
return "rgba16ui";
|
||||||
|
case DataType::UINT32:
|
||||||
|
return "rgba32ui";
|
||||||
|
case DataType::INT16:
|
||||||
|
return "rgba16i";
|
||||||
|
case DataType::INT32:
|
||||||
|
return "rgba32i";
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
return "rgba16f";
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "rgba32f";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToImagePrecision(DataType type) {
|
||||||
|
switch (type) {
|
||||||
|
case DataType::UINT16:
|
||||||
|
case DataType::INT16:
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
return "mediump";
|
||||||
|
case DataType::UINT32:
|
||||||
|
case DataType::INT32:
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "highp";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SizeParametersAdder {
|
||||||
|
void operator()(uint32_t) const {}
|
||||||
|
|
||||||
|
void operator()(const uint2& size) const {
|
||||||
|
parameters->AddParameter(
|
||||||
|
{absl::StrCat(object_name, "_w"), static_cast<int32_t>(size.x)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// p1 and p2 are padding. For some reason buffer does not map correctly
|
||||||
|
// without it.
|
||||||
|
void operator()(const uint3& size) const {
|
||||||
|
parameters->AddParameter(
|
||||||
|
{absl::StrCat(object_name, "_w"), static_cast<int32_t>(size.x)});
|
||||||
|
parameters->AddParameter(
|
||||||
|
{absl::StrCat(object_name, "_h"), static_cast<int32_t>(size.y)});
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::string_view object_name;
|
||||||
|
ParameterAccessor* parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Adds necessary parameters to parameter accessor that represent object size
|
||||||
|
// needed for indexed access.
|
||||||
|
// - 1D : empty
|
||||||
|
// - 2D : 'int object_name_w'
|
||||||
|
// - 3D : 'int object_name_w' + 'int object_name_h'
|
||||||
|
void AddSizeParameters(absl::string_view object_name, const Object& object,
|
||||||
|
ParameterAccessor* parameters) {
|
||||||
|
absl::visit(SizeParametersAdder{object_name, parameters}, object.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenerateObjectDeclaration(absl::string_view name, const Object& object,
|
||||||
|
std::string* declaration, bool is_mali) {
|
||||||
|
switch (object.object_type) {
|
||||||
|
case ObjectType::BUFFER:
|
||||||
|
// readonly modifier used to fix shader compilation for Mali on Android 8,
|
||||||
|
// see b/111601761
|
||||||
|
absl::StrAppend(declaration, "layout(binding = ", object.binding, ")",
|
||||||
|
ToAccessModifier(object.access, !is_mali), " buffer B",
|
||||||
|
object.binding, " { ", ToBufferType(object.data_type),
|
||||||
|
" data[]; } ", name, ";\n");
|
||||||
|
break;
|
||||||
|
case ObjectType::TEXTURE:
|
||||||
|
absl::StrAppend(declaration, "layout(",
|
||||||
|
ToImageLayoutQualifier(object.data_type),
|
||||||
|
", binding = ", object.binding, ")",
|
||||||
|
ToAccessModifier(object.access, true), " uniform ",
|
||||||
|
ToImagePrecision(object.data_type), " ",
|
||||||
|
ToImageType(object), " ", name, ";\n");
|
||||||
|
break;
|
||||||
|
case ObjectType::UNKNOWN:
|
||||||
|
// do nothing.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
RewriteStatus ObjectAccessor::Rewrite(absl::string_view input,
|
||||||
|
std::string* output) {
|
||||||
|
// Splits 'a =b' into {'a','b'}.
|
||||||
|
std::pair<absl::string_view, absl::string_view> n =
|
||||||
|
absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace());
|
||||||
|
if (n.first.empty()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
if (n.second.empty()) {
|
||||||
|
return RewriteRead(absl::StripAsciiWhitespace(n.first), output);
|
||||||
|
}
|
||||||
|
return RewriteWrite(absl::StripAsciiWhitespace(n.first),
|
||||||
|
absl::StripAsciiWhitespace(n.second), output);
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus ObjectAccessor::RewriteRead(absl::string_view location,
|
||||||
|
std::string* output) {
|
||||||
|
auto element = object_accessor_internal::ParseElement(location);
|
||||||
|
if (element.object_name.empty()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
auto it = name_to_object_.find(
|
||||||
|
std::string(element.object_name.data(), element.object_name.size()));
|
||||||
|
if (it == name_to_object_.end()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
bool requires_sizes = false;
|
||||||
|
auto status =
|
||||||
|
GenerateReadAccessor(it->second, element, output, &requires_sizes);
|
||||||
|
if (requires_sizes) {
|
||||||
|
AddSizeParameters(it->first, it->second, parameter_accessor_);
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus ObjectAccessor::RewriteWrite(absl::string_view location,
|
||||||
|
absl::string_view value,
|
||||||
|
std::string* output) {
|
||||||
|
// name[index1, index2...] = value
|
||||||
|
auto element = object_accessor_internal::ParseElement(location);
|
||||||
|
if (element.object_name.empty()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
auto it = name_to_object_.find(
|
||||||
|
std::string(element.object_name.data(), element.object_name.size()));
|
||||||
|
if (it == name_to_object_.end()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
bool requires_sizes = false;
|
||||||
|
auto status = GenerateWriteAccessor(it->second, element, value, output,
|
||||||
|
&requires_sizes);
|
||||||
|
if (requires_sizes) {
|
||||||
|
AddSizeParameters(it->first, it->second, parameter_accessor_);
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ObjectAccessor::AddObject(const std::string& name, Object object) {
|
||||||
|
if (object.object_type == ObjectType::UNKNOWN) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return name_to_object_.insert({name, std::move(object)}).second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ObjectAccessor::GetObjectDeclarations() const {
|
||||||
|
std::string declarations;
|
||||||
|
for (auto& o : name_to_object_) {
|
||||||
|
GenerateObjectDeclaration(o.first, o.second, &declarations, is_mali_);
|
||||||
|
}
|
||||||
|
return declarations;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ObjectAccessor::GetFunctionsDeclarations() const {
|
||||||
|
std::string modifier = "";
|
||||||
|
// Mali compiler does not want to compile a function without readonly
|
||||||
|
// modifier. See b/111601761 for the context.
|
||||||
|
if (is_mali_) {
|
||||||
|
modifier = "readonly ";
|
||||||
|
}
|
||||||
|
// If there is a single object SSBO with F16, then we need to output functions
|
||||||
|
// as well.
|
||||||
|
for (const auto& o : name_to_object_) {
|
||||||
|
if (o.second.data_type == DataType::FLOAT16 &&
|
||||||
|
o.second.object_type == ObjectType::BUFFER) {
|
||||||
|
return absl::StrCat("vec4 Vec4FromHalf(in ", modifier,
|
||||||
|
"uvec2 v) { return vec4(unpackHalf2x16(v.x), "
|
||||||
|
"unpackHalf2x16(v.y)); }\n"
|
||||||
|
"uvec2 Vec4ToHalf(in ",
|
||||||
|
modifier,
|
||||||
|
"vec4 v) { return uvec2(packHalf2x16(v.xy), "
|
||||||
|
"packHalf2x16(v.zw)); }\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Object> ObjectAccessor::GetObjects() const {
|
||||||
|
std::vector<Object> objects;
|
||||||
|
for (auto& o : name_to_object_) {
|
||||||
|
objects.push_back(o.second);
|
||||||
|
}
|
||||||
|
return objects;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
105
tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h
Normal file
105
tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// This rewrite handles access to objects both reads and writes.
|
||||||
|
//
|
||||||
|
// The following syntax is supported to access objects:
|
||||||
|
//
|
||||||
|
// READ:
|
||||||
|
// vec4 value = $data[i]$;
|
||||||
|
// where data is a buffer or 1D texture
|
||||||
|
// vec4 value = $data[i,j]$;
|
||||||
|
// where data is 2D texture
|
||||||
|
// vec4 value = $data[i,j,k]$;
|
||||||
|
// where data is 3D texture
|
||||||
|
//
|
||||||
|
// WRITE:
|
||||||
|
// $data[i] = value$;
|
||||||
|
// where data is a buffer or 1D texture
|
||||||
|
// $data[i,j] = value$;
|
||||||
|
// where data is 2D texture
|
||||||
|
// $data[i,j,k] = value$;
|
||||||
|
// where data is 3D texture
|
||||||
|
//
|
||||||
|
// Accessor supports all types (gvecN) as well as float16.
|
||||||
|
//
|
||||||
|
// TODO(akulik): support field in data[x,y,z].x
|
||||||
|
//
|
||||||
|
class ObjectAccessor : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
ObjectAccessor(bool is_mali, ParameterAccessor* parameter_accessor)
|
||||||
|
: is_mali_(is_mali), parameter_accessor_(parameter_accessor) {}
|
||||||
|
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final;
|
||||||
|
|
||||||
|
// Return true if object was successfully added.
|
||||||
|
bool AddObject(const std::string& name, Object object);
|
||||||
|
|
||||||
|
// Returns objects declarations that need to be added in a shader's code.
|
||||||
|
std::string GetObjectDeclarations() const;
|
||||||
|
|
||||||
|
// Returns functions declarations that need to be added in a shader's code.
|
||||||
|
// These functions are used by code accessing objects.
|
||||||
|
std::string GetFunctionsDeclarations() const;
|
||||||
|
|
||||||
|
// Returns a collection of registered objects
|
||||||
|
std::vector<Object> GetObjects() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
RewriteStatus RewriteRead(absl::string_view location, std::string* output);
|
||||||
|
|
||||||
|
RewriteStatus RewriteWrite(absl::string_view location,
|
||||||
|
absl::string_view value, std::string* output);
|
||||||
|
|
||||||
|
std::unordered_map<std::string, Object> name_to_object_;
|
||||||
|
|
||||||
|
const bool is_mali_;
|
||||||
|
ParameterAccessor* parameter_accessor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implementation details below.
|
||||||
|
|
||||||
|
namespace object_accessor_internal {
|
||||||
|
|
||||||
|
// Refers to an element in an object.
|
||||||
|
struct IndexedElement {
|
||||||
|
absl::string_view object_name;
|
||||||
|
std::vector<absl::string_view> indices;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Splits name[index1, index2...] into 'name' and {'index1', 'index2'...}.
|
||||||
|
IndexedElement ParseElement(absl::string_view input);
|
||||||
|
|
||||||
|
} // namespace object_accessor_internal
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_
|
@ -0,0 +1,206 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
struct ParameterComparator {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const T& t) const {
|
||||||
|
const T* v = absl::get_if<T>(&p.value);
|
||||||
|
return v && t == *v;
|
||||||
|
}
|
||||||
|
const UniformParameter& p;
|
||||||
|
};
|
||||||
|
|
||||||
|
// partially equal
|
||||||
|
bool operator==(const UniformParameter& l, const UniformParameter& r) {
|
||||||
|
return l.name == r.name && absl::visit(ParameterComparator{l}, r.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(Preprocessor, CornerCases) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
std::string result;
|
||||||
|
ASSERT_EQ(accessor.Rewrite("", &result), RewriteStatus::NOT_RECOGNIZED);
|
||||||
|
ASSERT_EQ(accessor.Rewrite("=", &result), RewriteStatus::NOT_RECOGNIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, ReadFromBuffer) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_TRUE(parameters.GetUniformParameters().empty());
|
||||||
|
ASSERT_EQ(result, "obj.data[i]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, ReadFromBufferLinear) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyBuffer(uint3(1, 2, 3), std::vector<float>{1.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_TRUE(parameters.GetUniformParameters().empty());
|
||||||
|
ASSERT_EQ(result, "obj.data[i]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, ReadFromBufferByIndex) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyBuffer(uint3(1, 2, 3), std::vector<float>{1.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[x,y + 5,z]", &result),
|
||||||
|
RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_THAT(parameters.GetUniformParameters(),
|
||||||
|
testing::UnorderedElementsAre(UniformParameter{"obj_w", 1},
|
||||||
|
UniformParameter{"obj_h", 2}));
|
||||||
|
ASSERT_EQ(result, "obj.data[x + $obj_w$ * (y + 5 + $obj_h$ * (z))]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, ReadFromTexture) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyTexture(uint3(1, 2, 3), {1.0, 2.0, 3.0, 4.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i,j,k]", &result), RewriteStatus::SUCCESS);
|
||||||
|
// textures don't need extra variables to be stored for indexed access
|
||||||
|
EXPECT_TRUE(parameters.GetUniformParameters().empty());
|
||||||
|
ASSERT_EQ(result, "imageLoad(obj, ivec3(i, j, k))");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, ReadFromTexture1D) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddObject("obj", MakeReadonlyTexture({1.0, 2.0, 3.0, 4.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_TRUE(parameters.GetUniformParameters().empty());
|
||||||
|
ASSERT_EQ(result, "imageLoad(obj, ivec2(i, 0))");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, WriteToBuffer) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite(" obj[i] =value", &result),
|
||||||
|
RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_TRUE(parameters.GetUniformParameters().empty());
|
||||||
|
ASSERT_EQ(result, "obj.data[i] = value");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, WriteToBufferByIndex) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyBuffer(uint3(1, 2, 3), {1.0, 2.0, 3.0, 4.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite(" obj[i,j,k] =value", &result),
|
||||||
|
RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_THAT(parameters.GetUniformParameters(),
|
||||||
|
testing::UnorderedElementsAre(UniformParameter{"obj_w", 1},
|
||||||
|
UniformParameter{"obj_h", 2}));
|
||||||
|
ASSERT_EQ(result, "obj.data[i + $obj_w$ * (j + $obj_h$ * (k))] = value");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, WriteToTexture) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i,j,k]= value ", &result),
|
||||||
|
RewriteStatus::SUCCESS);
|
||||||
|
ASSERT_EQ(result, "imageStore(obj, ivec3(i, j, k), value)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, WriteToTexture1D) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddObject("obj", MakeReadonlyTexture({1.0, 2.0, 3.0, 4.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i]= value ", &result),
|
||||||
|
RewriteStatus::SUCCESS);
|
||||||
|
EXPECT_TRUE(parameters.GetUniformParameters().empty());
|
||||||
|
ASSERT_EQ(result, "imageStore(obj, ivec2(i, 0), value)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, FailedWriteToBuffer) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite(" obj[i,j] =value", &result),
|
||||||
|
RewriteStatus::ERROR);
|
||||||
|
ASSERT_EQ(result, "WRONG_NUMBER_OF_INDICES");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, FailedWriteToTexture) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0})));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("obj[i]= value ", &result), RewriteStatus::ERROR);
|
||||||
|
ASSERT_EQ(result, "WRONG_NUMBER_OF_INDICES");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, DeclareTexture) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(false, ¶meters);
|
||||||
|
ASSERT_TRUE(accessor.AddObject(
|
||||||
|
"obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0})));
|
||||||
|
ASSERT_EQ(accessor.GetObjectDeclarations(),
|
||||||
|
"layout(rgba32f, binding = 0) readonly uniform highp image2DArray "
|
||||||
|
"obj;\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, DeclareBuffer) {
|
||||||
|
ParameterAccessor parameters(false);
|
||||||
|
ObjectAccessor accessor(true, ¶meters);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddObject("obj", MakeReadonlyBuffer(std::vector<float>{1.0})));
|
||||||
|
ASSERT_EQ(accessor.GetObjectDeclarations(),
|
||||||
|
"layout(binding = 0) buffer B0 { vec4 data[]; } obj;\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
368
tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc
Normal file
368
tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace parameter_accessor_internal {
|
||||||
|
|
||||||
|
// Parse the following regex manually
|
||||||
|
// name(\[index\])?(\.field)?
|
||||||
|
ParameterReference Parse(absl::string_view input) {
|
||||||
|
ParameterReference ref;
|
||||||
|
auto start_index = input.find('[');
|
||||||
|
if (start_index != std::string::npos) {
|
||||||
|
auto end_index = input.rfind(']');
|
||||||
|
if (end_index == std::string::npos) {
|
||||||
|
return ref;
|
||||||
|
}
|
||||||
|
ref.index = input.substr(start_index + 1, end_index - start_index - 1);
|
||||||
|
ref.name = input.substr(0, start_index);
|
||||||
|
ref.field = input.substr(end_index + 1);
|
||||||
|
} else {
|
||||||
|
auto dot = input.find('.');
|
||||||
|
if (dot != std::string::npos) {
|
||||||
|
ref.name = input.substr(0, dot);
|
||||||
|
ref.field = input.substr(dot);
|
||||||
|
} else {
|
||||||
|
ref.name = input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace parameter_accessor_internal
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct UniformTypeGetter {
|
||||||
|
std::string operator()(int) const { return "int"; }
|
||||||
|
std::string operator()(const int2&) const { return "ivec2"; }
|
||||||
|
std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
|
||||||
|
std::string operator()(const int4&) const { return "ivec4"; }
|
||||||
|
std::string operator()(unsigned int) const { return "uint"; }
|
||||||
|
std::string operator()(const uint4&) const { return "uvec4"; }
|
||||||
|
std::string operator()(float) const { return "float"; }
|
||||||
|
std::string operator()(const float2&) const { return "vec2"; }
|
||||||
|
std::string operator()(const float4&) const { return "vec4"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns GLSL uniform type of the given parameter.
|
||||||
|
std::string GetUniformType(const UniformParameter::ValueType& value) {
|
||||||
|
return absl::visit(UniformTypeGetter(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void FormatValue(std::string* result, T t) {
|
||||||
|
absl::StrAppend(result, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void FormatValue(std::string* result, float t) {
|
||||||
|
absl::StrAppend(result, absl::StrFormat("%.9ff", t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unfortunately absl::StrJoin with custom formatter requires formatter to use
|
||||||
|
// string, not std::string. Therefore, due to this compatibility issue data
|
||||||
|
// needs to be converted to string representation first and then joined.
|
||||||
|
template <typename T, int N>
|
||||||
|
std::vector<std::string> ToString(const std::array<T, N>& data) {
|
||||||
|
std::vector<std::string> result(N);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
FormatValue(&result[i], data[i]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConstGenerator {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(T t) const {
|
||||||
|
FormatValue(result, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const Vec2<T>& v) const {
|
||||||
|
absl::StrAppend(result, UniformTypeGetter()(v), "(",
|
||||||
|
absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const Vec3<T>& v) const {
|
||||||
|
absl::StrAppend(result, UniformTypeGetter()(v), "(",
|
||||||
|
absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const Vec4<T>& v) const {
|
||||||
|
absl::StrAppend(result, UniformTypeGetter()(v), "(",
|
||||||
|
absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const std::vector<T>& v) const {
|
||||||
|
std::string type = UniformTypeGetter()(v);
|
||||||
|
absl::StrAppend(result, type, "[", v.size(), "](");
|
||||||
|
bool first = true;
|
||||||
|
for (const auto& i : v) {
|
||||||
|
if (first) {
|
||||||
|
first = false;
|
||||||
|
} else {
|
||||||
|
absl::StrAppend(result, ",");
|
||||||
|
}
|
||||||
|
(*this)(i);
|
||||||
|
}
|
||||||
|
absl::StrAppend(result, ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string* result;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Appends string representation of a parameter value.
|
||||||
|
void GetValue(const UniformParameter::ValueType& value, std::string* result) {
|
||||||
|
absl::visit(ConstGenerator{result}, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UniformDeclarationGenerator {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const T&) const {
|
||||||
|
absl::StrAppend(result, "uniform ", GetUniformType(param.value), " ",
|
||||||
|
param.name, ";\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const std::vector<T>& v) const {
|
||||||
|
absl::StrAppend(result, "uniform ", GetUniformType(param.value), " ",
|
||||||
|
param.name, "[", v.size(), "];\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
const UniformParameter& param;
|
||||||
|
std::string* result;
|
||||||
|
};
|
||||||
|
|
||||||
|
void GenerateUniformDeclaration(const UniformParameter& parameter,
|
||||||
|
std::string* result) {
|
||||||
|
absl::visit(UniformDeclarationGenerator{parameter, result}, parameter.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct VariableLengthGetter {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const T&) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const std::vector<T>&) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns true if value is a vector
|
||||||
|
bool IsVariableLength(const UniformParameter::ValueType& value) {
|
||||||
|
return absl::visit(VariableLengthGetter(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
|
||||||
|
|
||||||
|
Field ToField(absl::string_view field_name) {
|
||||||
|
if (field_name.size() == 2 && field_name[0] == '.') {
|
||||||
|
switch (field_name[1]) {
|
||||||
|
case 'x':
|
||||||
|
return Field::X;
|
||||||
|
case 'y':
|
||||||
|
return Field::Y;
|
||||||
|
case 'z':
|
||||||
|
return Field::Z;
|
||||||
|
case 'w':
|
||||||
|
return Field::W;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Field::UNKNOWN;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FieldAccessor {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const T&) const {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const Vec2<T>& v) const {
|
||||||
|
FormatValue(result, v[field]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const Vec3<T>& v) const {
|
||||||
|
FormatValue(result, v[field]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const Vec4<T>& v) const {
|
||||||
|
FormatValue(result, v[field]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Field field;
|
||||||
|
std::string* result;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Appends formatted value of the given field.
|
||||||
|
void GetValue(const UniformParameter::ValueType& value, Field field,
|
||||||
|
std::string* result) {
|
||||||
|
absl::visit(FieldAccessor{field, result}, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FieldChecker {
|
||||||
|
// For trivial as well as variable-length types indexed access is not allowed.
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const T&) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const Vec2<T>& v) const {
|
||||||
|
return field < v.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const Vec3<T>& v) const {
|
||||||
|
return field < v.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const Vec4<T>& v) const {
|
||||||
|
return field < v.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(const std::vector<T>&) const {
|
||||||
|
// technically accessing [0] element of an empty vector is UB, but we need
|
||||||
|
// only type information for this check. Therefore, construct default T and
|
||||||
|
// use it instead.
|
||||||
|
T t;
|
||||||
|
return (*this)(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
Field field;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns true if field has field access and field is not out of bounds.
|
||||||
|
bool HasField(const UniformParameter::ValueType& value, Field field) {
|
||||||
|
return absl::visit(FieldChecker{field}, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssembleAccessor(absl::string_view name, absl::string_view index,
|
||||||
|
absl::string_view field, std::string* result) {
|
||||||
|
if (index.empty()) {
|
||||||
|
absl::StrAppend(result, name, field);
|
||||||
|
} else {
|
||||||
|
absl::StrAppend(result, name, "[", index, "]", field);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
RewriteStatus ParameterAccessor::Rewrite(absl::string_view input,
|
||||||
|
std::string* output) {
|
||||||
|
auto ref = parameter_accessor_internal::Parse(input);
|
||||||
|
if (ref.name.empty()) {
|
||||||
|
absl::StrAppend(output, "INVALID_SYNTAX");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it = name_to_param_.find(std::string(ref.name.data(), ref.name.size()));
|
||||||
|
if (it == name_to_param_.end()) {
|
||||||
|
// Uniform with this name is not registered.
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
const auto& value = it->second.value;
|
||||||
|
|
||||||
|
if (!ref.index.empty() && !IsVariableLength(value)) {
|
||||||
|
// Trying to access parameter by index, but it is not variable-length.
|
||||||
|
absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
Field f = ToField(ref.field);
|
||||||
|
if (!ref.field.empty() && !HasField(value, f)) {
|
||||||
|
// Trying to access a parameter by field, but it does not have it.
|
||||||
|
absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error checks are complete now.
|
||||||
|
|
||||||
|
// All variable-length parameters are encoded as-is without inlining.
|
||||||
|
if (!inline_values_ || IsVariableLength(value)) {
|
||||||
|
AssembleAccessor(it->second.name, ref.index, ref.field, output);
|
||||||
|
} else {
|
||||||
|
// Parameter + field is replaced with field value.
|
||||||
|
if (f != Field::UNKNOWN) {
|
||||||
|
GetValue(value, f, output);
|
||||||
|
} else {
|
||||||
|
// Parameter is accessed directly.
|
||||||
|
GetValue(value, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ParameterAccessor::AddParameter(UniformParameter param) {
|
||||||
|
std::string name = param.name;
|
||||||
|
return name_to_param_.insert({name, std::move(param)}).second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ParameterAccessor::GetConstDeclarations() const {
|
||||||
|
// Variable length parameters are declared as const and accessed via variable
|
||||||
|
// with index.
|
||||||
|
std::string declarations;
|
||||||
|
for (auto& param : name_to_param_) {
|
||||||
|
const auto& value = param.second.value;
|
||||||
|
if (IsVariableLength(value)) {
|
||||||
|
absl::StrAppend(&declarations, "const ", GetUniformType(value), " ",
|
||||||
|
param.second.name, "[] = ");
|
||||||
|
GetValue(value, &declarations);
|
||||||
|
absl::StrAppend(&declarations, ";\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return declarations;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ParameterAccessor::GetUniformDeclarations() const {
|
||||||
|
std::string declarations;
|
||||||
|
if (!inline_values_) {
|
||||||
|
for (auto& param : name_to_param_) {
|
||||||
|
GenerateUniformDeclaration(param.second, &declarations);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return declarations;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<UniformParameter> ParameterAccessor::GetUniformParameters() const {
|
||||||
|
std::vector<UniformParameter> params;
|
||||||
|
if (!inline_values_) {
|
||||||
|
for (auto& param : name_to_param_) {
|
||||||
|
params.push_back(param.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,92 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// This rewrite handles access to parameters. It may rewrite a parameter with
|
||||||
|
// actual values if inline_values is set to true.
|
||||||
|
//
|
||||||
|
// The following syntax is supported to access parameters:
|
||||||
|
// - simple parameter: name
|
||||||
|
// - parameter with field: name.(x|y|z|w)
|
||||||
|
// - parameter with index: name[i]
|
||||||
|
// - parameter with index and field: name[i].(x|y|z|w)
|
||||||
|
//
|
||||||
|
// If 'inline_values' is set to true, non variable-length parameters will be
|
||||||
|
// inlined. For example, 'base.x' will be replaced with value of 'x' field from
|
||||||
|
// 'base'. Variable-length are declared as const and accessed via index.
|
||||||
|
// These declarations are returned by GetConstDeclarations.
|
||||||
|
//
|
||||||
|
// If 'inline_values' is set to false, all parameters will be declared as
|
||||||
|
// uniforms. Uniform declarations are returned by GetUniformDeclarations.
|
||||||
|
class ParameterAccessor : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
explicit ParameterAccessor(bool inline_values)
|
||||||
|
: inline_values_(inline_values) {}
|
||||||
|
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final;
|
||||||
|
|
||||||
|
// Return true if parameter was successfully added.
|
||||||
|
bool AddParameter(UniformParameter param);
|
||||||
|
|
||||||
|
// Returns const parameters that need to be inlined in the a shader's code.
|
||||||
|
std::string GetConstDeclarations() const;
|
||||||
|
|
||||||
|
// Returns uniforms declarations that need to be inlined in a shader's code.
|
||||||
|
std::string GetUniformDeclarations() const;
|
||||||
|
|
||||||
|
// Returns a collection of uniform parameters.
|
||||||
|
std::vector<UniformParameter> GetUniformParameters() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const bool inline_values_;
|
||||||
|
// Unique parameter index used for obfuscation.
|
||||||
|
uint32_t unique_param_index_ = 0;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, UniformParameter> name_to_param_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implementation details below.
|
||||||
|
|
||||||
|
namespace parameter_accessor_internal {
|
||||||
|
|
||||||
|
struct ParameterReference {
|
||||||
|
absl::string_view name;
|
||||||
|
absl::string_view index;
|
||||||
|
absl::string_view field;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the following regex manually
|
||||||
|
// name(\[index\])?(\.field)?
|
||||||
|
ParameterReference Parse(absl::string_view input);
|
||||||
|
|
||||||
|
} // namespace parameter_accessor_internal
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_
|
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(Preprocessor, CornerCases) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
std::string result;
|
||||||
|
ASSERT_EQ(accessor.Rewrite("unknown", &result),
|
||||||
|
RewriteStatus::NOT_RECOGNIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, Value) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", int32_t(1)}));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("var", &result), RewriteStatus::SUCCESS);
|
||||||
|
ASSERT_EQ(result, "1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, ValueVec) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", int2(1, 2)}));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("var", &result), RewriteStatus::SUCCESS);
|
||||||
|
ASSERT_EQ(result, "ivec2(1,2)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, Field) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
ASSERT_TRUE(
|
||||||
|
accessor.AddParameter(UniformParameter{"var", float2(1.0, 2.1234567)}));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("var.y", &result), RewriteStatus::SUCCESS);
|
||||||
|
ASSERT_EQ(result, "2.123456717f");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, FieldFail) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", 1.0f}));
|
||||||
|
ASSERT_TRUE(accessor.AddParameter(UniformParameter{"vec", float2(1.0, 1.0)}));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("var.y", &result), RewriteStatus::ERROR);
|
||||||
|
ASSERT_EQ(result, "INVALID_ACCESS_BY_FIELD");
|
||||||
|
|
||||||
|
result.clear();
|
||||||
|
EXPECT_EQ(accessor.Rewrite("vec.z", &result), RewriteStatus::ERROR);
|
||||||
|
ASSERT_EQ(result, "INVALID_ACCESS_BY_FIELD");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, Variable) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
std::vector<int2> v;
|
||||||
|
v.push_back(int2(1, 2));
|
||||||
|
ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", v}));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("var[i].y", &result), RewriteStatus::SUCCESS);
|
||||||
|
ASSERT_EQ(result, "var[i].y");
|
||||||
|
ASSERT_EQ(accessor.GetConstDeclarations(),
|
||||||
|
"const ivec2 var[] = ivec2[1](ivec2(1,2));\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, InlineVariableFail) {
|
||||||
|
ParameterAccessor accessor(true);
|
||||||
|
ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", 1}));
|
||||||
|
std::string result;
|
||||||
|
EXPECT_EQ(accessor.Rewrite("var[i]", &result), RewriteStatus::ERROR);
|
||||||
|
ASSERT_EQ(result, "INVALID_ACCESS_BY_INDEX");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
95
tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc
Normal file
95
tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Given input string and a delimiter returns back a substring including
|
||||||
|
// delimiters. If there was only starting delimiter found, returns single char.
|
||||||
|
absl::string_view FindInlineBlock(absl::string_view s, char delimiter) {
|
||||||
|
size_t start = s.find(delimiter);
|
||||||
|
if (start != absl::string_view::npos) {
|
||||||
|
size_t end = s.find(delimiter, start + 1);
|
||||||
|
if (end != std::string::npos) {
|
||||||
|
return s.substr(start, end - start + 1);
|
||||||
|
}
|
||||||
|
// Special case to indicate that we didn't find the end.
|
||||||
|
return s.substr(start, 1);
|
||||||
|
}
|
||||||
|
return s.substr(s.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the given 's' and its substring 'subs' returns new substring of 's' that
|
||||||
|
// begins past 'subs'.
|
||||||
|
absl::string_view PastSubstr(absl::string_view s, absl::string_view subs) {
|
||||||
|
return s.substr(subs.data() + subs.size() - s.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status TextPreprocessor::Rewrite(const std::string& input,
|
||||||
|
std::string* output) {
|
||||||
|
absl::string_view s = input;
|
||||||
|
std::string result;
|
||||||
|
while (true) {
|
||||||
|
absl::string_view inline_block = FindInlineBlock(s, inline_delimiter_);
|
||||||
|
result.append(s.data(), inline_block.data() - s.data());
|
||||||
|
if (inline_block.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (inline_block.size() == 1) {
|
||||||
|
return NotFoundError("Unable to find end of inline block");
|
||||||
|
}
|
||||||
|
s = PastSubstr(s, inline_block);
|
||||||
|
bool processed = false;
|
||||||
|
for (auto& rewrite : inline_rewrites_) {
|
||||||
|
if (processed) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
switch (rewrite->Rewrite(inline_block.substr(1, inline_block.size() - 2),
|
||||||
|
&result)) {
|
||||||
|
case RewriteStatus::NOT_RECOGNIZED:
|
||||||
|
// try another rewrite.
|
||||||
|
break;
|
||||||
|
case RewriteStatus::SUCCESS:
|
||||||
|
processed = true;
|
||||||
|
break;
|
||||||
|
case RewriteStatus::ERROR:
|
||||||
|
return InternalError(absl::StrCat("Error while rewriting '",
|
||||||
|
inline_block, "': ", result));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!processed) {
|
||||||
|
if (!keep_unknown_rewrites_) {
|
||||||
|
return NotFoundError(absl::StrCat("Didn't find inline rewrite for '",
|
||||||
|
inline_block, "'"));
|
||||||
|
}
|
||||||
|
absl::StrAppend(&result, inline_block);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*output = std::move(result);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
74
tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h
Normal file
74
tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PREPROCESSOR_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PREPROCESSOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
enum class RewriteStatus {
|
||||||
|
SUCCESS = 0,
|
||||||
|
NOT_RECOGNIZED = 1,
|
||||||
|
ERROR = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inline rewrite matches a string and rewrites it.
|
||||||
|
class InlineRewrite {
|
||||||
|
public:
|
||||||
|
virtual ~InlineRewrite() = default;
|
||||||
|
|
||||||
|
virtual RewriteStatus Rewrite(absl::string_view input,
|
||||||
|
std::string* output) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Text preprocessor runs a collection of registered rewrites.
|
||||||
|
// It uses a single character prefix as inline delimiter that needs to quote
|
||||||
|
// text to be rewritten.
|
||||||
|
class TextPreprocessor {
|
||||||
|
public:
|
||||||
|
// @param keep_unknown_rewrites if true, will keep unhandled rewrites as is
|
||||||
|
// instead of reporting an error.
|
||||||
|
TextPreprocessor(char inline_delimiter, bool keep_unknown_rewrites)
|
||||||
|
: inline_delimiter_(inline_delimiter),
|
||||||
|
keep_unknown_rewrites_(keep_unknown_rewrites) {}
|
||||||
|
|
||||||
|
void AddRewrite(InlineRewrite* rewrite) {
|
||||||
|
inline_rewrites_.push_back(rewrite);
|
||||||
|
}
|
||||||
|
|
||||||
|
// input and output may point to the same object.
|
||||||
|
Status Rewrite(const std::string& input, std::string* output);
|
||||||
|
|
||||||
|
private:
|
||||||
|
const char inline_delimiter_;
|
||||||
|
const bool keep_unknown_rewrites_;
|
||||||
|
|
||||||
|
std::vector<InlineRewrite*> inline_rewrites_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PREPROCESSOR_H_
|
129
tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc
Normal file
129
tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class AccuInlineRewrite : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
explicit AccuInlineRewrite(std::vector<std::string>* blocks)
|
||||||
|
: blocks_(blocks) {}
|
||||||
|
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
|
||||||
|
blocks_->push_back(std::string(input.data(), input.size()));
|
||||||
|
output->append("r:");
|
||||||
|
output->append(input.data(), input.size());
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string>* blocks_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::string> ParseInlines(const std::string& text) {
|
||||||
|
std::vector<std::string> blocks;
|
||||||
|
TextPreprocessor preprocessor('$', false);
|
||||||
|
AccuInlineRewrite rewrite(&blocks);
|
||||||
|
preprocessor.AddRewrite(&rewrite);
|
||||||
|
std::string discard;
|
||||||
|
preprocessor.Rewrite(text, &discard).IgnoreError();
|
||||||
|
return blocks;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, CornerCases) {
|
||||||
|
EXPECT_THAT(ParseInlines(""), testing::ElementsAre());
|
||||||
|
EXPECT_THAT(ParseInlines("text text"), testing::ElementsAre());
|
||||||
|
EXPECT_THAT(ParseInlines("$$"), testing::ElementsAre(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, One) {
|
||||||
|
EXPECT_THAT(ParseInlines("$text$"), testing::ElementsAre("text"));
|
||||||
|
EXPECT_THAT(ParseInlines(" $text$ "), testing::ElementsAre("text"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, More) {
|
||||||
|
EXPECT_THAT(ParseInlines("Test $inline1$\n$inline2$ test $inline3$ "),
|
||||||
|
testing::ElementsAre("inline1", "inline2", "inline3"));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RewriteInlines(const std::string& text) {
|
||||||
|
std::vector<std::string> blocks;
|
||||||
|
TextPreprocessor preprocessor('$', false);
|
||||||
|
AccuInlineRewrite rewrite(&blocks);
|
||||||
|
preprocessor.AddRewrite(&rewrite);
|
||||||
|
std::string out;
|
||||||
|
preprocessor.Rewrite(text, &out).IgnoreError();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, RewriteCornerCases) {
|
||||||
|
EXPECT_EQ(RewriteInlines(""), "");
|
||||||
|
EXPECT_EQ(RewriteInlines("text text"), "text text");
|
||||||
|
EXPECT_EQ(RewriteInlines("$$"), "r:");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, RewriteOne) {
|
||||||
|
EXPECT_EQ(RewriteInlines("$text$"), "r:text");
|
||||||
|
EXPECT_EQ(RewriteInlines(" $text$ "), " r:text ");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, RewriteMore) {
|
||||||
|
EXPECT_EQ(RewriteInlines("Test $inline1$\n$inline2$ test $inline3$ "),
|
||||||
|
"Test r:inline1\nr:inline2 test r:inline3 ");
|
||||||
|
}
|
||||||
|
|
||||||
|
class SingleRewrite : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
|
||||||
|
if (input == "foo") {
|
||||||
|
output->append("bla");
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string>* blocks_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(Preprocessor, KeepUnknownRewrites) {
|
||||||
|
TextPreprocessor preprocessor('$', true);
|
||||||
|
SingleRewrite rewrite;
|
||||||
|
preprocessor.AddRewrite(&rewrite);
|
||||||
|
std::string out;
|
||||||
|
ASSERT_TRUE(preprocessor.Rewrite("Good morning, $name$! $foo$", &out).ok());
|
||||||
|
EXPECT_EQ("Good morning, $name$! bla", out);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Preprocessor, KeepUnknownRewrites_Fail) {
|
||||||
|
TextPreprocessor preprocessor('$', false);
|
||||||
|
SingleRewrite rewrite;
|
||||||
|
preprocessor.AddRewrite(&rewrite);
|
||||||
|
std::string out;
|
||||||
|
EXPECT_FALSE(preprocessor.Rewrite("Good morning, $name$! $foo$", &out).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
203
tensorflow/lite/delegates/gpu/gl/compiler/rename.cc
Normal file
203
tensorflow/lite/delegates/gpu/gl/compiler/rename.cc
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Rewrites names of all parameters according to returned values from the
|
||||||
|
// given NameFunctor.
|
||||||
|
class ParameterRewriter : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
ParameterRewriter(const std::string& inline_delimiter,
|
||||||
|
const NameFunctor& name_func)
|
||||||
|
: inline_delimiter_(inline_delimiter), name_func_(name_func) {}
|
||||||
|
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
|
||||||
|
auto ref = parameter_accessor_internal::Parse(input);
|
||||||
|
if (ref.name.empty()) {
|
||||||
|
absl::StrAppend(output, "INVALID_SYNTAX");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it =
|
||||||
|
name_to_param_.find(std::string(ref.name.data(), ref.name.size()));
|
||||||
|
if (it == name_to_param_.end()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconstruct access using the new name.
|
||||||
|
absl::StrAppend(output, inline_delimiter_, it->second.name);
|
||||||
|
if (!ref.index.empty()) {
|
||||||
|
absl::StrAppend(output, "[", ref.index, "]");
|
||||||
|
}
|
||||||
|
absl::StrAppend(output, ref.field, inline_delimiter_);
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return true if parameter was successfully added.
|
||||||
|
bool AddParameter(UniformParameter param) {
|
||||||
|
std::string old_name = param.name;
|
||||||
|
param.name = name_func_(old_name);
|
||||||
|
return name_to_param_.insert({old_name, std::move(param)}).second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a collection of uniform parameters with updated names.
|
||||||
|
std::vector<UniformParameter> GetUniformParameters() const {
|
||||||
|
std::vector<UniformParameter> params;
|
||||||
|
params.reserve(name_to_param_.size());
|
||||||
|
for (auto& param : name_to_param_) {
|
||||||
|
params.push_back(param.second);
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string inline_delimiter_;
|
||||||
|
const NameFunctor name_func_;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, UniformParameter> name_to_param_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Rewrites names of all objects according to returned values from the
|
||||||
|
// given NameFunctor.
|
||||||
|
class ObjectRewriter : public InlineRewrite {
|
||||||
|
public:
|
||||||
|
ObjectRewriter(const std::string& inline_delimiter,
|
||||||
|
const NameFunctor& name_func)
|
||||||
|
: inline_delimiter_(inline_delimiter), name_func_(name_func) {}
|
||||||
|
|
||||||
|
RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
|
||||||
|
// Splits 'a = b' into {'a','b'}.
|
||||||
|
std::pair<absl::string_view, absl::string_view> n =
|
||||||
|
absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace());
|
||||||
|
if (n.first.empty()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n.second.empty()) {
|
||||||
|
return RewriteRead(absl::StripAsciiWhitespace(n.first), output);
|
||||||
|
}
|
||||||
|
return RewriteWrite(absl::StripAsciiWhitespace(n.first),
|
||||||
|
absl::StripAsciiWhitespace(n.second), output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return true if an object was successfully added.
|
||||||
|
bool AddObject(const std::string& name, Object object) {
|
||||||
|
std::string new_name = name_func_(name);
|
||||||
|
return name_to_object_.insert({name, {new_name, std::move(object)}}).second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a collection of registered objects with updated names.
|
||||||
|
std::vector<std::pair<std::string, Object>> GetObjects() const {
|
||||||
|
std::vector<std::pair<std::string, Object>> objects;
|
||||||
|
objects.reserve(name_to_object_.size());
|
||||||
|
for (auto& o : name_to_object_) {
|
||||||
|
objects.push_back(o.second);
|
||||||
|
}
|
||||||
|
return objects;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RewriteStatus RewriteRead(absl::string_view location, std::string* output) {
|
||||||
|
auto element = object_accessor_internal::ParseElement(location);
|
||||||
|
if (element.object_name.empty()) {
|
||||||
|
absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
auto it = name_to_object_.find(
|
||||||
|
std::string(element.object_name.data(), element.object_name.size()));
|
||||||
|
if (it == name_to_object_.end()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
absl::StrAppend(output, inline_delimiter_, it->second.first, "[",
|
||||||
|
absl::StrJoin(element.indices, ","), "]",
|
||||||
|
inline_delimiter_);
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewriteStatus RewriteWrite(absl::string_view location,
|
||||||
|
absl::string_view value, std::string* output) {
|
||||||
|
// name[index1, index2...] = value
|
||||||
|
auto element = object_accessor_internal::ParseElement(location);
|
||||||
|
if (element.object_name.empty()) {
|
||||||
|
absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT");
|
||||||
|
return RewriteStatus::ERROR;
|
||||||
|
}
|
||||||
|
auto it = name_to_object_.find(
|
||||||
|
std::string(element.object_name.data(), element.object_name.size()));
|
||||||
|
if (it == name_to_object_.end()) {
|
||||||
|
return RewriteStatus::NOT_RECOGNIZED;
|
||||||
|
}
|
||||||
|
absl::StrAppend(output, inline_delimiter_, it->second.first, "[",
|
||||||
|
absl::StrJoin(element.indices, ","), "] = ", value,
|
||||||
|
inline_delimiter_);
|
||||||
|
return RewriteStatus::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string inline_delimiter_;
|
||||||
|
const NameFunctor name_func_;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::pair<std::string, Object>>
|
||||||
|
name_to_object_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status Rename(const NameFunctor& name_func, GeneratedCode* code) {
|
||||||
|
ParameterRewriter param_rewriter("$", name_func);
|
||||||
|
ObjectRewriter object_rewriter("$", name_func);
|
||||||
|
for (auto&& param : code->parameters) {
|
||||||
|
if (!param_rewriter.AddParameter(std::move(param))) {
|
||||||
|
return InternalError("Parameter name already exists");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto&& object : code->objects) {
|
||||||
|
if (!object_rewriter.AddObject(object.first, std::move(object.second))) {
|
||||||
|
return InternalError("Object name already exists");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TextPreprocessor preprocessor('$', /* keep_unknown_rewrites = */ true);
|
||||||
|
preprocessor.AddRewrite(¶m_rewriter);
|
||||||
|
preprocessor.AddRewrite(&object_rewriter);
|
||||||
|
std::string source_code;
|
||||||
|
RETURN_IF_ERROR(preprocessor.Rewrite(code->source_code, &source_code));
|
||||||
|
code->source_code = source_code;
|
||||||
|
code->parameters = param_rewriter.GetUniformParameters();
|
||||||
|
code->objects = object_rewriter.GetObjects();
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
41
tensorflow/lite/delegates/gpu/gl/compiler/rename.h
Normal file
41
tensorflow/lite/delegates/gpu/gl/compiler/rename.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_RENAME_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_RENAME_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Functor takes old name and returns new name.
|
||||||
|
using NameFunctor = std::function<std::string(absl::string_view name)>;
|
||||||
|
|
||||||
|
// Rewrites source code, objects and parameters with the new names supplied
|
||||||
|
// by the given functor.
|
||||||
|
Status Rename(const NameFunctor& name_func, GeneratedCode* code);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_RENAME_H_
|
68
tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h
Normal file
68
tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
struct ShaderCode {
|
||||||
|
ShaderCode() = default;
|
||||||
|
ShaderCode(const std::vector<UniformParameter>& in_parameters,
|
||||||
|
const std::vector<Object>& in_objects, const uint3& in_workload,
|
||||||
|
const uint3& in_recommended_workgroup,
|
||||||
|
const std::string& in_source_code,
|
||||||
|
const std::vector<NodeId>& in_node_indices)
|
||||||
|
: parameters(in_parameters),
|
||||||
|
objects(in_objects),
|
||||||
|
workload(in_workload),
|
||||||
|
recommended_workgroup(in_recommended_workgroup),
|
||||||
|
source_code(in_source_code),
|
||||||
|
node_indices(in_node_indices) {}
|
||||||
|
|
||||||
|
// A list of uniform parameters to be set.
|
||||||
|
std::vector<UniformParameter> parameters;
|
||||||
|
|
||||||
|
// A list of objects to bind to opengl program.
|
||||||
|
std::vector<Object> objects;
|
||||||
|
|
||||||
|
uint3 workload;
|
||||||
|
|
||||||
|
// operation may specify recommended workgroup size
|
||||||
|
uint3 recommended_workgroup;
|
||||||
|
|
||||||
|
// Generated source code does not set local size, therefore it needs to be set
|
||||||
|
// elsewhere.
|
||||||
|
std::string source_code;
|
||||||
|
|
||||||
|
// nodes of the graph that are covered by the shader.
|
||||||
|
std::vector<NodeId> node_indices;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODE_H_
|
148
tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc
Normal file
148
tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
ShaderCodegen::ShaderCodegen(const CompilationOptions& options,
|
||||||
|
const GpuInfo& gpu_info)
|
||||||
|
: options_(options), gpu_type_(gpu_info.type) {}
|
||||||
|
|
||||||
|
Status ShaderCodegen::Build(CompiledNodeAttributes attr,
|
||||||
|
ShaderCode* shader_code) const {
|
||||||
|
ParameterAccessor parameters(options_.inline_parameters);
|
||||||
|
ObjectAccessor objects(gpu_type_ == GpuType::MALI, ¶meters);
|
||||||
|
|
||||||
|
auto add_object = [&](const std::string& name, Object&& object) {
|
||||||
|
if (!objects.AddObject(name, std::forward<Object>(object))) {
|
||||||
|
return InternalError("There is an object with the same name");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto add_parameter = [&](UniformParameter&& param) {
|
||||||
|
if (!parameters.AddParameter(std::forward<UniformParameter>(param))) {
|
||||||
|
return InternalError("There is a parameter with the same name");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto&& param : attr.code.parameters) {
|
||||||
|
RETURN_IF_ERROR(add_parameter(std::move(param)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto&& object : attr.code.objects) {
|
||||||
|
RETURN_IF_ERROR(add_object(object.first, std::move(object.second)));
|
||||||
|
}
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
for (auto&& input : attr.inputs) {
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
add_object(absl::StrCat("input_data_", index++), std::move(input)));
|
||||||
|
}
|
||||||
|
index = 0;
|
||||||
|
for (auto&& output : attr.outputs) {
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
add_object(absl::StrCat("output_data_", index++), std::move(output)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(akulik): workload params need to go away and be replaced with
|
||||||
|
// output_data_0_w
|
||||||
|
RETURN_IF_ERROR(add_parameter(
|
||||||
|
{"workload_x", static_cast<int32_t>(attr.code.workload.x)}));
|
||||||
|
RETURN_IF_ERROR(add_parameter(
|
||||||
|
{"workload_y", static_cast<int32_t>(attr.code.workload.y)}));
|
||||||
|
RETURN_IF_ERROR(add_parameter(
|
||||||
|
{"workload_z", static_cast<int32_t>(attr.code.workload.z)}));
|
||||||
|
|
||||||
|
std::string source_code = R"(
|
||||||
|
ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
|
||||||
|
if (gid.x >= $workload_x$ || gid.y >= $workload_y$ || gid.z >= $workload_z$) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
switch (attr.code.input) {
|
||||||
|
case IOStructure::ONLY_DEFINITIONS:
|
||||||
|
for (int i = 0; i < attr.inputs.size(); ++i) {
|
||||||
|
absl::StrAppend(&source_code, " highp vec4 value_", i,
|
||||||
|
" = vec4(0);\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case IOStructure::AUTO: {
|
||||||
|
for (int i = 0; i < attr.inputs.size(); ++i) {
|
||||||
|
absl::StrAppend(&source_code, " highp vec4 value_", i,
|
||||||
|
" = $input_data_", i, "[gid.x, gid.y, gid.z]$;\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
source_code.append(attr.code.source_code);
|
||||||
|
|
||||||
|
if (attr.code.output == IOStructure::AUTO) {
|
||||||
|
for (int i = 0; i < attr.outputs.size(); ++i) {
|
||||||
|
absl::StrAppend(&source_code, " $output_data_", i,
|
||||||
|
"[gid.x, gid.y, gid.z] = value_", i, "$;\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point main function is already generated. Now we need to process
|
||||||
|
// object and parameter accessors.
|
||||||
|
|
||||||
|
// process objects first. Object accessor may introduce new uniform
|
||||||
|
// parameters that need to be rewritten in the subsequent pass.
|
||||||
|
{
|
||||||
|
TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
|
||||||
|
preprocessor.AddRewrite(&objects);
|
||||||
|
RETURN_IF_ERROR(preprocessor.Rewrite(source_code, &source_code));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/false);
|
||||||
|
preprocessor.AddRewrite(¶meters);
|
||||||
|
RETURN_IF_ERROR(preprocessor.Rewrite(source_code, &source_code));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options_.inline_parameters) {
|
||||||
|
source_code = absl::StrCat(parameters.GetConstDeclarations(), source_code);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string declarations = absl::StrCat(
|
||||||
|
objects.GetFunctionsDeclarations(), "\n", objects.GetObjectDeclarations(),
|
||||||
|
"\n", parameters.GetUniformDeclarations());
|
||||||
|
*shader_code = ShaderCode(
|
||||||
|
parameters.GetUniformParameters(), objects.GetObjects(),
|
||||||
|
attr.code.workload, attr.code.workgroup,
|
||||||
|
absl::StrCat("layout(std430) buffer;\nprecision ",
|
||||||
|
(options_.allow_precision_loss ? "mediump" : "highp"),
|
||||||
|
" float;\n", declarations, "\nvoid main() {\n", source_code,
|
||||||
|
"\n}"),
|
||||||
|
attr.node_indices);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
54
tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h
Normal file
54
tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODEGEN_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODEGEN_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// This class is responsible for assembling a shader by putting together
|
||||||
|
// objects, parameters declarations and main function.
|
||||||
|
class ShaderCodegen {
|
||||||
|
public:
|
||||||
|
ShaderCodegen(const CompilationOptions& options, const GpuInfo& gpu_info);
|
||||||
|
|
||||||
|
// Builds final program representation.
|
||||||
|
Status Build(CompiledNodeAttributes attr, ShaderCode* shader_code) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const CompilationOptions options_;
|
||||||
|
const GpuType gpu_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODEGEN_H_
|
68
tensorflow/lite/delegates/gpu/gl/compiler_options.h
Normal file
68
tensorflow/lite/delegates/gpu/gl/compiler_options.h
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Default constructor for options turns on all optimizations.
|
||||||
|
struct CompilationOptions {
|
||||||
|
// Allows to quantify tensors, downcast values, process in float16 etc.
|
||||||
|
bool allow_precision_loss = false;
|
||||||
|
|
||||||
|
// When set few operations are fused into a single shader. Therefore, there
|
||||||
|
// will be less shaders, but each shader will become larger.
|
||||||
|
bool fuse_operations = true;
|
||||||
|
|
||||||
|
// Parameters will be inlined into a shader. This in turn will generated more
|
||||||
|
// unique shaders where each will need to be compiled.
|
||||||
|
bool inline_parameters = false;
|
||||||
|
|
||||||
|
// If true, shaders, that have auto-input and auto-output, will use a single
|
||||||
|
// object for reading and writing.
|
||||||
|
bool inline_objects = true; // TODO(akulik): unsupported
|
||||||
|
|
||||||
|
// Can be only Textures or Buffers
|
||||||
|
ObjectType preferred_obj_type = ObjectType::UNKNOWN;
|
||||||
|
// User has an option to choose between textures and buffers. Textures work
|
||||||
|
// better on Adreno and buffers are better for Mali.
|
||||||
|
|
||||||
|
// Chooses object type to represent intermediate tensors. Buffers have more
|
||||||
|
// efficient memory usage because they represent opaque memory blob, but
|
||||||
|
// textures work better on Adreno.
|
||||||
|
// TODO(akulik): may be better name?
|
||||||
|
ObjectType ref_obj_type = ObjectType::UNKNOWN;
|
||||||
|
|
||||||
|
// If true, a user may change BATCH dimension at runtime. Otherwise, static
|
||||||
|
// batch size will be fixed during compile time.
|
||||||
|
// Dynamic mode uses less memory, while static mode may yield better
|
||||||
|
// performance for small models.
|
||||||
|
bool dynamic_batch = false;
|
||||||
|
|
||||||
|
// Fuses consequent nodes which have auto output and auto input.
|
||||||
|
bool auto_input_fusion = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_
|
102
tensorflow/lite/delegates/gpu/gl/converters/BUILD
Normal file
102
tensorflow/lite/delegates/gpu/gl/converters/BUILD
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "util",
|
||||||
|
hdrs = ["util.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "bhwc_to_phwc4",
|
||||||
|
srcs = ["bhwc_to_phwc4.cc"],
|
||||||
|
hdrs = ["bhwc_to_phwc4.h"],
|
||||||
|
deps = [
|
||||||
|
":util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:command_queue",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:uniform_parameter",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "bhwc_to_phwc4_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["bhwc_to_phwc4_test.cc"],
|
||||||
|
linkopts = [
|
||||||
|
"-lGLESv3",
|
||||||
|
"-lEGL",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":bhwc_to_phwc4",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:egl_environment",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:portable",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "phwc4_to_bhwc",
|
||||||
|
srcs = ["phwc4_to_bhwc.cc"],
|
||||||
|
hdrs = ["phwc4_to_bhwc.h"],
|
||||||
|
deps = [
|
||||||
|
":util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:command_queue",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:uniform_parameter",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "phwc4_to_bhwc_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["phwc4_to_bhwc_test.cc"],
|
||||||
|
linkopts = [
|
||||||
|
"-lGLESv3",
|
||||||
|
"-lEGL",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":phwc4_to_bhwc",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:egl_environment",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/gl:portable",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
106
tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc
Normal file
106
tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/converters/util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) {
|
||||||
|
uint3 workgroup_size = uint3(4, 4, 4);
|
||||||
|
std::string shader_source = GetShaderHeader(workgroup_size) + R"(
|
||||||
|
layout(std430) buffer;
|
||||||
|
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer B0 {
|
||||||
|
float elements[];
|
||||||
|
} input_data;
|
||||||
|
|
||||||
|
layout(binding = 1) writeonly buffer B1 {
|
||||||
|
vec4 elements[];
|
||||||
|
} output_data;
|
||||||
|
|
||||||
|
uniform ivec4 sizes_;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
|
||||||
|
if (gid.x >= sizes_.x || gid.y >= sizes_.y || gid.z >= sizes_.z) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
vec4 v = vec4(0);
|
||||||
|
int dst_channel = gid.z * 4;
|
||||||
|
int index = (gid.y * sizes_.x + gid.x) * sizes_.w + dst_channel;
|
||||||
|
for (int i = 0; i < 4; ++i, ++index, ++dst_channel) {
|
||||||
|
if (dst_channel >= sizes_.w) break;
|
||||||
|
v[i] = input_data.elements[index];
|
||||||
|
}
|
||||||
|
output_data.elements[(gid.z * sizes_.y + gid.y) * sizes_.x + gid.x] = v;
|
||||||
|
})";
|
||||||
|
|
||||||
|
GlShader shader;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, &shader));
|
||||||
|
GlProgram program;
|
||||||
|
RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program));
|
||||||
|
*converter = ConverterBhwcToPhwc4(std::move(program), workgroup_size);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConverterBhwcToPhwc4::Convert(const BHWC& shape, const GlBuffer& source,
|
||||||
|
CommandQueue* command_queue,
|
||||||
|
GlBuffer* destination) {
|
||||||
|
if (source.bytes_size() < BytesForBHWC(shape)) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"BhwcToPhwc4: Input data size does not match expected size.");
|
||||||
|
}
|
||||||
|
if (destination->bytes_size() < BytesForPHWC4(shape)) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"BhwcToPhwc4: output data size does not match expected size.");
|
||||||
|
}
|
||||||
|
if (shape.b != 1) {
|
||||||
|
return UnimplementedError("BhwcToPhwc4: Batch size is not equal to 1.");
|
||||||
|
}
|
||||||
|
uint3 workload = uint3(shape.w, shape.h, shape.c);
|
||||||
|
uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_);
|
||||||
|
|
||||||
|
RETURN_IF_ERROR(program_.SetParameter(UniformParameter{
|
||||||
|
"sizes_",
|
||||||
|
int4(static_cast<int32_t>(workload.x), static_cast<int32_t>(workload.y),
|
||||||
|
static_cast<int32_t>(workload.z), static_cast<int32_t>(shape.c))}));
|
||||||
|
RETURN_IF_ERROR(source.BindToIndex(0));
|
||||||
|
RETURN_IF_ERROR(destination->BindToIndex(1));
|
||||||
|
if (command_queue) {
|
||||||
|
return command_queue->Dispatch(program_, num_workgroups);
|
||||||
|
}
|
||||||
|
return program_.Dispatch(num_workgroups);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
53
tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h
Normal file
53
tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_BHWC_TO_PHWC4_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_BHWC_TO_PHWC4_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
class ConverterBhwcToPhwc4 {
|
||||||
|
public:
|
||||||
|
// Creates invalid object.
|
||||||
|
ConverterBhwcToPhwc4() : program_(), workgroup_size_() {}
|
||||||
|
|
||||||
|
static Status Create(ConverterBhwcToPhwc4* converter);
|
||||||
|
|
||||||
|
Status Convert(const BHWC& shape, const GlBuffer& source,
|
||||||
|
CommandQueue* command_queue /* optional */,
|
||||||
|
GlBuffer* destination);
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit ConverterBhwcToPhwc4(GlProgram program, const uint3& workgroup_size)
|
||||||
|
: program_(std::move(program)), workgroup_size_(workgroup_size) {}
|
||||||
|
|
||||||
|
GlProgram program_;
|
||||||
|
uint3 workgroup_size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_BHWC_TO_PHWC4_H_
|
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline std::vector<float> GenerateFloats(float multiplier, int size) {
|
||||||
|
std::vector<float> v(size);
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
v[i] = multiplier * i * (i % 2 == 0 ? -1 : 1);
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RunTest(const BHWC& shape) {
|
||||||
|
// Create random input and calculate expected output for it.
|
||||||
|
std::vector<float> input = GenerateFloats(0.01, shape.DimensionsProduct());
|
||||||
|
std::vector<float> output(GetElementsSizeForPHWC4(shape), 0);
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
ConvertToPHWC4(absl::MakeConstSpan(input.data(), input.size()), shape,
|
||||||
|
absl::MakeSpan(output.data(), output.size())));
|
||||||
|
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env));
|
||||||
|
|
||||||
|
// Create input and output buffers
|
||||||
|
GlBuffer input_buffer;
|
||||||
|
RETURN_IF_ERROR(CreateReadOnlyShaderStorageBuffer(
|
||||||
|
absl::MakeConstSpan(input.data(), input.size()), &input_buffer));
|
||||||
|
|
||||||
|
GlBuffer output_buffer;
|
||||||
|
RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
|
||||||
|
GetElementsSizeForPHWC4(shape), &output_buffer));
|
||||||
|
|
||||||
|
// Create converter and run it.
|
||||||
|
ConverterBhwcToPhwc4 converter;
|
||||||
|
RETURN_IF_ERROR(ConverterBhwcToPhwc4::Create(&converter));
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
converter.Convert(shape, input_buffer, nullptr, &output_buffer));
|
||||||
|
|
||||||
|
std::vector<float> converted_output(output.size(), 0);
|
||||||
|
RETURN_IF_ERROR(output_buffer.Read(
|
||||||
|
absl::MakeSpan(converted_output.data(), converted_output.size())));
|
||||||
|
if (output != converted_output) {
|
||||||
|
return InternalError("Outputs don't match");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HwcToPhwc4, Smoke) {
|
||||||
|
for (int32_t h : {1, 2, 3, 7, 20}) {
|
||||||
|
for (int32_t w : {1, 2, 4, 5, 11}) {
|
||||||
|
for (int32_t c : {1, 2, 4, 5, 8, 9}) {
|
||||||
|
BHWC shape(1, h, w, c);
|
||||||
|
EXPECT_TRUE(RunTest(shape).ok())
|
||||||
|
<< shape.h << " " << shape.w << " " << shape.c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
102
tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc
Normal file
102
tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/converters/util.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) {
|
||||||
|
uint3 workgroup_size = uint3(4, 4, 4);
|
||||||
|
std::string shader_source = GetShaderHeader(workgroup_size) + R"(
|
||||||
|
layout(std430) buffer;
|
||||||
|
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer B0 {
|
||||||
|
vec4 elements[];
|
||||||
|
} input_data;
|
||||||
|
|
||||||
|
layout(binding = 1) writeonly buffer B1 {
|
||||||
|
float elements[];
|
||||||
|
} output_data;
|
||||||
|
|
||||||
|
uniform ivec4 sizes_;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
|
||||||
|
if (gid.x >= sizes_.x || gid.y >= sizes_.y || gid.z >= sizes_.z) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
output_data.elements[(gid.y * sizes_.x + gid.x) * sizes_.z + gid.z] = input_data.elements[(gid.z / 4 * sizes_.y + gid.y) * sizes_.x + gid.x][gid.z % 4];
|
||||||
|
})";
|
||||||
|
|
||||||
|
GlShader shader;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, &shader));
|
||||||
|
GlProgram program;
|
||||||
|
RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program));
|
||||||
|
*converter = ConverterPhwc4ToBhwc(std::move(program), workgroup_size);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConverterPhwc4ToBhwc::Convert(const BHWC& shape, const GlBuffer& source,
|
||||||
|
CommandQueue* command_queue,
|
||||||
|
GlBuffer* destination) {
|
||||||
|
if (source.bytes_size() < BytesForPHWC4(shape)) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Phwc4ToBhwc: Input data size does not match expected size.");
|
||||||
|
}
|
||||||
|
if (destination->bytes_size() < BytesForBHWC(shape)) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Phwc4ToBhwc: output data size does not match expected size.");
|
||||||
|
}
|
||||||
|
if (shape.b != 1) {
|
||||||
|
return UnimplementedError("Phwc4ToBhwc: Batch size is not equal to 1.");
|
||||||
|
}
|
||||||
|
|
||||||
|
uint3 workload = uint3(shape.w, shape.h, shape.c);
|
||||||
|
uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_);
|
||||||
|
|
||||||
|
// TODO(akulik): simply pass workload as soon as UniformParameter
|
||||||
|
// supports uint3
|
||||||
|
RETURN_IF_ERROR(program_.SetParameter(UniformParameter{
|
||||||
|
"sizes_",
|
||||||
|
int4(static_cast<int32_t>(workload.x), static_cast<int32_t>(workload.y),
|
||||||
|
static_cast<int32_t>(workload.z), 0)}));
|
||||||
|
RETURN_IF_ERROR(source.BindToIndex(0));
|
||||||
|
RETURN_IF_ERROR(destination->BindToIndex(1));
|
||||||
|
if (command_queue) {
|
||||||
|
return command_queue->Dispatch(program_, num_workgroups);
|
||||||
|
}
|
||||||
|
return program_.Dispatch(num_workgroups);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
53
tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h
Normal file
53
tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_PHWC4_TO_BHWC_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_PHWC4_TO_BHWC_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
class ConverterPhwc4ToBhwc {
|
||||||
|
public:
|
||||||
|
// Creates invalid object.
|
||||||
|
ConverterPhwc4ToBhwc() : program_(), workgroup_size_() {}
|
||||||
|
|
||||||
|
static Status Create(ConverterPhwc4ToBhwc* converter);
|
||||||
|
|
||||||
|
Status Convert(const BHWC& shape, const GlBuffer& source,
|
||||||
|
CommandQueue* command_queue /* optional */,
|
||||||
|
GlBuffer* destination);
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit ConverterPhwc4ToBhwc(GlProgram program, const uint3& workgroup_size)
|
||||||
|
: program_(std::move(program)), workgroup_size_(workgroup_size) {}
|
||||||
|
|
||||||
|
GlProgram program_;
|
||||||
|
uint3 workgroup_size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_PHWC4_TO_BHWC_H_
|
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline std::vector<float> GenerateFloats(float multiplier, int size) {
|
||||||
|
std::vector<float> v(size);
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
v[i] = multiplier * i * (i % 2 == 0 ? -1 : 1);
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RunTest(const BHWC& shape) {
|
||||||
|
// Create random input and calculate expected output for it.
|
||||||
|
std::vector<float> input =
|
||||||
|
GenerateFloats(0.01, GetElementsSizeForPHWC4(shape));
|
||||||
|
std::vector<float> output(shape.DimensionsProduct(), 0);
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
ConvertFromPHWC4(absl::MakeConstSpan(input.data(), input.size()), shape,
|
||||||
|
absl::MakeSpan(output.data(), output.size())));
|
||||||
|
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env));
|
||||||
|
|
||||||
|
// Create input and output buffers
|
||||||
|
GlBuffer input_buffer;
|
||||||
|
RETURN_IF_ERROR(CreateReadOnlyShaderStorageBuffer(
|
||||||
|
absl::MakeConstSpan(input.data(), input.size()), &input_buffer));
|
||||||
|
|
||||||
|
GlBuffer output_buffer;
|
||||||
|
RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
|
||||||
|
shape.DimensionsProduct(), &output_buffer));
|
||||||
|
|
||||||
|
// Create converter and run it.
|
||||||
|
ConverterPhwc4ToBhwc converter;
|
||||||
|
RETURN_IF_ERROR(ConverterPhwc4ToBhwc::Create(&converter));
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
converter.Convert(shape, input_buffer, nullptr, &output_buffer));
|
||||||
|
|
||||||
|
std::vector<float> converted_output(output.size(), 0);
|
||||||
|
RETURN_IF_ERROR(output_buffer.Read(
|
||||||
|
absl::MakeSpan(converted_output.data(), converted_output.size())));
|
||||||
|
if (output != converted_output) {
|
||||||
|
return InternalError("Outputs don't match");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Phwc4ToHwc, Smoke) {
|
||||||
|
for (int32_t h : {1, 2, 3, 7, 20}) {
|
||||||
|
for (int32_t w : {1, 2, 4, 5, 11}) {
|
||||||
|
for (int32_t c : {1, 2, 4, 5, 8, 9}) {
|
||||||
|
BHWC shape(1, h, w, c);
|
||||||
|
EXPECT_TRUE(RunTest(shape).ok())
|
||||||
|
<< shape.h << " " << shape.w << " " << shape.c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
49
tensorflow/lite/delegates/gpu/gl/converters/util.h
Normal file
49
tensorflow/lite/delegates/gpu/gl/converters/util.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_UTIL_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_UTIL_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
inline std::string GetShaderHeader(const uint3& localsize) {
|
||||||
|
return absl::StrCat("#version 310 es\nlayout(local_size_x = ", localsize.x,
|
||||||
|
", local_size_y = ", localsize.y,
|
||||||
|
", local_size_z = ", localsize.z, ") in;\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint32_t BytesForPHWC4(const BHWC& shape) {
|
||||||
|
return shape.b * shape.h * shape.w * AlignByN(shape.c, 4) * sizeof(float);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint32_t BytesForBHWC(const BHWC& shape) {
|
||||||
|
return shape.DimensionsProduct() * sizeof(float);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_UTIL_H_
|
143
tensorflow/lite/delegates/gpu/gl/egl_context.cc
Normal file
143
tensorflow/lite/delegates/gpu/gl/egl_context.cc
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_context.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status GetConfig(EGLDisplay display, const EGLint* attributes,
|
||||||
|
EGLConfig* config) {
|
||||||
|
EGLint config_count;
|
||||||
|
bool chosen = eglChooseConfig(display, attributes, config, 1, &config_count);
|
||||||
|
RETURN_IF_ERROR(GetOpenGlErrors());
|
||||||
|
if (!chosen || config_count == 0) {
|
||||||
|
return InternalError("No EGL error, but eglChooseConfig failed.");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CreateContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EGLConfig config, EglContext* egl_context) {
|
||||||
|
static const EGLint attributes[] = {EGL_CONTEXT_CLIENT_VERSION, 3,
|
||||||
|
#ifdef _DEBUG // Add debugging bit
|
||||||
|
EGL_CONTEXT_FLAGS_KHR,
|
||||||
|
EGL_CONTEXT_OPENGL_DEBUG_BIT_KHR,
|
||||||
|
#endif
|
||||||
|
EGL_NONE};
|
||||||
|
EGLContext context =
|
||||||
|
eglCreateContext(display, config, shared_context, attributes);
|
||||||
|
RETURN_IF_ERROR(GetOpenGlErrors());
|
||||||
|
if (context == EGL_NO_CONTEXT) {
|
||||||
|
return InternalError("No EGL error, but eglCreateContext failed.");
|
||||||
|
}
|
||||||
|
*egl_context = EglContext(context, display, config);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasExtension(EGLDisplay display, const char* name) {
|
||||||
|
return strstr(eglQueryString(display, EGL_EXTENSIONS), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void EglContext::Invalidate() {
|
||||||
|
if (context_ != EGL_NO_CONTEXT) {
|
||||||
|
eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||||
|
eglDestroyContext(display_, context_);
|
||||||
|
context_ = EGL_NO_CONTEXT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EglContext::EglContext(EglContext&& other)
|
||||||
|
: context_(other.context_),
|
||||||
|
display_(other.display_),
|
||||||
|
config_(other.config_) {
|
||||||
|
other.context_ = EGL_NO_CONTEXT;
|
||||||
|
}
|
||||||
|
|
||||||
|
EglContext& EglContext::operator=(EglContext&& other) {
|
||||||
|
if (this != &other) {
|
||||||
|
Invalidate();
|
||||||
|
std::swap(context_, other.context_);
|
||||||
|
display_ = other.display_;
|
||||||
|
config_ = other.config_;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EglContext::MakeCurrent(EGLSurface read, EGLSurface write) {
|
||||||
|
bool is_made_current = eglMakeCurrent(display_, write, read, context_);
|
||||||
|
RETURN_IF_ERROR(GetOpenGlErrors());
|
||||||
|
if (!is_made_current) {
|
||||||
|
return InternalError("No EGL error, but eglMakeCurrent failed.");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EglContext::IsCurrent() const {
|
||||||
|
return context_ == eglGetCurrentContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EglContext* egl_context) {
|
||||||
|
if (!HasExtension(display, "EGL_KHR_no_config_context")) {
|
||||||
|
return UnavailableError("EGL_KHR_no_config_context not supported");
|
||||||
|
}
|
||||||
|
return CreateContext(display, shared_context, EGL_NO_CONFIG_KHR, egl_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EglContext* egl_context) {
|
||||||
|
if (!HasExtension(display, "EGL_KHR_create_context")) {
|
||||||
|
return UnavailableError("EGL_KHR_create_context not supported");
|
||||||
|
}
|
||||||
|
if (!HasExtension(display, "EGL_KHR_surfaceless_context")) {
|
||||||
|
return UnavailableError("EGL_KHR_surfaceless_context not supported");
|
||||||
|
}
|
||||||
|
const EGLint attributes[] = {EGL_RENDERABLE_TYPE, EGL_OPENGL_ES3_BIT_KHR,
|
||||||
|
EGL_NONE};
|
||||||
|
EGLConfig config;
|
||||||
|
RETURN_IF_ERROR(GetConfig(display, attributes, &config));
|
||||||
|
return CreateContext(display, shared_context, config, egl_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EglContext* egl_context) {
|
||||||
|
const EGLint attributes[] = {EGL_SURFACE_TYPE,
|
||||||
|
EGL_PBUFFER_BIT,
|
||||||
|
EGL_BLUE_SIZE,
|
||||||
|
8,
|
||||||
|
EGL_GREEN_SIZE,
|
||||||
|
8,
|
||||||
|
EGL_RED_SIZE,
|
||||||
|
8,
|
||||||
|
EGL_RENDERABLE_TYPE,
|
||||||
|
EGL_OPENGL_ES3_BIT_KHR,
|
||||||
|
EGL_NONE};
|
||||||
|
EGLConfig config;
|
||||||
|
RETURN_IF_ERROR(GetConfig(display, attributes, &config));
|
||||||
|
return CreateContext(display, shared_context, config, egl_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
93
tensorflow/lite/delegates/gpu/gl/egl_context.h
Normal file
93
tensorflow/lite/delegates/gpu/gl/egl_context.h
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_CONTEXT_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_CONTEXT_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// EglContext is an RAII wrapper for an EGLContext.
|
||||||
|
//
|
||||||
|
// EglContext is moveable but not copyable.
|
||||||
|
//
|
||||||
|
// See https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglIntro.xhtml for
|
||||||
|
// more info.
|
||||||
|
class EglContext {
|
||||||
|
public:
|
||||||
|
// Creates an invalid EglContext.
|
||||||
|
EglContext()
|
||||||
|
: context_(EGL_NO_CONTEXT),
|
||||||
|
display_(EGL_NO_DISPLAY),
|
||||||
|
config_(EGL_NO_CONFIG_KHR) {}
|
||||||
|
|
||||||
|
EglContext(EGLContext context, EGLDisplay display, EGLConfig config)
|
||||||
|
: context_(context), display_(display), config_(config) {}
|
||||||
|
|
||||||
|
// Move only
|
||||||
|
EglContext(EglContext&& other);
|
||||||
|
EglContext& operator=(EglContext&& other);
|
||||||
|
EglContext(const EglContext&) = delete;
|
||||||
|
EglContext& operator=(const EglContext&) = delete;
|
||||||
|
|
||||||
|
~EglContext() { Invalidate(); }
|
||||||
|
|
||||||
|
EGLContext context() const { return context_; }
|
||||||
|
|
||||||
|
EGLDisplay display() const { return display_; }
|
||||||
|
|
||||||
|
EGLConfig config() const { return config_; }
|
||||||
|
|
||||||
|
// Make this EglContext the current EGL context on this thread, replacing
|
||||||
|
// the existing current.
|
||||||
|
Status MakeCurrent(EGLSurface read, EGLSurface write);
|
||||||
|
|
||||||
|
Status MakeCurrentSurfaceless() {
|
||||||
|
return MakeCurrent(EGL_NO_SURFACE, EGL_NO_SURFACE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if this is the currently bound EGL context.
|
||||||
|
bool IsCurrent() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Invalidate();
|
||||||
|
|
||||||
|
EGLContext context_;
|
||||||
|
EGLDisplay display_;
|
||||||
|
EGLConfig config_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// It uses the EGL_KHR_no_config_context extension to create a no config context
|
||||||
|
// since most modern hardware supports the extension.
|
||||||
|
Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EglContext* egl_context);
|
||||||
|
|
||||||
|
Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EglContext* egl_context);
|
||||||
|
|
||||||
|
Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context,
|
||||||
|
EglContext* egl_context);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_CONTEXT_H_
|
149
tensorflow/lite/delegates/gpu/gl/egl_environment.cc
Normal file
149
tensorflow/lite/delegates/gpu/gl/egl_environment.cc
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(akulik): detect power management event when all contexts are destroyed
|
||||||
|
// and OpenGL ES is reinitialized. See eglMakeCurrent
|
||||||
|
|
||||||
|
Status InitDisplay(EGLDisplay* egl_display) {
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
TFLITE_GPU_CALL_EGL(eglGetDisplay, egl_display, EGL_DEFAULT_DISPLAY));
|
||||||
|
if (*egl_display == EGL_NO_DISPLAY) {
|
||||||
|
return UnavailableError("eglGetDisplay returned nullptr");
|
||||||
|
}
|
||||||
|
bool is_initialized;
|
||||||
|
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglInitialize, &is_initialized,
|
||||||
|
*egl_display, nullptr, nullptr));
|
||||||
|
if (!is_initialized) {
|
||||||
|
return InternalError("No EGL error, but eglInitialize failed");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status EglEnvironment::NewEglEnvironment(
|
||||||
|
std::unique_ptr<EglEnvironment>* egl_environment) {
|
||||||
|
*egl_environment = absl::make_unique<EglEnvironment>();
|
||||||
|
RETURN_IF_ERROR((*egl_environment)->Init());
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
EglEnvironment::~EglEnvironment() {
|
||||||
|
if (dummy_framebuffer_ != GL_INVALID_INDEX) {
|
||||||
|
glDeleteFramebuffers(1, &dummy_framebuffer_);
|
||||||
|
}
|
||||||
|
if (dummy_texture_ != GL_INVALID_INDEX) {
|
||||||
|
glDeleteTextures(1, &dummy_texture_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EglEnvironment::Init() {
|
||||||
|
bool is_bound;
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
TFLITE_GPU_CALL_EGL(eglBindAPI, &is_bound, EGL_OPENGL_ES_API));
|
||||||
|
if (!is_bound) {
|
||||||
|
return InternalError("No EGL error, but eglBindAPI failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-use context and display if it was created on this thread.
|
||||||
|
if (eglGetCurrentContext() != EGL_NO_CONTEXT) {
|
||||||
|
display_ = eglGetCurrentDisplay();
|
||||||
|
context_ = EglContext(eglGetCurrentContext(), display_, EGL_NO_CONFIG_KHR);
|
||||||
|
} else {
|
||||||
|
RETURN_IF_ERROR(InitDisplay(&display_));
|
||||||
|
|
||||||
|
Status status = InitConfiglessContext();
|
||||||
|
if (!status.ok()) {
|
||||||
|
status = InitSurfacelessContext();
|
||||||
|
}
|
||||||
|
if (!status.ok()) {
|
||||||
|
status = InitPBufferContext();
|
||||||
|
}
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gpu_info_.type == GpuType::UNKNOWN) {
|
||||||
|
RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_));
|
||||||
|
}
|
||||||
|
// TODO(akulik): when do we need ForceSyncTurning?
|
||||||
|
ForceSyncTurning();
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EglEnvironment::InitConfiglessContext() {
|
||||||
|
RETURN_IF_ERROR(CreateConfiglessContext(display_, EGL_NO_CONTEXT, &context_));
|
||||||
|
return context_.MakeCurrentSurfaceless();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EglEnvironment::InitSurfacelessContext() {
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
CreateSurfacelessContext(display_, EGL_NO_CONTEXT, &context_));
|
||||||
|
Status status = context_.MakeCurrentSurfaceless();
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
// PowerVR support EGL_KHR_surfaceless_context, but glFenceSync crashes on
|
||||||
|
// PowerVR when it is surface-less.
|
||||||
|
RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_));
|
||||||
|
if (gpu_info_.type == GpuType::POWERVR) {
|
||||||
|
return UnavailableError(
|
||||||
|
"Surface-less context is not properly supported on powervr.");
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EglEnvironment::InitPBufferContext() {
|
||||||
|
RETURN_IF_ERROR(CreatePBufferContext(display_, EGL_NO_CONTEXT, &context_));
|
||||||
|
RETURN_IF_ERROR(CreatePbufferRGBSurface(context_.config(), display_, 1, 1,
|
||||||
|
&surface_read_));
|
||||||
|
RETURN_IF_ERROR(CreatePbufferRGBSurface(context_.config(), display_, 1, 1,
|
||||||
|
&surface_draw_));
|
||||||
|
return context_.MakeCurrent(surface_read_.surface(), surface_draw_.surface());
|
||||||
|
}
|
||||||
|
|
||||||
|
void EglEnvironment::ForceSyncTurning() {
|
||||||
|
glGenFramebuffers(1, &dummy_framebuffer_);
|
||||||
|
glBindFramebuffer(GL_FRAMEBUFFER, dummy_framebuffer_);
|
||||||
|
|
||||||
|
glGenTextures(1, &dummy_texture_);
|
||||||
|
glBindTexture(GL_TEXTURE_2D, dummy_texture_);
|
||||||
|
glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA8, 4, 4);
|
||||||
|
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
|
||||||
|
dummy_texture_, 0);
|
||||||
|
|
||||||
|
GLenum draw_buffers[1] = {GL_COLOR_ATTACHMENT0};
|
||||||
|
glDrawBuffers(1, draw_buffers);
|
||||||
|
|
||||||
|
glViewport(0, 0, 4, 4);
|
||||||
|
glClear(GL_COLOR_BUFFER_BIT);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
72
tensorflow/lite/delegates/gpu/gl/egl_environment.h
Normal file
72
tensorflow/lite/delegates/gpu/gl/egl_environment.h
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_ENVIRONMENT_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_ENVIRONMENT_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_context.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_surface.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Class encapsulates creation of OpenGL objects needed before starting working
|
||||||
|
// with OpenGL: binds OpenGL ES API, creates new EGL context, binds it to EGL
|
||||||
|
// display and creates surfaces if needed.
|
||||||
|
//
|
||||||
|
// EGL environment needs to be created once per thread.
|
||||||
|
class EglEnvironment {
|
||||||
|
public:
|
||||||
|
static Status NewEglEnvironment(
|
||||||
|
std::unique_ptr<EglEnvironment>* egl_environment);
|
||||||
|
|
||||||
|
EglEnvironment() = default;
|
||||||
|
~EglEnvironment();
|
||||||
|
|
||||||
|
const EglContext& context() const { return context_; }
|
||||||
|
EGLDisplay display() const { return display_; }
|
||||||
|
const GpuInfo& gpu_info() const { return gpu_info_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status Init();
|
||||||
|
Status InitConfiglessContext();
|
||||||
|
Status InitSurfacelessContext();
|
||||||
|
Status InitPBufferContext();
|
||||||
|
|
||||||
|
EGLDisplay display_ = EGL_NO_DISPLAY;
|
||||||
|
EglContext context_;
|
||||||
|
EglSurface surface_draw_;
|
||||||
|
EglSurface surface_read_;
|
||||||
|
GpuInfo gpu_info_;
|
||||||
|
|
||||||
|
// Strange hack that helps on Mali GPUs
|
||||||
|
// without it glFinish and glFenceSync don't work
|
||||||
|
void ForceSyncTurning();
|
||||||
|
GLuint dummy_framebuffer_ = GL_INVALID_INDEX;
|
||||||
|
GLuint dummy_texture_ = GL_INVALID_INDEX;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_ENVIRONMENT_H_
|
71
tensorflow/lite/delegates/gpu/gl/egl_surface.cc
Normal file
71
tensorflow/lite/delegates/gpu/gl/egl_surface.cc
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_surface.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
EglSurface::EglSurface(EglSurface&& other)
|
||||||
|
: surface_(other.surface_), display_(other.display_) {
|
||||||
|
other.surface_ = EGL_NO_SURFACE;
|
||||||
|
}
|
||||||
|
|
||||||
|
EglSurface& EglSurface::operator=(EglSurface&& other) {
|
||||||
|
if (this != &other) {
|
||||||
|
display_ = other.display_;
|
||||||
|
Invalidate();
|
||||||
|
std::swap(surface_, other.surface_);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EglSurface::Invalidate() {
|
||||||
|
if (surface_ != EGL_NO_SURFACE) {
|
||||||
|
eglDestroySurface(display_, surface_);
|
||||||
|
surface_ = EGL_NO_SURFACE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display,
|
||||||
|
uint32_t height, uint32_t width,
|
||||||
|
EglSurface* egl_surface) {
|
||||||
|
const EGLint pbuffer_attributes[] = {EGL_WIDTH,
|
||||||
|
static_cast<EGLint>(width),
|
||||||
|
EGL_HEIGHT,
|
||||||
|
static_cast<EGLint>(height),
|
||||||
|
EGL_TEXTURE_FORMAT,
|
||||||
|
EGL_TEXTURE_RGB,
|
||||||
|
EGL_TEXTURE_TARGET,
|
||||||
|
EGL_TEXTURE_2D,
|
||||||
|
EGL_NONE};
|
||||||
|
EGLSurface surface =
|
||||||
|
eglCreatePbufferSurface(display, config, pbuffer_attributes);
|
||||||
|
RETURN_IF_ERROR(GetOpenGlErrors());
|
||||||
|
if (surface == EGL_NO_SURFACE) {
|
||||||
|
return InternalError("No EGL error, but eglCreatePbufferSurface failed");
|
||||||
|
}
|
||||||
|
*egl_surface = EglSurface(surface, display);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
67
tensorflow/lite/delegates/gpu/gl/egl_surface.h
Normal file
67
tensorflow/lite/delegates/gpu/gl/egl_surface.h
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_SURFACE_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_SURFACE_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// An RAII wrapper for EGLSurface.
|
||||||
|
// See https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglIntro.xhtml for
|
||||||
|
// an introduction to the concepts.
|
||||||
|
//
|
||||||
|
// EglSurface is moveable but not copyable.
|
||||||
|
class EglSurface {
|
||||||
|
public:
|
||||||
|
// Creates an invalid EglSurface.
|
||||||
|
EglSurface() : surface_(EGL_NO_SURFACE), display_(EGL_NO_DISPLAY) {}
|
||||||
|
|
||||||
|
EglSurface(EGLSurface surface, EGLDisplay display)
|
||||||
|
: surface_(surface), display_(display) {}
|
||||||
|
|
||||||
|
// Move-only
|
||||||
|
EglSurface(EglSurface&& other);
|
||||||
|
EglSurface& operator=(EglSurface&& other);
|
||||||
|
EglSurface(const EglSurface&) = delete;
|
||||||
|
EglSurface& operator=(const EglSurface&) = delete;
|
||||||
|
|
||||||
|
~EglSurface() { Invalidate(); }
|
||||||
|
|
||||||
|
EGLSurface surface() const { return surface_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Invalidate();
|
||||||
|
|
||||||
|
EGLSurface surface_;
|
||||||
|
EGLDisplay display_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates off-screen pbuffer-based surface of the given height and width.
|
||||||
|
Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display,
|
||||||
|
uint32_t height, uint32_t width,
|
||||||
|
EglSurface* egl_surface);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_SURFACE_H_
|
73
tensorflow/lite/delegates/gpu/gl/float16_conversions.cc
Normal file
73
tensorflow/lite/delegates/gpu/gl/float16_conversions.cc
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/float16_conversions.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <fp16.h>
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Performs in-place conversion of float32 into float16
|
||||||
|
bool ToFloat16(std::vector<uint8_t>* values) {
|
||||||
|
if (values->size() % sizeof(float) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t* store_f16 = reinterpret_cast<uint16_t*>(values->data());
|
||||||
|
const float* load_f32 = reinterpret_cast<const float*>(values->data());
|
||||||
|
const float* end_load_f32 =
|
||||||
|
reinterpret_cast<const float*>(values->data() + values->size());
|
||||||
|
|
||||||
|
while (load_f32 != end_load_f32) {
|
||||||
|
*store_f16++ = fp16_ieee_from_fp32_value(*load_f32++);
|
||||||
|
}
|
||||||
|
|
||||||
|
values->resize(values->size() / 2);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConverterToFloat16 {
|
||||||
|
bool operator()(ObjectData& data) const { // NOLINT
|
||||||
|
return ToFloat16(&data);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator()(ObjectRef& buffer) const { // NOLINT
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool MaybeConvertToFloat16(Object* object) {
|
||||||
|
if (object->data_type == DataType::FLOAT32 &&
|
||||||
|
absl::visit(ConverterToFloat16(), object->object)) {
|
||||||
|
object->data_type = DataType::FLOAT16;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
32
tensorflow/lite/delegates/gpu/gl/float16_conversions.h
Normal file
32
tensorflow/lite/delegates/gpu/gl/float16_conversions.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_FLOAT16_CONVERSIONS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_FLOAT16_CONVERSIONS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/object.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// If an object is float32, converts it to float16 representation.
|
||||||
|
bool MaybeConvertToFloat16(Object* object);
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_FLOAT16_CONVERSIONS_H_
|
89
tensorflow/lite/delegates/gpu/gl/gl_buffer.cc
Normal file
89
tensorflow/lite/delegates/gpu/gl/gl_buffer.cc
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer) {
|
||||||
|
if (read_buffer.bytes_size() != write_buffer.bytes_size()) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Read buffer does not match write buffer size.");
|
||||||
|
}
|
||||||
|
gl_buffer_internal::BufferBinder read_buffer_binder(GL_COPY_READ_BUFFER,
|
||||||
|
read_buffer.id());
|
||||||
|
gl_buffer_internal::BufferBinder write_buffer_binder(GL_COPY_WRITE_BUFFER,
|
||||||
|
write_buffer.id());
|
||||||
|
return TFLITE_GPU_CALL_GL(glCopyBufferSubData, GL_COPY_READ_BUFFER,
|
||||||
|
GL_COPY_WRITE_BUFFER, read_buffer.offset(),
|
||||||
|
write_buffer.offset(), read_buffer.bytes_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
GlBuffer::GlBuffer(GlBuffer&& buffer)
|
||||||
|
: GlBuffer(buffer.target_, buffer.id_, buffer.bytes_size_, buffer.offset_,
|
||||||
|
buffer.has_ownership_) {
|
||||||
|
buffer.has_ownership_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
GlBuffer& GlBuffer::operator=(GlBuffer&& buffer) {
|
||||||
|
if (this != &buffer) {
|
||||||
|
Invalidate();
|
||||||
|
|
||||||
|
target_ = buffer.target_;
|
||||||
|
bytes_size_ = buffer.bytes_size_;
|
||||||
|
offset_ = buffer.offset_;
|
||||||
|
has_ownership_ = buffer.has_ownership_;
|
||||||
|
id_ = buffer.id_;
|
||||||
|
buffer.has_ownership_ = false;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
GlBuffer::~GlBuffer() { Invalidate(); }
|
||||||
|
|
||||||
|
void GlBuffer::Invalidate() {
|
||||||
|
if (has_ownership_ && id_ != GL_INVALID_INDEX) {
|
||||||
|
TFLITE_GPU_CALL_GL(glDeleteBuffers, 1, &id_).IgnoreError();
|
||||||
|
id_ = GL_INVALID_INDEX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GlBuffer::BindToIndex(uint32_t index) const {
|
||||||
|
return TFLITE_GPU_CALL_GL(glBindBufferRange, target_, index, id_, offset_,
|
||||||
|
bytes_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GlBuffer::MakeView(size_t offset, size_t bytes_size,
|
||||||
|
GlBuffer* gl_buffer) {
|
||||||
|
if (offset + bytes_size > bytes_size_) {
|
||||||
|
return OutOfRangeError("GlBuffer view is out of range.");
|
||||||
|
}
|
||||||
|
*gl_buffer = GlBuffer(target_, id_, bytes_size, offset_ + offset,
|
||||||
|
/*has_ownership=*/false);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
GlBuffer GlBuffer::MakeRef() {
|
||||||
|
return GlBuffer(target_, id_, bytes_size_, offset_,
|
||||||
|
/* has_ownership = */ false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
298
tensorflow/lite/delegates/gpu/gl/gl_buffer.h
Normal file
298
tensorflow/lite/delegates/gpu/gl/gl_buffer.h
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_BUFFER_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_BUFFER_H_
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Buffer is an RAII wrapper for OpenGL buffer object.
|
||||||
|
// See https://www.khronos.org/opengl/wiki/Buffer_Object for more information.
|
||||||
|
//
|
||||||
|
// Buffer is moveable but not copyable.
|
||||||
|
class GlBuffer {
|
||||||
|
public:
|
||||||
|
// @param has_ownership indicates that GlBuffer is responsible for
|
||||||
|
// corresponding GL buffer deletion.
|
||||||
|
GlBuffer(GLenum target, GLuint id, size_t bytes_size, size_t offset,
|
||||||
|
bool has_ownership)
|
||||||
|
: target_(target),
|
||||||
|
id_(id),
|
||||||
|
bytes_size_(bytes_size),
|
||||||
|
offset_(offset),
|
||||||
|
has_ownership_(has_ownership) {}
|
||||||
|
|
||||||
|
// Creates invalid buffer.
|
||||||
|
GlBuffer() : GlBuffer(GL_INVALID_ENUM, GL_INVALID_INDEX, 0, 0, false) {}
|
||||||
|
|
||||||
|
// Move-only
|
||||||
|
GlBuffer(GlBuffer&& buffer);
|
||||||
|
GlBuffer& operator=(GlBuffer&& buffer);
|
||||||
|
GlBuffer(const GlBuffer&) = delete;
|
||||||
|
GlBuffer& operator=(const GlBuffer&) = delete;
|
||||||
|
|
||||||
|
~GlBuffer();
|
||||||
|
|
||||||
|
// Reads data from buffer into CPU memory. Data should point to a region that
|
||||||
|
// has at least bytes_size available.
|
||||||
|
template <typename T>
|
||||||
|
Status Read(absl::Span<T> data) const;
|
||||||
|
|
||||||
|
// Writes data to a buffer.
|
||||||
|
template <typename T>
|
||||||
|
Status Write(absl::Span<const T> data);
|
||||||
|
|
||||||
|
// Maps GPU memory to CPU address space and calls reader that may read from
|
||||||
|
// that memory.
|
||||||
|
template <typename T>
|
||||||
|
Status MappedRead(
|
||||||
|
const std::function<Status(absl::Span<const T>)>& reader) const;
|
||||||
|
|
||||||
|
// Maps GPU memory to CPU address space and calls writer that may write into
|
||||||
|
// that memory.
|
||||||
|
template <typename T>
|
||||||
|
Status MappedWrite(const std::function<Status(absl::Span<T>)>& writer);
|
||||||
|
|
||||||
|
Status MakeView(size_t offset, size_t bytes_size, GlBuffer* gl_buffer);
|
||||||
|
|
||||||
|
// Makes a copy without ownership of the buffer.
|
||||||
|
GlBuffer MakeRef();
|
||||||
|
|
||||||
|
// Binds a buffer to an index.
|
||||||
|
Status BindToIndex(uint32_t index) const;
|
||||||
|
|
||||||
|
// Releases the ownership of the buffer object.
|
||||||
|
void Release() { has_ownership_ = false; }
|
||||||
|
|
||||||
|
size_t bytes_size() const { return bytes_size_; }
|
||||||
|
|
||||||
|
const GLenum target() const { return target_; }
|
||||||
|
|
||||||
|
const GLuint id() const { return id_; }
|
||||||
|
|
||||||
|
bool is_valid() const { return id_ != GL_INVALID_INDEX; }
|
||||||
|
|
||||||
|
size_t offset() const { return offset_; }
|
||||||
|
|
||||||
|
// @return true if this object actually owns corresponding GL buffer
|
||||||
|
// and manages it's lifetime.
|
||||||
|
bool has_ownership() const { return has_ownership_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Invalidate();
|
||||||
|
|
||||||
|
GLenum target_;
|
||||||
|
GLuint id_;
|
||||||
|
size_t bytes_size_;
|
||||||
|
size_t offset_;
|
||||||
|
bool has_ownership_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer);
|
||||||
|
|
||||||
|
// Creates new shader storage buffer that will be modified and used many
|
||||||
|
// times.
|
||||||
|
//
|
||||||
|
// See https://www.khronos.org/opengl/wiki/Shader_Storage_Buffer_Object for
|
||||||
|
// details.
|
||||||
|
template <typename T>
|
||||||
|
Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements,
|
||||||
|
GlBuffer* gl_buffer);
|
||||||
|
|
||||||
|
// Creates new shader storage buffer that will be filled with data once which
|
||||||
|
// will be used many times.
|
||||||
|
template <typename T>
|
||||||
|
Status CreateReadOnlyShaderStorageBuffer(absl::Span<const T> data,
|
||||||
|
GlBuffer* gl_buffer);
|
||||||
|
|
||||||
|
// Adapts raw Buffer::Read method to read data into a vector.
|
||||||
|
template <typename T>
|
||||||
|
Status AppendFromBuffer(const GlBuffer& buffer, std::vector<T>* data) {
|
||||||
|
if (buffer.bytes_size() % sizeof(T) != 0) {
|
||||||
|
return InvalidArgumentError("Buffer is not aligned");
|
||||||
|
}
|
||||||
|
size_t num_elements = buffer.bytes_size() / sizeof(T);
|
||||||
|
data->resize(data->size() + num_elements);
|
||||||
|
return buffer.Read<T>(
|
||||||
|
absl::MakeSpan(data->data() + data->size() - num_elements, num_elements));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Implementation details are below.
|
||||||
|
|
||||||
|
namespace gl_buffer_internal {
|
||||||
|
|
||||||
|
// RAII for creating and/or owning buffer id.
|
||||||
|
class BufferId {
|
||||||
|
public:
|
||||||
|
BufferId() : id_(GL_INVALID_INDEX) {
|
||||||
|
TFLITE_GPU_CALL_GL(glGenBuffers, 1 /* number of buffers */, &id_)
|
||||||
|
.IgnoreError();
|
||||||
|
// only possible error here is when a number of buffers is negative.
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit BufferId(GLuint id) : id_(id) {}
|
||||||
|
|
||||||
|
~BufferId() {
|
||||||
|
if (id_ != GL_INVALID_INDEX) {
|
||||||
|
TFLITE_GPU_CALL_GL(glDeleteBuffers, 1, &id_).IgnoreError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GLuint id() const { return id_; }
|
||||||
|
|
||||||
|
GLuint Release() {
|
||||||
|
GLuint id = GL_INVALID_INDEX;
|
||||||
|
std::swap(id, id_);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
GLuint id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// RAII for binding and unbinding a buffer.
|
||||||
|
class BufferBinder {
|
||||||
|
public:
|
||||||
|
BufferBinder(GLenum target, GLuint id) : target_(target) {
|
||||||
|
TFLITE_GPU_CALL_GL(glBindBuffer, target_, id).IgnoreError();
|
||||||
|
}
|
||||||
|
|
||||||
|
~BufferBinder() {
|
||||||
|
TFLITE_GPU_CALL_GL(glBindBuffer, target_, 0).IgnoreError();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const GLenum target_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// RAII for mapping and unmapping a buffer.
|
||||||
|
class BufferMapper {
|
||||||
|
public:
|
||||||
|
BufferMapper(GLenum target, size_t offset, size_t bytes, GLbitfield access)
|
||||||
|
: target_(target),
|
||||||
|
data_(glMapBufferRange(target_, offset, bytes, access)) {}
|
||||||
|
|
||||||
|
~BufferMapper() { TFLITE_GPU_CALL_GL(glUnmapBuffer, target_).IgnoreError(); }
|
||||||
|
|
||||||
|
void* data() { return data_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const GLenum target_;
|
||||||
|
void* data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gl_buffer_internal
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements,
|
||||||
|
GlBuffer* gl_buffer) {
|
||||||
|
gl_buffer_internal::BufferId id;
|
||||||
|
gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id());
|
||||||
|
// TODO(akulik): benchmark DYNAMIC vs STREAM buffer
|
||||||
|
RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER,
|
||||||
|
num_elements * sizeof(T), nullptr,
|
||||||
|
GL_STREAM_COPY));
|
||||||
|
*gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(),
|
||||||
|
num_elements * sizeof(T), 0, true};
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status CreateReadOnlyShaderStorageBuffer(absl::Span<const T> data,
|
||||||
|
GlBuffer* gl_buffer) {
|
||||||
|
gl_buffer_internal::BufferId id;
|
||||||
|
gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id());
|
||||||
|
RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER,
|
||||||
|
data.size() * sizeof(T), data.data(),
|
||||||
|
GL_STATIC_READ));
|
||||||
|
*gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(),
|
||||||
|
data.size() * sizeof(T), 0, true};
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status GlBuffer::Read(absl::Span<T> data) const {
|
||||||
|
if (data.size() * sizeof(T) < bytes_size()) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Read from buffer failed. Destination data is shorter than buffer.");
|
||||||
|
}
|
||||||
|
// TODO(akulik): glCopyBufferSubData is actually available in ES 3.1, try it.
|
||||||
|
return MappedRead<T>([this, data](absl::Span<const T> src) {
|
||||||
|
std::memcpy(data.data(), src.data(), bytes_size());
|
||||||
|
return OkStatus();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status GlBuffer::Write(absl::Span<const T> data) {
|
||||||
|
if (data.size() * sizeof(T) > bytes_size_) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Write to buffer failed. Source data is larger than buffer.");
|
||||||
|
}
|
||||||
|
gl_buffer_internal::BufferBinder binder(target_, id_);
|
||||||
|
return TFLITE_GPU_CALL_GL(glBufferSubData, target_, offset_, bytes_size_,
|
||||||
|
data.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status GlBuffer::MappedRead(
|
||||||
|
const std::function<Status(absl::Span<const T> d)>& reader) const {
|
||||||
|
if (bytes_size_ % sizeof(T) != 0) {
|
||||||
|
return InvalidArgumentError("Buffer is not aligned");
|
||||||
|
}
|
||||||
|
gl_buffer_internal::BufferBinder binder(target_, id_);
|
||||||
|
gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_,
|
||||||
|
GL_MAP_READ_BIT);
|
||||||
|
if (!mapper.data()) {
|
||||||
|
return GetOpenGlErrors();
|
||||||
|
}
|
||||||
|
return reader(absl::MakeSpan(reinterpret_cast<const T*>(mapper.data()),
|
||||||
|
bytes_size_ / sizeof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status GlBuffer::MappedWrite(
|
||||||
|
const std::function<Status(absl::Span<T> d)>& writer) {
|
||||||
|
if (bytes_size_ % sizeof(T) != 0) {
|
||||||
|
return InvalidArgumentError("Buffer is not aligned");
|
||||||
|
}
|
||||||
|
gl_buffer_internal::BufferBinder binder(target_, id_);
|
||||||
|
gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_,
|
||||||
|
GL_MAP_WRITE_BIT);
|
||||||
|
if (!mapper.data()) {
|
||||||
|
return GetOpenGlErrors();
|
||||||
|
}
|
||||||
|
return writer(absl::MakeSpan(reinterpret_cast<T*>(mapper.data()),
|
||||||
|
bytes_size_ / sizeof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_BUFFER_H_
|
126
tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc
Normal file
126
tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(Buffer, Read) {
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok());
|
||||||
|
std::vector<float> test = {0, 1, 2, 3};
|
||||||
|
GlBuffer buffer;
|
||||||
|
ASSERT_TRUE(CreateReadOnlyShaderStorageBuffer<float>(test, &buffer).ok());
|
||||||
|
std::vector<float> from_buffer;
|
||||||
|
ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok());
|
||||||
|
EXPECT_EQ(test, from_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Buffer, Write) {
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok());
|
||||||
|
GlBuffer buffer;
|
||||||
|
ASSERT_TRUE(CreateReadWriteShaderStorageBuffer<float>(4, &buffer).ok());
|
||||||
|
std::vector<float> test = {0, 1, 2, 3};
|
||||||
|
ASSERT_TRUE(buffer.Write<float>(test).ok());
|
||||||
|
std::vector<float> from_buffer;
|
||||||
|
ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok());
|
||||||
|
EXPECT_EQ(test, from_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Buffer, View) {
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok());
|
||||||
|
GlBuffer buffer;
|
||||||
|
ASSERT_TRUE(CreateReadWriteShaderStorageBuffer<float>(6, &buffer).ok());
|
||||||
|
EXPECT_TRUE(buffer.has_ownership());
|
||||||
|
EXPECT_EQ(24, buffer.bytes_size());
|
||||||
|
EXPECT_EQ(0, buffer.offset());
|
||||||
|
|
||||||
|
// Create view and write data there.
|
||||||
|
GlBuffer* buffer1_ptr = nullptr;
|
||||||
|
ASSERT_TRUE(buffer.MakeView(4, 16, buffer1_ptr).ok());
|
||||||
|
EXPECT_FALSE(buffer1_ptr->has_ownership());
|
||||||
|
EXPECT_EQ(16, buffer1_ptr->bytes_size());
|
||||||
|
EXPECT_EQ(4, buffer1_ptr->offset());
|
||||||
|
std::vector<float> test = {1, 2, 3, 4};
|
||||||
|
ASSERT_TRUE(buffer1_ptr->Write<float>(test).ok());
|
||||||
|
|
||||||
|
// Check that data indeed landed in a buffer with proper offset.
|
||||||
|
std::vector<float> from_buffer;
|
||||||
|
ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok());
|
||||||
|
EXPECT_THAT(from_buffer, testing::ElementsAre(0, 1, 2, 3, 4, 0));
|
||||||
|
|
||||||
|
std::vector<float> from_view;
|
||||||
|
ASSERT_TRUE(AppendFromBuffer(*buffer1_ptr, &from_view).ok());
|
||||||
|
EXPECT_THAT(from_view, testing::ElementsAre(1, 2, 3, 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Buffer, SubView) {
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok());
|
||||||
|
GlBuffer buffer;
|
||||||
|
ASSERT_TRUE(CreateReadWriteShaderStorageBuffer<float>(6, &buffer).ok());
|
||||||
|
|
||||||
|
// Create view and another view over that view.
|
||||||
|
|
||||||
|
GlBuffer* buffer1_ptr = nullptr;
|
||||||
|
ASSERT_TRUE(buffer.MakeView(4, 16, buffer1_ptr).ok());
|
||||||
|
GlBuffer* buffer2_ptr = nullptr;
|
||||||
|
EXPECT_NE(buffer1_ptr->MakeView(1, 16, buffer2_ptr), OkStatus());
|
||||||
|
ASSERT_TRUE(buffer1_ptr->MakeView(2, 2, buffer2_ptr).ok());
|
||||||
|
|
||||||
|
EXPECT_FALSE(buffer2_ptr->has_ownership());
|
||||||
|
EXPECT_EQ(2, buffer2_ptr->bytes_size());
|
||||||
|
EXPECT_EQ(6, buffer2_ptr->offset());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Buffer, Copy) {
|
||||||
|
std::unique_ptr<EglEnvironment> env;
|
||||||
|
ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok());
|
||||||
|
GlBuffer buffer;
|
||||||
|
ASSERT_TRUE(CreateReadWriteShaderStorageBuffer<float>(4, &buffer).ok());
|
||||||
|
|
||||||
|
// Create view and write data there.
|
||||||
|
GlBuffer* buffer1_ptr = nullptr;
|
||||||
|
ASSERT_TRUE(buffer.MakeView(4, 4, buffer1_ptr).ok());
|
||||||
|
|
||||||
|
GlBuffer* buffer2_ptr = nullptr;
|
||||||
|
ASSERT_TRUE(buffer.MakeView(8, 4, buffer2_ptr).ok());
|
||||||
|
|
||||||
|
// Copy data from one view to another
|
||||||
|
ASSERT_TRUE(buffer1_ptr->Write<float>({1}).ok());
|
||||||
|
ASSERT_TRUE(CopyBuffer(*buffer1_ptr, *buffer2_ptr).ok());
|
||||||
|
|
||||||
|
// Check that data indeed landed correctly.
|
||||||
|
std::vector<float> from_buffer;
|
||||||
|
ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok());
|
||||||
|
EXPECT_THAT(from_buffer, testing::ElementsAre(0, 1, 1, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
115
tensorflow/lite/delegates/gpu/gl/gl_call.h
Normal file
115
tensorflow/lite/delegates/gpu/gl/gl_call.h
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_CALL_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_CALL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace gl {
|
||||||
|
|
||||||
|
// Primary purpose of this file is to provide useful macro for calling GL
|
||||||
|
// functions and checking errors. It also attaches a context to status in case
|
||||||
|
// of a GL error.
|
||||||
|
//
|
||||||
|
// Use TFLITE_GPU_CALL_GL as follows:
|
||||||
|
//
|
||||||
|
// For GL functions with a return value:
|
||||||
|
// Before:
|
||||||
|
// GLint result = glFunc(...);
|
||||||
|
// RETURN_IF_ERROR(GetOpenGlErrors());
|
||||||
|
// After:
|
||||||
|
// GLint result;
|
||||||
|
// RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFunc, &result, ...));
|
||||||
|
//
|
||||||
|
// For GL functions without a return value:
|
||||||
|
// Before:
|
||||||
|
// glFunc(...);
|
||||||
|
// RETURN_IF_ERROR(GetOpenGlErrors());
|
||||||
|
// After:
|
||||||
|
// RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFunc, ...));
|
||||||
|
|
||||||
|
namespace gl_call_internal {
|
||||||
|
|
||||||
|
// For GL functions with a return value.
|
||||||
|
template <typename T>
|
||||||
|
struct Caller {
|
||||||
|
template <typename F, typename ErrorF, typename... Params>
|
||||||
|
Status operator()(const std::string& context, F func, ErrorF error_func,
|
||||||
|
T* result, Params&&... params) {
|
||||||
|
*result = func(std::forward<Params>(params)...);
|
||||||
|
const auto status = error_func();
|
||||||
|
if (status.ok()) return OkStatus();
|
||||||
|
return Status(status.code(), status.error_message() + ": " + context);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// For GL functions without a return value.
|
||||||
|
template<>
|
||||||
|
struct Caller<void> {
|
||||||
|
template <typename F, typename ErrorF, typename... Params>
|
||||||
|
Status operator()(const std::string& context, F func, ErrorF error_func,
|
||||||
|
Params&&... params) {
|
||||||
|
func(std::forward<Params>(params)...);
|
||||||
|
const auto status = error_func();
|
||||||
|
if (status.ok()) return OkStatus();
|
||||||
|
return Status(status.code(), status.error_message() + ": " + context);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F, typename ErrorF, typename ResultT, typename... ParamsT>
|
||||||
|
Status CallAndCheckError(const std::string& context, F func, ErrorF error_func,
|
||||||
|
ResultT* result, ParamsT&&... params) {
|
||||||
|
return Caller<ResultT>()(context, func, error_func, result,
|
||||||
|
std::forward<ParamsT>(params)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F, typename ErrorF, typename... Params>
|
||||||
|
Status CallAndCheckError(const std::string& context, F func, ErrorF error_func,
|
||||||
|
Params&&... params) {
|
||||||
|
return Caller<void>()(context, func, error_func,
|
||||||
|
std::forward<Params>(params)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gl_call_internal
|
||||||
|
|
||||||
|
// XX_STRINGIFY is a helper macro to effectively apply # operator to an
|
||||||
|
// arbitrary value.
|
||||||
|
#define TFLITE_GPU_INTERNAL_STRINGIFY_HELPER(x) #x
|
||||||
|
#define TFLITE_GPU_INTERNAL_STRINGIFY(x) TFLITE_GPU_INTERNAL_STRINGIFY_HELPER(x)
|
||||||
|
#define TFLITE_GPU_FILE_LINE \
|
||||||
|
__FILE__ ":" TFLITE_GPU_INTERNAL_STRINGIFY(__LINE__)
|
||||||
|
|
||||||
|
#define TFLITE_GPU_CALL_GL(method, ...) \
|
||||||
|
::tflite::gpu::gl::gl_call_internal::CallAndCheckError( \
|
||||||
|
#method " in " TFLITE_GPU_FILE_LINE, method, \
|
||||||
|
::tflite::gpu::gl::GetOpenGlErrors, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define TFLITE_GPU_CALL_EGL(method, ...) \
|
||||||
|
::tflite::gpu::gl::gl_call_internal::CallAndCheckError( \
|
||||||
|
#method " in " TFLITE_GPU_FILE_LINE, method, \
|
||||||
|
::tflite::gpu::gl::GetEglError, __VA_ARGS__)
|
||||||
|
|
||||||
|
} // namespace gl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_CALL_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user