Merge pull request #46868 from Intel-tensorflow:yang/graph-api-registration
PiperOrigin-RevId: 360977145 Change-Id: I783f7923144f62168e3a9951fd0a3359913340ef
This commit is contained in:
commit
0577bbdd46
@ -659,6 +659,7 @@ tf_cuda_cc_test(
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":c_test_util",
|
||||
":test_op_kernel",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
|
@ -238,6 +238,15 @@ Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferToMessage(const TF_Buffer* in,
|
||||
tensorflow::protobuf::MessageLite* out) {
|
||||
if (in == nullptr || !out->ParseFromArray(in->data, in->length)) {
|
||||
return errors::InvalidArgument("Unparseable ", out->GetTypeName(),
|
||||
" proto");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
|
||||
const char* mutation_type) {
|
||||
// If any session has already run this node_id, mark this session as
|
||||
|
@ -190,6 +190,9 @@ namespace tensorflow {
|
||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
TF_Buffer* out);
|
||||
|
||||
Status BufferToMessage(const TF_Buffer* in,
|
||||
tensorflow::protobuf::MessageLite* out);
|
||||
|
||||
// Set the shapes and types of the output's handle.
|
||||
//
|
||||
// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
@ -44,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
@ -2576,6 +2578,20 @@ TEST(CAPI, TestTensorIsNotAligned) {
|
||||
TF_DeleteTensor(a);
|
||||
}
|
||||
|
||||
TEST(CAPI, MessageBufferConversion) {
|
||||
NodeDef node_in, node_out;
|
||||
node_in.set_name("Test name");
|
||||
node_in.set_op("Test op");
|
||||
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
TF_CHECK_OK(MessageToBuffer(node_in, buffer));
|
||||
TF_CHECK_OK(BufferToMessage(buffer, &node_out));
|
||||
TF_DeleteBuffer(buffer);
|
||||
|
||||
protobuf::util::MessageDifferencer differencer;
|
||||
EXPECT_TRUE(differencer.Compare(node_in, node_out));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
|
67
tensorflow/c/experimental/grappler/BUILD
Normal file
67
tensorflow/c/experimental/grappler/BUILD
Normal file
@ -0,0 +1,67 @@
|
||||
# Description:
|
||||
# Graph C API.
|
||||
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grappler_hdrs",
|
||||
hdrs = ["grappler.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grappler",
|
||||
srcs = ["grappler.cc"],
|
||||
hdrs = [
|
||||
"grappler.h",
|
||||
"grappler_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "grappler_test",
|
||||
srcs = ["grappler_test.cc"],
|
||||
deps = [
|
||||
":grappler",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/clusters:single_machine",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
|
||||
],
|
||||
)
|
195
tensorflow/c/experimental/grappler/grappler.cc
Normal file
195
tensorflow/c/experimental/grappler/grappler.cc
Normal file
@ -0,0 +1,195 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This file extends/implements core graph optimizer base classes in terms of
|
||||
// the C API defined in grappler.h. A class "CSomething" represents a
|
||||
// "Something" that can be manipulated via calls in the C interface and a C
|
||||
// struct called "TP_Something".
|
||||
|
||||
#include "tensorflow/c/experimental/grappler/grappler.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/experimental/grappler/grappler_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace {
|
||||
|
||||
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
||||
do { \
|
||||
if (STRUCT_OBJ.struct_size == 0) { \
|
||||
return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
|
||||
"struct_size field in " #STRUCT_NAME \
|
||||
" must be set to " #SIZE_VALUE_NAME "."); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
|
||||
do { \
|
||||
if (STRUCT_OBJ.NAME == 0) { \
|
||||
return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
|
||||
"'" #NAME "' field in " #STRUCT_NAME \
|
||||
" must be set."); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
tensorflow::Status ValidateTPOptimizerRegistrationParams(
|
||||
const TP_OptimizerRegistrationParams& params) {
|
||||
VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params,
|
||||
TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) {
|
||||
VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE);
|
||||
VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ValidateTPOptimizerConfigs(
|
||||
const TP_OptimizerConfigs& configs) {
|
||||
VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs,
|
||||
TP_OPTIMIZER_CONFIGS_STRUCT_SIZE);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
#undef VALIDATE_MEMBER
|
||||
#undef VALIDATE_STRUCT_SIZE
|
||||
|
||||
// A map containing the input graph as its key, and TF_GrapplerItem as the
|
||||
// value. Users can fetch GrapplerItem for additional info to transform the
|
||||
// graph.
|
||||
absl::flat_hash_map<TF_Buffer*, const TF_GrapplerItem*>* GrapplerItemMap() {
|
||||
static absl::flat_hash_map<TF_Buffer*, const TF_GrapplerItem*>*
|
||||
grappler_items =
|
||||
new absl::flat_hash_map<TF_Buffer*, const TF_GrapplerItem*>;
|
||||
return grappler_items;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph_def) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
OwnedTFBuffer graph_buf(TF_NewBuffer());
|
||||
OwnedTFBuffer optimized_graph_buf(TF_NewBuffer());
|
||||
TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get()));
|
||||
|
||||
const auto it = GrapplerItemMap()->find(graph_buf.get());
|
||||
if (it == GrapplerItemMap()->end())
|
||||
GrapplerItemMap()->insert(
|
||||
{graph_buf.get(), reinterpret_cast<const TF_GrapplerItem*>(&item)});
|
||||
|
||||
optimizer_.optimize_func(c_optimizer_, graph_buf.get(),
|
||||
optimized_graph_buf.get(), c_status.get());
|
||||
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
BufferToMessage(optimized_graph_buf.get(), optimized_graph_def));
|
||||
|
||||
GrapplerItemMap()->erase(graph_buf.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CONFIG_TOGGLE(optimizer) \
|
||||
if (tp_configs.optimizer == TF_TriState_Off) \
|
||||
configs.toggle_config[#optimizer] = RewriterConfig::OFF; \
|
||||
else \
|
||||
configs.toggle_config[#optimizer] = RewriterConfig::ON;
|
||||
|
||||
void CGraphOptimizerRegister(
|
||||
const PluginGraphOptimizerRegistry::Creator& creator,
|
||||
const TP_OptimizerConfigs tp_configs, const char* device_type) {
|
||||
ConfigList configs;
|
||||
// disable_model_pruning is turned off by default.
|
||||
if (tp_configs.disable_model_pruning == TF_TriState_On)
|
||||
configs.disable_model_pruning = true;
|
||||
else
|
||||
configs.disable_model_pruning = false;
|
||||
// The other configs are turned on by default.
|
||||
CONFIG_TOGGLE(implementation_selector);
|
||||
CONFIG_TOGGLE(function_optimization);
|
||||
CONFIG_TOGGLE(common_subgraph_elimination);
|
||||
CONFIG_TOGGLE(arithmetic_optimization);
|
||||
CONFIG_TOGGLE(debug_stripper);
|
||||
CONFIG_TOGGLE(constant_folding);
|
||||
CONFIG_TOGGLE(shape_optimization);
|
||||
CONFIG_TOGGLE(auto_mixed_precision);
|
||||
CONFIG_TOGGLE(auto_mixed_precision_mkl);
|
||||
CONFIG_TOGGLE(pin_to_host_optimization);
|
||||
CONFIG_TOGGLE(layout_optimizer);
|
||||
CONFIG_TOGGLE(remapping);
|
||||
CONFIG_TOGGLE(loop_optimization);
|
||||
CONFIG_TOGGLE(dependency_optimization);
|
||||
CONFIG_TOGGLE(auto_parallel);
|
||||
CONFIG_TOGGLE(memory_optimization);
|
||||
CONFIG_TOGGLE(scoped_allocator_optimization);
|
||||
PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
|
||||
creator, device_type, configs);
|
||||
}
|
||||
|
||||
#undef CONFIG_TOGGLE
|
||||
|
||||
tensorflow::Status InitGraphPlugin(void* dso_handle) {
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
|
||||
// Step 1: Load symbol for `TF_InitPlugin`
|
||||
void* dso_symbol;
|
||||
TF_RETURN_IF_ERROR(
|
||||
env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol));
|
||||
|
||||
// Step 2: Call `TF_InitPlugin`
|
||||
auto init_fn = reinterpret_cast<TFInitGraphPluginFn>(dso_symbol);
|
||||
return InitGraphPlugin(init_fn);
|
||||
}
|
||||
|
||||
tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) {
|
||||
TP_OptimizerRegistrationParams params{
|
||||
TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
|
||||
TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
|
||||
TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
|
||||
params.major_version = SE_MAJOR;
|
||||
params.minor_version = SE_MINOR;
|
||||
params.patch_version = SE_PATCH;
|
||||
params.optimizer = &optimizer;
|
||||
params.optimizer_configs = &optimizer_configs;
|
||||
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
init_fn(¶ms, c_status.get());
|
||||
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
|
||||
TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params));
|
||||
TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer));
|
||||
TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs));
|
||||
|
||||
CGraphOptimizerRegister(
|
||||
[=]() { return new CGraphOptimizer(optimizer, params.device_type); },
|
||||
optimizer_configs, params.device_type);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
173
tensorflow/c/experimental/grappler/grappler.h
Normal file
173
tensorflow/c/experimental/grappler/grappler.h
Normal file
@ -0,0 +1,173 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for Graph. The API is under active development and eventually
|
||||
// should allow registering a plugin graph optimizer with TensorFlow.
|
||||
//
|
||||
// Conventions:
|
||||
// * Struct prefix indicates whether struct fields should be filled by the
|
||||
// plugin or core implementation:
|
||||
// * Struct that should be filled by the plugin: `TP_OptimizerConfigs`,
|
||||
// `TP_Optimizer`, `TP_OptimizerRegistrationParams`
|
||||
// * Struct that should be filled by the proper: `TF_GrapplerItem`,
|
||||
// `TF_GraphProperties`, `TF_FunctionLibraryDefinition`
|
||||
// * We use `struct_size` for version checking. It should be set both by
|
||||
// core and the plugin.
|
||||
// * For example, `TF_InitGraph` function receives
|
||||
// `TP_OptimizerRegistrationParams*` as input with `struct_size`
|
||||
// populated by core. The plugin is responsible for setting
|
||||
// `struct_size` as well, along with all other fields.
|
||||
// * Refer to "TensorFlow Versioning Strategy" section at
|
||||
// https://github.com/tensorflow/community/pull/257/files.
|
||||
// * Note that the API is still under active development and doesn't have
|
||||
// versioning guarantees yet.
|
||||
// * `void* ext` is a free-form field that can be populated by
|
||||
// a plugin in `TP_*` structs or potential future extension points .
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// /* Sample TensorFlow code below, exact implementation might differ. */
|
||||
// // Version checking uses `struct_size`. It should be set both by core
|
||||
// // and the plugin.
|
||||
// TP_OptimizerRegistrationParams params{
|
||||
// TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
|
||||
// TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
|
||||
// TP_OptimizerConfigs configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
|
||||
// params.optimizer = &optimizer;
|
||||
// params.configs = &configs;
|
||||
//
|
||||
// /* Plugin code below */
|
||||
// void TF_InitGraph(TP_OptimizerRegistrationParams* params,
|
||||
// TF_Status* status) {
|
||||
// params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE;
|
||||
// params->device_type = "MY_DEVICE";
|
||||
//
|
||||
// // Disable certain optimizer.
|
||||
// params->optimizer_configs->struct_size =
|
||||
// TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; params->optimizer_configs->remapping =
|
||||
// TF_TriState_Off;
|
||||
//
|
||||
// // Set functions to create a new optimizer.
|
||||
// params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE;
|
||||
// params->optimizer->create_func = (My_optimizer::create_func);
|
||||
// }
|
||||
|
||||
#define SE_MAJOR 0
|
||||
#define SE_MINOR 0
|
||||
#define SE_PATCH 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// TF_TriState is the C API typedef for tri-state.
|
||||
typedef enum TF_TriState {
|
||||
TF_TriState_Default = 0,
|
||||
TF_TriState_Off,
|
||||
TF_TriState_On,
|
||||
} TF_TriState;
|
||||
|
||||
// Flags indicating whether existing optimizers should be turned off.
|
||||
// It's optional for plugin to set functions to return true/false. If not
|
||||
// set, proper uses configuration set by user.
|
||||
typedef struct TP_OptimizerConfigs {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
TF_TriState disable_model_pruning;
|
||||
TF_TriState implementation_selector;
|
||||
TF_TriState function_optimization;
|
||||
TF_TriState common_subgraph_elimination;
|
||||
TF_TriState arithmetic_optimization;
|
||||
TF_TriState debug_stripper;
|
||||
TF_TriState constant_folding;
|
||||
TF_TriState shape_optimization;
|
||||
TF_TriState auto_mixed_precision;
|
||||
TF_TriState auto_mixed_precision_mkl;
|
||||
TF_TriState pin_to_host_optimization;
|
||||
TF_TriState layout_optimizer;
|
||||
TF_TriState remapping;
|
||||
TF_TriState loop_optimization;
|
||||
TF_TriState dependency_optimization;
|
||||
TF_TriState auto_parallel;
|
||||
TF_TriState memory_optimization;
|
||||
TF_TriState scoped_allocator_optimization;
|
||||
} TP_OptimizerConfigs;
|
||||
|
||||
#define TP_OPTIMIZER_CONFIGS_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(TP_OptimizerConfigs, scoped_allocator_optimization)
|
||||
|
||||
// Struct for Optimizer. Plugin authors must provide an optimize function.
|
||||
// Creation and deletion functions are optional.
|
||||
typedef struct TP_Optimizer {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
|
||||
// [Optional]
|
||||
// Create function for optimizer.
|
||||
void* (*create_func)();
|
||||
|
||||
// Optimizer function for optimizer. The first param is an optimizer created
|
||||
// by create_func. The second param is input graph. The third param is output
|
||||
// graph.
|
||||
void (*optimize_func)(void*, TF_Buffer*, TF_Buffer*, TF_Status*);
|
||||
|
||||
// [Optional]
|
||||
// Destroy function for optimizer. If Create function is provided, destroy
|
||||
// function is must.
|
||||
void (*destroy_func)(void*);
|
||||
} TP_Optimizer;
|
||||
|
||||
#define TP_OPTIMIZER_STRUCT_SIZE TF_OFFSET_OF_END(TP_Optimizer, destroy_func)
|
||||
|
||||
typedef struct TP_OptimizerRegistrationParams {
|
||||
size_t struct_size;
|
||||
void* ext; // reserved for future use
|
||||
|
||||
// Graph C API version.
|
||||
int32_t major_version;
|
||||
int32_t minor_version;
|
||||
int32_t patch_version;
|
||||
|
||||
// Backend device type supported by the optimizer.
|
||||
const char* device_type;
|
||||
TP_OptimizerConfigs* optimizer_configs; // output, set by plugin
|
||||
TP_Optimizer* optimizer; // output, set by plugin
|
||||
} TP_OptimizerRegistrationParams;
|
||||
|
||||
#define TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(TP_OptimizerRegistrationParams, optimizer)
|
||||
|
||||
// TF_InitGraph is used to do graph optimizer registration.
|
||||
// Plugin should implement TF_InitGraph to register graph optimizers.
|
||||
void TF_InitGraph(TP_OptimizerRegistrationParams* params, TF_Status* status);
|
||||
|
||||
// TF_GrapplerItem represents a combination of a graph, one of more fetch nodes,
|
||||
// and potentially a set of nodes to feed.
|
||||
typedef struct TF_GrapplerItem TF_GrapplerItem;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_
|
106
tensorflow/c/experimental/grappler/grappler_internal.h
Normal file
106
tensorflow/c/experimental/grappler/grappler_internal.h
Normal file
@ -0,0 +1,106 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Classes and utilities that work with Graph C API for internal use.
|
||||
// This includes functions used for optimizer registration and interfaces needed
|
||||
// for testing.
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/grappler/grappler.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Plugin initialization function that a device plugin
|
||||
// must define.
|
||||
typedef void (*TFInitGraphPluginFn)(TP_OptimizerRegistrationParams* const,
|
||||
TF_Status* const);
|
||||
|
||||
// Registers Graph optimizers.
|
||||
Status InitGraphPlugin(void* dso_handle);
|
||||
|
||||
// Allow registering a graph optimizer using a function (used for
|
||||
// testing).
|
||||
Status InitGraphPlugin(TFInitGraphPluginFn init_fn);
|
||||
|
||||
struct GrapplerItem;
|
||||
class Cluster;
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
struct TFBufferDeleter {
|
||||
void operator()(TF_Buffer* buf) const { TF_DeleteBuffer(buf); }
|
||||
};
|
||||
using OwnedTFBuffer = std::unique_ptr<TF_Buffer, TFBufferDeleter>;
|
||||
|
||||
class CGraphOptimizer : public CustomGraphOptimizer {
|
||||
public:
|
||||
explicit CGraphOptimizer(TP_Optimizer optimizer, const char* device_type)
|
||||
: optimizer_(optimizer), device_type_(device_type) {
|
||||
if (optimizer.create_func != nullptr) {
|
||||
c_optimizer_ = (*optimizer_.create_func)();
|
||||
} else {
|
||||
c_optimizer_ = nullptr;
|
||||
}
|
||||
}
|
||||
std::string name() const override { return "PluggableGraphOptimizer"; }
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override {}
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph_def) override;
|
||||
|
||||
~CGraphOptimizer() override {
|
||||
if (optimizer_.destroy_func != nullptr) {
|
||||
(*optimizer_.destroy_func)(c_optimizer_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TP_Optimizer optimizer_;
|
||||
std::string device_type_;
|
||||
void* c_optimizer_;
|
||||
};
|
||||
|
||||
// Registration function to register a CGraphOptimizer along with plugin configs
|
||||
// and device type.
|
||||
void CGraphOptimizerRegister(
|
||||
const PluginGraphOptimizerRegistry::Creator& creator,
|
||||
const TP_OptimizerConfigs tp_configs, const char* device_type);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
|
117
tensorflow/c/experimental/grappler/grappler_test.cc
Normal file
117
tensorflow/c/experimental/grappler/grappler_test.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0(the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/grappler/grappler.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/experimental/grappler/grappler_internal.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
void optimize_func(void* optimizer, TF_Buffer* graph_buf,
|
||||
TF_Buffer* optimized_graph_buf, TF_Status* tf_status) {}
|
||||
|
||||
void PopulateDefaultParam(TP_OptimizerRegistrationParams* params) {
|
||||
params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE;
|
||||
params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE;
|
||||
params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE;
|
||||
params->optimizer->create_func = nullptr;
|
||||
params->optimizer->optimize_func = optimize_func;
|
||||
params->optimizer->destroy_func = nullptr;
|
||||
}
|
||||
|
||||
TEST(Grappler, SuccessfulRegistration) {
|
||||
auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultParam(params);
|
||||
params->device_type = "Success";
|
||||
params->optimizer_configs->remapping = TF_TriState_Off;
|
||||
};
|
||||
|
||||
TF_ASSERT_OK(InitGraphPlugin(plugin_init));
|
||||
ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
|
||||
std::set<string>{"Success"})
|
||||
.size(),
|
||||
1);
|
||||
ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs(
|
||||
true, std::set<string>{"Success"});
|
||||
ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF);
|
||||
}
|
||||
|
||||
TEST(Grappler, MultiplePluginRegistration) {
|
||||
auto plugin_init_0 = [](TP_OptimizerRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultParam(params);
|
||||
params->device_type = "Device0";
|
||||
};
|
||||
auto plugin_init_1 = [](TP_OptimizerRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultParam(params);
|
||||
params->device_type = "Device1";
|
||||
};
|
||||
|
||||
TF_ASSERT_OK(InitGraphPlugin(plugin_init_0));
|
||||
TF_ASSERT_OK(InitGraphPlugin(plugin_init_1));
|
||||
ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
|
||||
std::set<string>{"Device0", "Device1"})
|
||||
.size(),
|
||||
2);
|
||||
}
|
||||
|
||||
TEST(Grappler, DeviceTypeNotSet) {
|
||||
auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultParam(params);
|
||||
params->device_type = nullptr;
|
||||
};
|
||||
|
||||
tensorflow::Status status = InitGraphPlugin(plugin_init);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(
|
||||
status.error_message(),
|
||||
"'device_type' field in TP_OptimizerRegistrationParams must be set.");
|
||||
}
|
||||
|
||||
TEST(Grappler, OptimizeFuncNotSet) {
|
||||
auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultParam(params);
|
||||
params->device_type = "FuncNotSet";
|
||||
params->optimizer->optimize_func = nullptr;
|
||||
};
|
||||
|
||||
tensorflow::Status status = InitGraphPlugin(plugin_init);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(status.error_message(),
|
||||
"'optimize_func' field in TP_Optimizer must be set.");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
|
||||
RegistrationMap;
|
||||
RegistrationMap* registered_optimizers = nullptr;
|
||||
@ -31,6 +31,53 @@ RegistrationMap* GetRegistrationMap() {
|
||||
registered_optimizers = new RegistrationMap;
|
||||
return registered_optimizers;
|
||||
}
|
||||
|
||||
// This map is a global map for registered plugin optimizers. It contains the
|
||||
// device_type as its key, and an optimizer creator as the value.
|
||||
typedef std::unordered_map<string, PluginGraphOptimizerRegistry::Creator>
|
||||
PluginRegistrationMap;
|
||||
PluginRegistrationMap* registered_plugin_optimizers = nullptr;
|
||||
PluginRegistrationMap* GetPluginRegistrationMap() {
|
||||
if (registered_plugin_optimizers == nullptr)
|
||||
registered_plugin_optimizers = new PluginRegistrationMap;
|
||||
return registered_plugin_optimizers;
|
||||
}
|
||||
|
||||
// This map is a global map for registered plugin configs. It contains the
|
||||
// device_type as its key, and ConfigList as the value.
|
||||
typedef std::unordered_map<string, ConfigList> PluginConfigMap;
|
||||
PluginConfigMap* plugin_config_map = nullptr;
|
||||
PluginConfigMap* GetPluginConfigMap() {
|
||||
if (plugin_config_map == nullptr) plugin_config_map = new PluginConfigMap;
|
||||
return plugin_config_map;
|
||||
}
|
||||
|
||||
// Returns plugin's default configuration for each Grappler optimizer (on/off).
|
||||
// See tensorflow/core/protobuf/rewriter_config.proto for more details about
|
||||
// each optimizer.
|
||||
const ConfigList& DefaultPluginConfigs() {
|
||||
static ConfigList* default_plugin_configs =
|
||||
new ConfigList(/*disable_model_pruning=*/false,
|
||||
{{"implementation_selector", RewriterConfig::ON},
|
||||
{"function_optimization", RewriterConfig::ON},
|
||||
{"common_subgraph_elimination", RewriterConfig::ON},
|
||||
{"arithmetic_optimization", RewriterConfig::ON},
|
||||
{"debug_stripper", RewriterConfig::ON},
|
||||
{"constant_folding", RewriterConfig::ON},
|
||||
{"shape_optimization", RewriterConfig::ON},
|
||||
{"auto_mixed_precision", RewriterConfig::ON},
|
||||
{"auto_mixed_precision_mkl", RewriterConfig::ON},
|
||||
{"pin_to_host_optimization", RewriterConfig::ON},
|
||||
{"layout_optimizer", RewriterConfig::ON},
|
||||
{"remapping", RewriterConfig::ON},
|
||||
{"loop_optimization", RewriterConfig::ON},
|
||||
{"dependency_optimization", RewriterConfig::ON},
|
||||
{"auto_parallel", RewriterConfig::ON},
|
||||
{"memory_optimization", RewriterConfig::ON},
|
||||
{"scoped_allocator_optimization", RewriterConfig::ON}});
|
||||
return *default_plugin_configs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<CustomGraphOptimizer>
|
||||
@ -57,5 +104,114 @@ void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
|
||||
GetRegistrationMap()->insert({name, optimizer_creator});
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<CustomGraphOptimizer>>
|
||||
PluginGraphOptimizerRegistry::CreateOptimizers(
|
||||
const std::set<string>& device_types) {
|
||||
std::vector<std::unique_ptr<CustomGraphOptimizer>> optimizer_list;
|
||||
for (auto it = GetPluginRegistrationMap()->begin();
|
||||
it != GetPluginRegistrationMap()->end(); ++it) {
|
||||
if (device_types.find(it->first) == device_types.end()) continue;
|
||||
LOG(INFO) << "Plugin optimizer for device_type " << it->first
|
||||
<< " is enabled.";
|
||||
optimizer_list.emplace_back(
|
||||
std::unique_ptr<CustomGraphOptimizer>(it->second()));
|
||||
}
|
||||
return optimizer_list;
|
||||
}
|
||||
|
||||
void PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
|
||||
const Creator& optimizer_creator, const std::string& device_type,
|
||||
ConfigList& configs) {
|
||||
auto ret = GetPluginConfigMap()->insert({device_type, configs});
|
||||
if (!ret.second) {
|
||||
LOG(FATAL) << "PluginGraphOptimizer with device_type " // Crash OK
|
||||
<< device_type << " is registered twice.";
|
||||
}
|
||||
GetPluginRegistrationMap()->insert({device_type, optimizer_creator});
|
||||
}
|
||||
|
||||
void PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(
|
||||
const std::set<string>& device_types) {
|
||||
bool init = false, conflict = false;
|
||||
ConfigList plugin_configs;
|
||||
// Check if plugin's configs have conflict.
|
||||
for (const auto& device_type : device_types) {
|
||||
const auto it = GetPluginConfigMap()->find(device_type);
|
||||
if (it == GetPluginConfigMap()->end()) continue;
|
||||
auto cur_plugin_configs = it->second;
|
||||
|
||||
if (!init) {
|
||||
plugin_configs = cur_plugin_configs;
|
||||
init = true;
|
||||
} else {
|
||||
if (!(plugin_configs == cur_plugin_configs)) {
|
||||
conflict = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!conflict) return;
|
||||
LOG(WARNING) << "Plugins have conflicting configs. Potential performance "
|
||||
"regression may happen.";
|
||||
for (const auto& device_type : device_types) {
|
||||
const auto it = GetPluginConfigMap()->find(device_type);
|
||||
if (it == GetPluginConfigMap()->end()) continue;
|
||||
auto cur_plugin_configs = it->second;
|
||||
|
||||
// Print logs in following style:
|
||||
// disable_model_pruning 0
|
||||
// remapping 1
|
||||
// ...
|
||||
string logs = "";
|
||||
strings::StrAppend(&logs, "disable_model_pruning\t\t",
|
||||
cur_plugin_configs.disable_model_pruning, "\n");
|
||||
for (auto const& pair : cur_plugin_configs.toggle_config) {
|
||||
strings::StrAppend(&logs, pair.first, string(32 - pair.first.size(), ' '),
|
||||
(pair.second != RewriterConfig::OFF), "\n");
|
||||
}
|
||||
LOG(WARNING) << "Plugin's configs for device_type " << device_type << ":\n"
|
||||
<< logs;
|
||||
}
|
||||
}
|
||||
|
||||
ConfigList PluginGraphOptimizerRegistry::GetPluginConfigs(
|
||||
bool use_plugin_optimizers, const std::set<string>& device_types) {
|
||||
if (!use_plugin_optimizers) return DefaultPluginConfigs();
|
||||
|
||||
ConfigList ret_plugin_configs = DefaultPluginConfigs();
|
||||
for (const auto& device_type : device_types) {
|
||||
const auto it = GetPluginConfigMap()->find(device_type);
|
||||
if (it == GetPluginConfigMap()->end()) continue;
|
||||
auto cur_plugin_configs = it->second;
|
||||
// If any of the plugin turns on `disable_model_pruning`,
|
||||
// then `disable_model_pruning` should be true;
|
||||
if (cur_plugin_configs.disable_model_pruning == true)
|
||||
ret_plugin_configs.disable_model_pruning = true;
|
||||
|
||||
// If any of the plugin turns off a certain optimizer,
|
||||
// then the optimizer should be turned off;
|
||||
for (auto& pair : cur_plugin_configs.toggle_config) {
|
||||
if (cur_plugin_configs.toggle_config[pair.first] == RewriterConfig::OFF)
|
||||
ret_plugin_configs.toggle_config[pair.first] = RewriterConfig::OFF;
|
||||
}
|
||||
}
|
||||
|
||||
return ret_plugin_configs;
|
||||
}
|
||||
|
||||
bool PluginGraphOptimizerRegistry::IsConfigsConflict(
|
||||
ConfigList& user_config, ConfigList& plugin_config) {
|
||||
if (plugin_config == DefaultPluginConfigs()) return false;
|
||||
if (user_config.disable_model_pruning != plugin_config.disable_model_pruning)
|
||||
return true;
|
||||
// Returns true if user_config is turned on but plugin_config is turned off.
|
||||
for (auto& pair : user_config.toggle_config) {
|
||||
if ((user_config.toggle_config[pair.first] == RewriterConfig::ON) &&
|
||||
(plugin_config.toggle_config[pair.first] == RewriterConfig::OFF))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -25,6 +25,23 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Contains plugin's configurations for each Grappler optimizer (on/off).
|
||||
// See tensorflow/core/protobuf/rewriter_config.proto for optimizer description.
|
||||
struct ConfigList {
|
||||
ConfigList() {}
|
||||
ConfigList(bool disable_model_pruning,
|
||||
std::unordered_map<string, RewriterConfig_Toggle> config)
|
||||
: disable_model_pruning(disable_model_pruning),
|
||||
toggle_config(std::move(config)) {}
|
||||
|
||||
bool operator==(const ConfigList& other) const {
|
||||
return (disable_model_pruning == other.disable_model_pruning) &&
|
||||
(toggle_config == other.toggle_config);
|
||||
}
|
||||
bool disable_model_pruning; // Don't remove unnecessary ops from the graph.
|
||||
std::unordered_map<string, RewriterConfig_Toggle> toggle_config;
|
||||
};
|
||||
|
||||
class CustomGraphOptimizerRegistry {
|
||||
public:
|
||||
static std::unique_ptr<CustomGraphOptimizer> CreateByNameOrNull(
|
||||
@ -59,6 +76,40 @@ class CustomGraphOptimizerRegistrar {
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, \
|
||||
#MyCustomGraphOptimizerClass)
|
||||
|
||||
// A separate registry to register all plug-in CustomGraphOptimizers.
|
||||
class PluginGraphOptimizerRegistry {
|
||||
public:
|
||||
// Constructs a list of plug-in CustomGraphOptimizers from the global map
|
||||
// `registered_plugin_optimizers`.
|
||||
static std::vector<std::unique_ptr<CustomGraphOptimizer>> CreateOptimizers(
|
||||
const std::set<string>& device_types);
|
||||
|
||||
typedef std::function<CustomGraphOptimizer*()> Creator;
|
||||
|
||||
// Returns plugin's config. If any of the config is turned off, the returned
|
||||
// config will be turned off.
|
||||
static ConfigList GetPluginConfigs(bool use_plugin_optimizers,
|
||||
const std::set<string>& device_types);
|
||||
|
||||
// Registers plugin graph optimizer which can be called during program
|
||||
// initialization. Dies if multiple plugins with the same `device_type` are
|
||||
// registered. This class is not thread-safe.
|
||||
static void RegisterPluginOptimizerOrDie(const Creator& optimizer_creator,
|
||||
const std::string& device_type,
|
||||
ConfigList& configs);
|
||||
|
||||
// Prints plugin's configs if there are some conflicts.
|
||||
static void PrintPluginConfigsIfConflict(
|
||||
const std::set<string>& device_types);
|
||||
|
||||
// Returns true when `plugin_config` conflicts with `user_config`:
|
||||
// - Plugin's `disable_model_pruning` is not equal to `user_config`'s, or
|
||||
// - At least one of plugin's `toggle_config`s is on when it is set to off in
|
||||
// `user_config`'s.
|
||||
static bool IsConfigsConflict(ConfigList& user_config,
|
||||
ConfigList& plugin_config);
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -29,11 +29,12 @@ namespace grappler {
|
||||
namespace {
|
||||
|
||||
static const char* kTestOptimizerName = "Test";
|
||||
static const char* kTestPluginOptimizerName = "TestPlugin";
|
||||
|
||||
class TestGraphOptimizer : public CustomGraphOptimizer {
|
||||
public:
|
||||
Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
|
||||
nullptr) override {
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
string name() const override { return kTestOptimizerName; }
|
||||
@ -86,6 +87,34 @@ TEST(GraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) {
|
||||
"twice");
|
||||
}
|
||||
|
||||
class TestPluginGraphOptimizer : public CustomGraphOptimizer {
|
||||
public:
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
string name() const override { return kTestPluginOptimizerName; }
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override {
|
||||
return Status::OK();
|
||||
}
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override {}
|
||||
};
|
||||
|
||||
TEST(PluginGraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) {
|
||||
const auto creator = []() { return new TestPluginGraphOptimizer; };
|
||||
ConfigList config_list;
|
||||
PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(creator, "GPU",
|
||||
config_list);
|
||||
PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(creator, "CPU",
|
||||
config_list);
|
||||
EXPECT_DEATH(PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
|
||||
creator, "GPU", config_list),
|
||||
"twice");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user