Introduce custom transformations placeholders and rename ApplyModelTransformations.
PiperOrigin-RevId: 329575742 Change-Id: Iefd7c92c7bfdb11ed2e8645ef9e6ad08ddfdd621
This commit is contained in:
parent
5adacc8807
commit
f635a5bb85
@ -54,7 +54,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/gl:api",
|
||||
"//tensorflow/lite/delegates/gpu/gl:command_queue",
|
||||
"//tensorflow/lite/delegates/gpu/gl:compiler",
|
||||
@ -96,7 +96,6 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/metal:api",
|
||||
"//tensorflow/lite/delegates/gpu/metal:buffer_convert",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
|
||||
|
@ -346,7 +346,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -97,8 +97,8 @@ class Delegate {
|
||||
// Apply general transformations on the graph.
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(&graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
if (!ApplyModelTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph transformations failed");
|
||||
}
|
||||
|
||||
InferenceEnvironmentOptions env_options;
|
||||
|
@ -23,7 +23,10 @@ cc_library(
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["custom_parsers.h"],
|
||||
[
|
||||
"custom_parsers.h",
|
||||
"custom_transformations.h",
|
||||
],
|
||||
visibility = ["//tensorflow/lite/delegates/gpu/common:__subpackages__"],
|
||||
)
|
||||
|
||||
@ -125,7 +128,7 @@ cc_library(
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels/internal:reference_base",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// Applies all implemented custom model transformations.
|
||||
bool ApplyCustomTransformations(ModelTransformer* transformer);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_
|
@ -14,3 +14,12 @@ cc_library(
|
||||
"@com_google_absl//absl/types:any",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_transformations",
|
||||
srcs = ["custom_transformations.cc"],
|
||||
hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_transformations.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h"
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
bool ApplyCustomTransformations(ModelTransformer* transformer) { return true; }
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -45,7 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
@ -2822,8 +2822,8 @@ absl::Status BuildFinalModel(
|
||||
// Apply general transformations on the graph.
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
if (!ApplyModelTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph transformations failed");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -33,6 +33,6 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
|
||||
],
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
@ -95,8 +95,8 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
|
||||
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
if (!ApplyModelTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph transformations failed");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_platform_alias")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -118,9 +120,9 @@ cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "general_transformations",
|
||||
srcs = ["general_transformations.cc"],
|
||||
hdrs = ["general_transformations.h"],
|
||||
name = "model_transformations",
|
||||
srcs = ["model_transformations.cc"],
|
||||
hdrs = ["model_transformations.h"],
|
||||
deps = [
|
||||
":add_quant_adjustments",
|
||||
":fuse_add_to_conv",
|
||||
@ -130,7 +132,7 @@ cc_library(
|
||||
":merge_padding_with",
|
||||
":remove_noop",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
],
|
||||
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
|
||||
@ -29,6 +30,8 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
bool ApplyGeneralTransformations(ModelTransformer* transformer) {
|
||||
// whenever any of these transforms return false, that means that a graph
|
||||
// is in the broken state and processing should not continue.
|
||||
@ -60,5 +63,12 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) {
|
||||
NewMergeMulWithConvolution().get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ApplyModelTransformations(ModelTransformer* transformer) {
|
||||
return ApplyCustomTransformations(transformer) &&
|
||||
ApplyGeneralTransformations(transformer);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -21,8 +21,9 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// Applies custom and general transformations to the model in the proper order.
|
||||
// @return false when something went wrong that turned a graph in a broken state
|
||||
bool ApplyGeneralTransformations(ModelTransformer* transformer);
|
||||
bool ApplyModelTransformations(ModelTransformer* transformer);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/api.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/compiler.h"
|
||||
@ -138,8 +138,8 @@ class Delegate {
|
||||
// Apply general transformations on the graph.
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(&graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
if (!ApplyModelTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph transformations failed");
|
||||
}
|
||||
|
||||
if (!env_) RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env_));
|
||||
|
@ -38,7 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/api.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
|
||||
|
Loading…
Reference in New Issue
Block a user