Merge pull request #46868 from Intel-tensorflow:yang/graph-api-registration

PiperOrigin-RevId: 360977145
Change-Id: I783f7923144f62168e3a9951fd0a3359913340ef
This commit is contained in:
TensorFlower Gardener 2021-03-04 12:28:33 -08:00
commit 0577bbdd46
12 changed files with 926 additions and 3 deletions

View File

@ -659,6 +659,7 @@ tf_cuda_cc_test(
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":c_api_internal",
":c_test_util",
":test_op_kernel",
"//tensorflow/cc:cc_ops",

View File

@ -238,6 +238,15 @@ Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
return Status::OK();
}
Status BufferToMessage(const TF_Buffer* in,
tensorflow::protobuf::MessageLite* out) {
if (in == nullptr || !out->ParseFromArray(in->data, in->length)) {
return errors::InvalidArgument("Unparseable ", out->GetTypeName(),
" proto");
}
return Status::OK();
}
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type) {
// If any session has already run this node_id, mark this session as

View File

@ -190,6 +190,9 @@ namespace tensorflow {
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);
Status BufferToMessage(const TF_Buffer* in,
tensorflow::protobuf::MessageLite* out);
// Set the shapes and types of the output's handle.
//
// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
@ -44,6 +45,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/strcat.h"
@ -2576,6 +2578,20 @@ TEST(CAPI, TestTensorIsNotAligned) {
TF_DeleteTensor(a);
}
TEST(CAPI, MessageBufferConversion) {
NodeDef node_in, node_out;
node_in.set_name("Test name");
node_in.set_op("Test op");
TF_Buffer* buffer = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(node_in, buffer));
TF_CHECK_OK(BufferToMessage(buffer, &node_out));
TF_DeleteBuffer(buffer);
protobuf::util::MessageDifferencer differencer;
EXPECT_TRUE(differencer.Compare(node_in, node_out));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
# Description:
# Graph C API.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "grappler_hdrs",
hdrs = ["grappler.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_headers",
],
)
cc_library(
name = "grappler",
srcs = ["grappler.cc"],
hdrs = [
"grappler.h",
"grappler_internal.h",
],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_map",
],
)
tf_cc_test(
name = "grappler_test",
srcs = ["grappler_test.cc"],
deps = [
":grappler",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
],
)

View File

