Introduce custom transformations placeholders and rename ApplyModelTransformations.

PiperOrigin-RevId: 329575742
Change-Id: Iefd7c92c7bfdb11ed2e8645ef9e6ad08ddfdd621
This commit is contained in:
A. Unique TensorFlower 2020-09-01 13:56:08 -07:00 committed by TensorFlower Gardener
parent 5adacc8807
commit f635a5bb85
15 changed files with 103 additions and 25 deletions

View File

@ -54,7 +54,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
"//tensorflow/lite/delegates/gpu/gl:api",
"//tensorflow/lite/delegates/gpu/gl:command_queue",
"//tensorflow/lite/delegates/gpu/gl:compiler",
@ -96,7 +96,6 @@ objc_library(
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
"//tensorflow/lite/delegates/gpu/metal:api",
"//tensorflow/lite/delegates/gpu/metal:buffer_convert",
"//tensorflow/lite/delegates/gpu/metal:compiled_model",

View File

@ -346,7 +346,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:model_builder",
"//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
"@com_google_absl//absl/types:span",
],
)

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
namespace tflite {
namespace gpu {
@ -97,8 +97,8 @@ class Delegate {
// Apply general transformations on the graph.
NullTransformationReporter reporter;
ModelTransformer transformer(&graph, &reporter);
if (!ApplyGeneralTransformations(&transformer)) {
return absl::InternalError("Graph general transformations failed");
if (!ApplyModelTransformations(&transformer)) {
return absl::InternalError("Graph transformations failed");
}
InferenceEnvironmentOptions env_options;

View File

@ -23,7 +23,10 @@ cc_library(
)
exports_files(
["custom_parsers.h"],
[
"custom_parsers.h",
"custom_transformations.h",
],
visibility = ["//tensorflow/lite/delegates/gpu/common:__subpackages__"],
)
@ -125,7 +128,7 @@ cc_library(
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:reference_base",
"//tensorflow/lite/kernels/internal:tensor",

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
namespace tflite {
namespace gpu {
// Applies all implemented custom model transformations.
bool ApplyCustomTransformations(ModelTransformer* transformer);
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_TRANSFORMATIONS_H_

View File

@ -14,3 +14,12 @@ cc_library(
"@com_google_absl//absl/types:any",
],
)
cc_library(
name = "custom_transformations",
srcs = ["custom_transformations.cc"],
hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_transformations.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:model_transformer",
],
)

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
namespace tflite {
namespace gpu {
bool ApplyCustomTransformations(ModelTransformer* transformer) { return true; }
} // namespace gpu
} // namespace tflite

View File

@ -45,7 +45,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
#include "tensorflow/lite/delegates/utils.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
@ -2822,8 +2822,8 @@ absl::Status BuildFinalModel(
// Apply general transformations on the graph.
NullTransformationReporter reporter;
ModelTransformer transformer(graph, &reporter);
if (!ApplyGeneralTransformations(&transformer)) {
return absl::InternalError("Graph general transformations failed");
if (!ApplyModelTransformations(&transformer)) {
return absl::InternalError("Graph transformations failed");
}
return absl::OkStatus();
}

View File

@ -33,6 +33,6 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:model_builder",
"//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
"//tensorflow/lite/delegates/gpu/common/transformations:model_transformations",
],
)

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/model_builder.h"
@ -95,8 +95,8 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
NullTransformationReporter reporter;
ModelTransformer transformer(graph, &reporter);
if (!ApplyGeneralTransformations(&transformer)) {
return absl::InternalError("Graph general transformations failed");
if (!ApplyModelTransformations(&transformer)) {
return absl::InternalError("Graph transformations failed");
}
return absl::OkStatus();

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform:build_config.bzl", "tf_platform_alias")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
@ -118,9 +120,9 @@ cc_test(
)
cc_library(
name = "general_transformations",
srcs = ["general_transformations.cc"],
hdrs = ["general_transformations.h"],
name = "model_transformations",
srcs = ["model_transformations.cc"],
hdrs = ["model_transformations.h"],
deps = [
":add_quant_adjustments",
":fuse_add_to_conv",
@ -130,7 +132,7 @@ cc_library(
":merge_padding_with",
":remove_noop",
"//tensorflow/lite/delegates/gpu/common:model_transformer",
],
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
)
cc_library(

View File

@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
#include <memory>
#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
@ -29,6 +30,8 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace {
bool ApplyGeneralTransformations(ModelTransformer* transformer) {
// whenever any of these transforms return false, that means that a graph
// is in the broken state and processing should not continue.
@ -60,5 +63,12 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) {
NewMergeMulWithConvolution().get());
}
} // namespace
bool ApplyModelTransformations(ModelTransformer* transformer) {
return ApplyCustomTransformations(transformer) &&
ApplyGeneralTransformations(transformer);
}
} // namespace gpu
} // namespace tflite

View File

@ -21,8 +21,9 @@ limitations under the License.
namespace tflite {
namespace gpu {
// Applies custom and general transformations to the model in the proper order.
// @return false when something went wrong that turned a graph in a broken state
bool ApplyGeneralTransformations(ModelTransformer* transformer);
bool ApplyModelTransformations(ModelTransformer* transformer);
} // namespace gpu
} // namespace tflite

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
#include "tensorflow/lite/delegates/gpu/gl/api.h"
#include "tensorflow/lite/delegates/gpu/gl/command_queue.h"
#include "tensorflow/lite/delegates/gpu/gl/compiler.h"
@ -138,8 +138,8 @@ class Delegate {
// Apply general transformations on the graph.
NullTransformationReporter reporter;
ModelTransformer transformer(&graph, &reporter);
if (!ApplyGeneralTransformations(&transformer)) {
return absl::InternalError("Graph general transformations failed");
if (!ApplyModelTransformations(&transformer)) {
return absl::InternalError("Graph transformations failed");
}
if (!env_) RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env_));

View File

@ -38,7 +38,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/metal/api.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"