@ -0,0 +1,195 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file extends/implements core graph optimizer base classes in terms of
// the C API defined in grappler.h. A class "CSomething" represents a
// "Something" that can be manipulated via calls in the C interface and a C
// struct called "TP_Something".
#include "tensorflow/c/experimental/grappler/grappler.h"
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/grappler/grappler_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
namespace {
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
do { \
if (STRUCT_OBJ.struct_size == 0) { \
return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
"struct_size field in " #STRUCT_NAME \
" must be set to " #SIZE_VALUE_NAME "."); \
} \
} while (0)
#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
do { \
if (STRUCT_OBJ.NAME == 0) { \
return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
"'" #NAME "' field in " #STRUCT_NAME \
" must be set."); \
} \
} while (0)
tensorflow::Status ValidateTPOptimizerRegistrationParams(
const TP_OptimizerRegistrationParams& params) {
VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params,
TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE);
VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type);
return tensorflow::Status::OK();
}
tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) {
VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE);
VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func);
return tensorflow::Status::OK();
}
tensorflow::Status ValidateTPOptimizerConfigs(
const TP_OptimizerConfigs& configs) {
VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs,
TP_OPTIMIZER_CONFIGS_STRUCT_SIZE);
return tensorflow::Status::OK();
}
#undef VALIDATE_MEMBER
#undef VALIDATE_STRUCT_SIZE
// A map containing the input graph as its key, and TF_GrapplerItem as the
// value. Users can fetch GrapplerItem for additional info to transform the
// graph.
absl::flat_hash_map<TF_Buffer*, const TF_GrapplerItem*>* GrapplerItemMap() {
static absl::flat_hash_map<TF_Buffer*, const TF_GrapplerItem*>*
grappler_items =
new absl::flat_hash_map<TF_Buffer*, const TF_GrapplerItem*>;
return grappler_items;
}
} // namespace
namespace tensorflow {
namespace grappler {
Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph_def) {
OwnedTFStatus c_status(TF_NewStatus());
OwnedTFBuffer graph_buf(TF_NewBuffer());
OwnedTFBuffer optimized_graph_buf(TF_NewBuffer());
TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get()));
const auto it = GrapplerItemMap()->find(graph_buf.get());
if (it == GrapplerItemMap()->end())
GrapplerItemMap()->insert(
{graph_buf.get(), reinterpret_cast<const TF_GrapplerItem*>(&item)});
optimizer_.optimize_func(c_optimizer_, graph_buf.get(),
optimized_graph_buf.get(), c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(
BufferToMessage(optimized_graph_buf.get(), optimized_graph_def));
GrapplerItemMap()->erase(graph_buf.get());
return Status::OK();
}
#define CONFIG_TOGGLE(optimizer) \
if (tp_configs.optimizer == TF_TriState_Off) \
configs.toggle_config[#optimizer] = RewriterConfig::OFF; \
else \
configs.toggle_config[#optimizer] = RewriterConfig::ON;
void CGraphOptimizerRegister(
const PluginGraphOptimizerRegistry::Creator& creator,
const TP_OptimizerConfigs tp_configs, const char* device_type) {
ConfigList configs;
// disable_model_pruning is turned off by default.
if (tp_configs.disable_model_pruning == TF_TriState_On)
configs.disable_model_pruning = true;
else
configs.disable_model_pruning = false;
// The other configs are turned on by default.
CONFIG_TOGGLE(implementation_selector);
CONFIG_TOGGLE(function_optimization);
CONFIG_TOGGLE(common_subgraph_elimination);
CONFIG_TOGGLE(arithmetic_optimization);
CONFIG_TOGGLE(debug_stripper);
CONFIG_TOGGLE(constant_folding);
CONFIG_TOGGLE(shape_optimization);
CONFIG_TOGGLE(auto_mixed_precision);
CONFIG_TOGGLE(auto_mixed_precision_mkl);
CONFIG_TOGGLE(pin_to_host_optimization);
CONFIG_TOGGLE(layout_optimizer);
CONFIG_TOGGLE(remapping);
CONFIG_TOGGLE(loop_optimization);
CONFIG_TOGGLE(dependency_optimization);
CONFIG_TOGGLE(auto_parallel);
CONFIG_TOGGLE(memory_optimization);
CONFIG_TOGGLE(scoped_allocator_optimization);
PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
creator, device_type, configs);
}
#undef CONFIG_TOGGLE
tensorflow::Status InitGraphPlugin(void* dso_handle) {
tensorflow::Env* env = tensorflow::Env::Default();
// Step 1: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol));
// Step 2: Call `TF_InitPlugin`
auto init_fn = reinterpret_cast<TFInitGraphPluginFn>(dso_symbol);
return InitGraphPlugin(init_fn);
}
tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) {
TP_OptimizerRegistrationParams params{
TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
params.major_version = SE_MAJOR;
params.minor_version = SE_MINOR;
params.patch_version = SE_PATCH;
params.optimizer = &optimizer;
params.optimizer_configs = &optimizer_configs;
OwnedTFStatus c_status(TF_NewStatus());
init_fn(&params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params));
TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer));
TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs));
CGraphOptimizerRegister(
[=]() { return new CGraphOptimizer(optimizer, params.device_type); },
optimizer_configs, params.device_type);
return Status::OK();
}
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,173 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_status.h"
// --------------------------------------------------------------------------
// C API for Graph. The API is under active development and eventually
// should allow registering a plugin graph optimizer with TensorFlow.
//
// Conventions:
// * Struct prefix indicates whether struct fields should be filled by the
// plugin or core implementation:
// * Struct that should be filled by the plugin: `TP_OptimizerConfigs`,
// `TP_Optimizer`, `TP_OptimizerRegistrationParams`
// * Struct that should be filled by the proper: `TF_GrapplerItem`,
// `TF_GraphProperties`, `TF_FunctionLibraryDefinition`
// * We use `struct_size` for version checking. It should be set both by
// core and the plugin.
// * For example, `TF_InitGraph` function receives
// `TP_OptimizerRegistrationParams*` as input with `struct_size`
// populated by core. The plugin is responsible for setting
// `struct_size` as well, along with all other fields.
// * Refer to "TensorFlow Versioning Strategy" section at
// https://github.com/tensorflow/community/pull/257/files.
// * Note that the API is still under active development and doesn't have
// versioning guarantees yet.
// * `void* ext` is a free-form field that can be populated by
// a plugin in `TP_*` structs or potential future extension points .
//
// Example usage:
//
// /* Sample TensorFlow code below, exact implementation might differ. */
// // Version checking uses `struct_size`. It should be set both by core
// // and the plugin.
// TP_OptimizerRegistrationParams params{
// TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
// TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
// TP_OptimizerConfigs configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
// params.optimizer = &optimizer;
// params.configs = &configs;
//
// /* Plugin code below */
// void TF_InitGraph(TP_OptimizerRegistrationParams* params,
// TF_Status* status) {
// params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE;
// params->device_type = "MY_DEVICE";
//
// // Disable certain optimizer.
// params->optimizer_configs->struct_size =
// TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; params->optimizer_configs->remapping =
// TF_TriState_Off;
//
// // Set functions to create a new optimizer.
// params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE;
// params->optimizer->create_func = (My_optimizer::create_func);
// }
#define SE_MAJOR 0
#define SE_MINOR 0
#define SE_PATCH 1
#ifdef __cplusplus
extern "C" {
#endif
// TF_TriState is the C API typedef for tri-state.
typedef enum TF_TriState {
TF_TriState_Default = 0,
TF_TriState_Off,
TF_TriState_On,
} TF_TriState;
// Flags indicating whether existing optimizers should be turned off.
// It's optional for plugin to set functions to return true/false. If not
// set, proper uses configuration set by user.
typedef struct TP_OptimizerConfigs {
size_t struct_size;
void* ext; // reserved for future use
TF_TriState disable_model_pruning;
TF_TriState implementation_selector;
TF_TriState function_optimization;
TF_TriState common_subgraph_elimination;
TF_TriState arithmetic_optimization;
TF_TriState debug_stripper;
TF_TriState constant_folding;
TF_TriState shape_optimization;
TF_TriState auto_mixed_precision;
TF_TriState auto_mixed_precision_mkl;
TF_TriState pin_to_host_optimization;
TF_TriState layout_optimizer;
TF_TriState remapping;
TF_TriState loop_optimization;
TF_TriState dependency_optimization;
TF_TriState auto_parallel;
TF_TriState memory_optimization;
TF_TriState scoped_allocator_optimization;
} TP_OptimizerConfigs;
#define TP_OPTIMIZER_CONFIGS_STRUCT_SIZE \
TF_OFFSET_OF_END(TP_OptimizerConfigs, scoped_allocator_optimization)
// Struct for Optimizer. Plugin authors must provide an optimize function.
// Creation and deletion functions are optional.
typedef struct TP_Optimizer {
size_t struct_size;
void* ext; // reserved for future use
// [Optional]
// Create function for optimizer.
void* (*create_func)();
// Optimizer function for optimizer. The first param is an optimizer created
// by create_func. The second param is input graph. The third param is output
// graph.
void (*optimize_func)(void*, TF_Buffer*, TF_Buffer*, TF_Status*);
// [Optional]
// Destroy function for optimizer. If Create function is provided, destroy
// function is must.
void (*destroy_func)(void*);
} TP_Optimizer;
#define TP_OPTIMIZER_STRUCT_SIZE TF_OFFSET_OF_END(TP_Optimizer, destroy_func)
typedef struct TP_OptimizerRegistrationParams {
size_t struct_size;
void* ext; // reserved for future use
// Graph C API version.
int32_t major_version;
int32_t minor_version;
int32_t patch_version;
// Backend device type supported by the optimizer.
const char* device_type;
TP_OptimizerConfigs* optimizer_configs; // output, set by plugin
TP_Optimizer* optimizer; // output, set by plugin
} TP_OptimizerRegistrationParams;
#define TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE \
TF_OFFSET_OF_END(TP_OptimizerRegistrationParams, optimizer)
// TF_InitGraph is used to do graph optimizer registration.
// Plugin should implement TF_InitGraph to register graph optimizers.
void TF_InitGraph(TP_OptimizerRegistrationParams* params, TF_Status* status);
// TF_GrapplerItem represents a combination of a graph, one of more fetch nodes,
// and potentially a set of nodes to feed.
typedef struct TF_GrapplerItem TF_GrapplerItem;
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_

View File

@ -0,0 +1,106 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Classes and utilities that work with Graph C API for internal use.
// This includes functions used for optimizer registration and interfaces needed
// for testing.
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
#include <functional>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/grappler/grappler.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
// Plugin initialization function that a device plugin
// must define.
typedef void (*TFInitGraphPluginFn)(TP_OptimizerRegistrationParams* const,
TF_Status* const);
// Registers Graph optimizers.
Status InitGraphPlugin(void* dso_handle);
// Allow registering a graph optimizer using a function (used for
// testing).
Status InitGraphPlugin(TFInitGraphPluginFn init_fn);
struct GrapplerItem;
class Cluster;
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
struct TFBufferDeleter {
void operator()(TF_Buffer* buf) const { TF_DeleteBuffer(buf); }
};
using OwnedTFBuffer = std::unique_ptr<TF_Buffer, TFBufferDeleter>;
class CGraphOptimizer : public CustomGraphOptimizer {
public:
explicit CGraphOptimizer(TP_Optimizer optimizer, const char* device_type)
: optimizer_(optimizer), device_type_(device_type) {
if (optimizer.create_func != nullptr) {
c_optimizer_ = (*optimizer_.create_func)();
} else {
c_optimizer_ = nullptr;
}
}
std::string name() const override { return "PluggableGraphOptimizer"; }
bool UsesFunctionLibrary() const override { return false; }
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override {}
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph_def) override;
~CGraphOptimizer() override {
if (optimizer_.destroy_func != nullptr) {
(*optimizer_.destroy_func)(c_optimizer_);
}
}
private:
TP_Optimizer optimizer_;
std::string device_type_;
void* c_optimizer_;
};
// Registration function to register a CGraphOptimizer along with plugin configs
// and device type.
void CGraphOptimizerRegister(
const PluginGraphOptimizerRegistry::Creator& creator,
const TP_OptimizerConfigs tp_configs, const char* device_type);
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0(the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/grappler/grappler.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/grappler/grappler_internal.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
void optimize_func(void* optimizer, TF_Buffer* graph_buf,
TF_Buffer* optimized_graph_buf, TF_Status* tf_status) {}
void PopulateDefaultParam(TP_OptimizerRegistrationParams* params) {
params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE;
params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE;
params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE;
params->optimizer->create_func = nullptr;
params->optimizer->optimize_func = optimize_func;
params->optimizer->destroy_func = nullptr;
}
TEST(Grappler, SuccessfulRegistration) {
auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultParam(params);
params->device_type = "Success";
params->optimizer_configs->remapping = TF_TriState_Off;
};
TF_ASSERT_OK(InitGraphPlugin(plugin_init));
ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
std::set<string>{"Success"})
.size(),
1);
ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs(
true, std::set<string>{"Success"});
ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF);
}
TEST(Grappler, MultiplePluginRegistration) {
auto plugin_init_0 = [](TP_OptimizerRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultParam(params);
params->device_type = "Device0";
};
auto plugin_init_1 = [](TP_OptimizerRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultParam(params);
params->device_type = "Device1";
};
TF_ASSERT_OK(InitGraphPlugin(plugin_init_0));
TF_ASSERT_OK(InitGraphPlugin(plugin_init_1));
ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
std::set<string>{"Device0", "Device1"})
.size(),
2);
}
TEST(Grappler, DeviceTypeNotSet) {
auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultParam(params);
params->device_type = nullptr;
};
tensorflow::Status status = InitGraphPlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(
status.error_message(),
"'device_type' field in TP_OptimizerRegistrationParams must be set.");
}
TEST(Grappler, OptimizeFuncNotSet) {
auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultParam(params);
params->device_type = "FuncNotSet";
params->optimizer->optimize_func = nullptr;
};
tensorflow::Status status = InitGraphPlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(),
"'optimize_func' field in TP_Optimizer must be set.");
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -21,8 +21,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
RegistrationMap;
RegistrationMap* registered_optimizers = nullptr;
@ -31,6 +31,53 @@ RegistrationMap* GetRegistrationMap() {
registered_optimizers = new RegistrationMap;
return registered_optimizers;
}
// This map is a global map for registered plugin optimizers. It contains the
// device_type as its key, and an optimizer creator as the value.
typedef std::unordered_map<string, PluginGraphOptimizerRegistry::Creator>
PluginRegistrationMap;
PluginRegistrationMap* registered_plugin_optimizers = nullptr;
PluginRegistrationMap* GetPluginRegistrationMap() {
if (registered_plugin_optimizers == nullptr)
registered_plugin_optimizers = new PluginRegistrationMap;
return registered_plugin_optimizers;
}
// This map is a global map for registered plugin configs. It contains the
// device_type as its key, and ConfigList as the value.
typedef std::unordered_map<string, ConfigList> PluginConfigMap;
PluginConfigMap* plugin_config_map = nullptr;
PluginConfigMap* GetPluginConfigMap() {
if (plugin_config_map == nullptr) plugin_config_map = new PluginConfigMap;
return plugin_config_map;
}
// Returns plugin's default configuration for each Grappler optimizer (on/off).
// See tensorflow/core/protobuf/rewriter_config.proto for more details about
// each optimizer.
const ConfigList& DefaultPluginConfigs() {
static ConfigList* default_plugin_configs =
new ConfigList(/*disable_model_pruning=*/false,
{{"implementation_selector", RewriterConfig::ON},
{"function_optimization", RewriterConfig::ON},
{"common_subgraph_elimination", RewriterConfig::ON},
{"arithmetic_optimization", RewriterConfig::ON},
{"debug_stripper", RewriterConfig::ON},
{"constant_folding", RewriterConfig::ON},
{"shape_optimization", RewriterConfig::ON},
{"auto_mixed_precision", RewriterConfig::ON},
{"auto_mixed_precision_mkl", RewriterConfig::ON},
{"pin_to_host_optimization", RewriterConfig::ON},
{"layout_optimizer", RewriterConfig::ON},
{"remapping", RewriterConfig::ON},
{"loop_optimization", RewriterConfig::ON},
{"dependency_optimization", RewriterConfig::ON},
{"auto_parallel", RewriterConfig::ON},
{"memory_optimization", RewriterConfig::ON},
{"scoped_allocator_optimization", RewriterConfig::ON}});
return *default_plugin_configs;
}
} // namespace
std::unique_ptr<CustomGraphOptimizer>
@ -57,5 +104,114 @@ void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
GetRegistrationMap()->insert({name, optimizer_creator});
}
std::vector<std::unique_ptr<CustomGraphOptimizer>>
PluginGraphOptimizerRegistry::CreateOptimizers(
const std::set<string>& device_types) {
std::vector<std::unique_ptr<CustomGraphOptimizer>> optimizer_list;
for (auto it = GetPluginRegistrationMap()->begin();
it != GetPluginRegistrationMap()->end(); ++it) {
if (device_types.find(it->first) == device_types.end()) continue;
LOG(INFO) << "Plugin optimizer for device_type " << it->first
<< " is enabled.";
optimizer_list.emplace_back(
std::unique_ptr<CustomGraphOptimizer>(it->second()));
}
return optimizer_list;
}
void PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
const Creator& optimizer_creator, const std::string& device_type,
ConfigList& configs) {
auto ret = GetPluginConfigMap()->insert({device_type, configs});
if (!ret.second) {
LOG(FATAL) << "PluginGraphOptimizer with device_type " // Crash OK
<< device_type << " is registered twice.";
}
GetPluginRegistrationMap()->insert({device_type, optimizer_creator});
}
void PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(
const std::set<string>& device_types) {
bool init = false, conflict = false;
ConfigList plugin_configs;
// Check if plugin's configs have conflict.
for (const auto& device_type : device_types) {
const auto it = GetPluginConfigMap()->find(device_type);
if (it == GetPluginConfigMap()->end()) continue;
auto cur_plugin_configs = it->second;
if (!init) {
plugin_configs = cur_plugin_configs;
init = true;
} else {
if (!(plugin_configs == cur_plugin_configs)) {
conflict = true;
break;
}
}
}
if (!conflict) return;
LOG(WARNING) << "Plugins have conflicting configs. Potential performance "
"regression may happen.";
for (const auto& device_type : device_types) {
const auto it = GetPluginConfigMap()->find(device_type);
if (it == GetPluginConfigMap()->end()) continue;
auto cur_plugin_configs = it->second;
// Print logs in following style:
// disable_model_pruning 0
// remapping 1
// ...
string logs = "";
strings::StrAppend(&logs, "disable_model_pruning\t\t",
cur_plugin_configs.disable_model_pruning, "\n");
for (auto const& pair : cur_plugin_configs.toggle_config) {
strings::StrAppend(&logs, pair.first, string(32 - pair.first.size(), ' '),
(pair.second != RewriterConfig::OFF), "\n");
}
LOG(WARNING) << "Plugin's configs for device_type " << device_type << ":\n"
<< logs;
}
}
ConfigList PluginGraphOptimizerRegistry::GetPluginConfigs(
bool use_plugin_optimizers, const std::set<string>& device_types) {
if (!use_plugin_optimizers) return DefaultPluginConfigs();
ConfigList ret_plugin_configs = DefaultPluginConfigs();
for (const auto& device_type : device_types) {
const auto it = GetPluginConfigMap()->find(device_type);
if (it == GetPluginConfigMap()->end()) continue;
auto cur_plugin_configs = it->second;
// If any of the plugin turns on `disable_model_pruning`,
// then `disable_model_pruning` should be true;
if (cur_plugin_configs.disable_model_pruning == true)
ret_plugin_configs.disable_model_pruning = true;
// If any of the plugin turns off a certain optimizer,
// then the optimizer should be turned off;
for (auto& pair : cur_plugin_configs.toggle_config) {
if (cur_plugin_configs.toggle_config[pair.first] == RewriterConfig::OFF)
ret_plugin_configs.toggle_config[pair.first] = RewriterConfig::OFF;
}
}
return ret_plugin_configs;
}
bool PluginGraphOptimizerRegistry::IsConfigsConflict(
ConfigList& user_config, ConfigList& plugin_config) {
if (plugin_config == DefaultPluginConfigs()) return false;
if (user_config.disable_model_pruning != plugin_config.disable_model_pruning)
return true;
// Returns true if user_config is turned on but plugin_config is turned off.
for (auto& pair : user_config.toggle_config) {
if ((user_config.toggle_config[pair.first] == RewriterConfig::ON) &&
(plugin_config.toggle_config[pair.first] == RewriterConfig::OFF))
return true;
}
return false;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -25,6 +25,23 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// Contains plugin's configurations for each Grappler optimizer (on/off).
// See tensorflow/core/protobuf/rewriter_config.proto for optimizer description.
struct ConfigList {
ConfigList() {}
ConfigList(bool disable_model_pruning,
std::unordered_map<string, RewriterConfig_Toggle> config)
: disable_model_pruning(disable_model_pruning),
toggle_config(std::move(config)) {}
bool operator==(const ConfigList& other) const {
return (disable_model_pruning == other.disable_model_pruning) &&
(toggle_config == other.toggle_config);
}
bool disable_model_pruning; // Don't remove unnecessary ops from the graph.
std::unordered_map<string, RewriterConfig_Toggle> toggle_config;
};
class CustomGraphOptimizerRegistry {
public:
static std::unique_ptr<CustomGraphOptimizer> CreateByNameOrNull(
@ -59,6 +76,40 @@ class CustomGraphOptimizerRegistrar {
REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, \
#MyCustomGraphOptimizerClass)
// A separate registry to register all plug-in CustomGraphOptimizers.
class PluginGraphOptimizerRegistry {
public:
// Constructs a list of plug-in CustomGraphOptimizers from the global map
// `registered_plugin_optimizers`.
static std::vector<std::unique_ptr<CustomGraphOptimizer>> CreateOptimizers(
const std::set<string>& device_types);
typedef std::function<CustomGraphOptimizer*()> Creator;
// Returns plugin's config. If any of the config is turned off, the returned
// config will be turned off.
static ConfigList GetPluginConfigs(bool use_plugin_optimizers,
const std::set<string>& device_types);
// Registers plugin graph optimizer which can be called during program
// initialization. Dies if multiple plugins with the same `device_type` are
// registered. This class is not thread-safe.
static void RegisterPluginOptimizerOrDie(const Creator& optimizer_creator,
const std::string& device_type,
ConfigList& configs);
// Prints plugin's configs if there are some conflicts.
static void PrintPluginConfigsIfConflict(
const std::set<string>& device_types);
// Returns true when `plugin_config` conflicts with `user_config`:
// - Plugin's `disable_model_pruning` is not equal to `user_config`'s, or
// - At least one of plugin's `toggle_config`s is on when it is set to off in
// `user_config`'s.
static bool IsConfigsConflict(ConfigList& user_config,
ConfigList& plugin_config);
};
} // end namespace grappler
} // end namespace tensorflow

View File

@ -29,11 +29,12 @@ namespace grappler {
namespace {
static const char* kTestOptimizerName = "Test";
static const char* kTestPluginOptimizerName = "TestPlugin";
class TestGraphOptimizer : public CustomGraphOptimizer {
public:
Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
nullptr) override {
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
string name() const override { return kTestOptimizerName; }
@ -86,6 +87,34 @@ TEST(GraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) {
"twice");
}
class TestPluginGraphOptimizer : public CustomGraphOptimizer {
public:
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
string name() const override { return kTestPluginOptimizerName; }
bool UsesFunctionLibrary() const override { return false; }
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override {
return Status::OK();
}
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override {}
};
TEST(PluginGraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) {
const auto creator = []() { return new TestPluginGraphOptimizer; };
ConfigList config_list;
PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(creator, "GPU",
config_list);
PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(creator, "CPU",
config_list);
EXPECT_DEATH(PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
creator, "GPU", config_list),
"twice");
}
} // namespace
} // namespace grappler
} // namespace tensorflow