Export grappler code from C++ to Python with pybind11 instead of swig. Adding pywrap_required_hdrs to downstream grappler build files which will go away with bazel first class shared library support.
PiperOrigin-RevId: 292986813 Change-Id: I24cbeba85e593fcb4604913874c0ac01eac73e4d
This commit is contained in:
parent
7954fb8fd1
commit
e3e22538e8
@ -41,6 +41,16 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"training/coordinator.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
|
@ -4,6 +4,18 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"devices.h",
|
||||
"grappler_item.h",
|
||||
"grappler_item_builder.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_types",
|
||||
srcs = ["op_types.cc"],
|
||||
|
@ -53,6 +53,19 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"cluster.h",
|
||||
"single_machine.h",
|
||||
"utils.h",
|
||||
"virtual_cluster.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cluster",
|
||||
srcs = ["cluster.cc"],
|
||||
|
@ -36,6 +36,25 @@ tf_pyclif_proto_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"analytical_cost_estimator.h",
|
||||
"cost_estimator.h",
|
||||
"graph_memory.h",
|
||||
"graph_properties.h",
|
||||
"measuring_cost_estimator.h",
|
||||
"op_context.h",
|
||||
"op_level_cost_estimator.h",
|
||||
"utils.h",
|
||||
"virtual_placer.h",
|
||||
"virtual_scheduler.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_properties",
|
||||
srcs = ["graph_properties.cc"],
|
||||
|
@ -11,6 +11,17 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"graph_optimizer.h",
|
||||
"meta_optimizer.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "static_schedule",
|
||||
srcs = ["static_schedule.cc"],
|
||||
|
@ -8,6 +8,16 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"topological_sort.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scc",
|
||||
srcs = ["scc.cc"],
|
||||
|
@ -4,6 +4,16 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"graph_verifier.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_verifier",
|
||||
hdrs = [
|
||||
|
@ -364,6 +364,40 @@ cc_library(
|
||||
] + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
# Necessary for the pywrap inclusion below. Combining targets does not work
|
||||
# properly.
|
||||
tf_pybind_cc_library_wrapper(
|
||||
name = "cost_analyzer_headers",
|
||||
deps = [
|
||||
":cost_analyzer_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_cost_analyzer",
|
||||
srcs = ["grappler/cost_analyzer_wrapper.cc"],
|
||||
hdrs = [
|
||||
"grappler/cost_analyzer.h",
|
||||
"//tensorflow/cc:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
|
||||
"//tensorflow/core/public:session.h",
|
||||
"//tensorflow/core/public:session_options.h",
|
||||
],
|
||||
module_name = "_pywrap_cost_analyzer",
|
||||
deps = [
|
||||
":cost_analyzer_headers",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_id",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_analyzer_lib",
|
||||
srcs = ["grappler/model_analyzer.cc"],
|
||||
@ -380,14 +414,16 @@ cc_library(
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_model_analyzer",
|
||||
srcs = ["grappler/model_analyzer_wrapper.cc"],
|
||||
hdrs = ["grappler/model_analyzer.h"],
|
||||
hdrs = [
|
||||
"grappler/model_analyzer.h",
|
||||
"//tensorflow/core/grappler:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_model_analyzer",
|
||||
deps = [
|
||||
":model_analyzer_lib",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item_builder",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
@ -5652,10 +5688,6 @@ tf_py_wrap_cc(
|
||||
name = "pywrap_tensorflow_internal",
|
||||
srcs = ["tensorflow.i"],
|
||||
swig_includes = [
|
||||
"grappler/cluster.i",
|
||||
"grappler/cost_analyzer.i",
|
||||
"grappler/item.i",
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/strings.i",
|
||||
"platform/base.i",
|
||||
],
|
||||
@ -5721,10 +5753,11 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
||||
":cpp_python_util", # util
|
||||
":py_func_lib", # py_func
|
||||
":model_analyzer_lib", # model_analyzer
|
||||
"//tensorflow/core/util:port", # util_port
|
||||
":cost_analyzer_lib", # cost_analyzer
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||
"//tensorflow/core/util:port", # util_port
|
||||
"//tensorflow/core:framework_internal_impl", # op_def_registry
|
||||
"//tensorflow/core:lib_internal_impl", # device_lib
|
||||
"//tensorflow/core:core_cpu_impl", # device_lib
|
||||
@ -5749,6 +5782,20 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
||||
"//tensorflow/core:core_cpu_base_no_ops", # tf_session
|
||||
"//tensorflow/c:python_api", # tf_session
|
||||
"//tensorflow/python:tf_session_helper", # tf_session
|
||||
"//tensorflow/core/grappler:grappler_item", # tf_item
|
||||
"//tensorflow/core/grappler/costs:graph_properties", # tf_item
|
||||
"//tensorflow/core/grappler:grappler_item_builder", # tf_item
|
||||
"//tensorflow/core/grappler/utils:topological_sort", # tf_item
|
||||
"//tensorflow/core/grappler/costs:utils", # tf_cluster
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer", # tf_optimizer
|
||||
"//tensorflow/core/grappler/clusters:cluster", # tf_cluster
|
||||
"//tensorflow/core/grappler/clusters:single_machine", # tf_cluster
|
||||
"//tensorflow/core/grappler/costs:op_level_cost_estimator", # tf_cluster
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster", # tf_cluster
|
||||
"//tensorflow/core/grappler/costs:graph_memory", # tf_cluster
|
||||
"//tensorflow/core/grappler:devices", # tf_cluster
|
||||
"//tensorflow/core/grappler/clusters:utils", # tf_optimizer
|
||||
"//tensorflow/core/grappler/costs:measuring_cost_estimator", # tf_cluster
|
||||
]
|
||||
|
||||
# Filter the DEF file to reduce the number of symbols to 64K or less.
|
||||
@ -7359,11 +7406,32 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pywrap_tensorflow_internal",
|
||||
":_pywrap_tf_item",
|
||||
"//tensorflow/core/grappler/costs:op_performance_data_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tf_item",
|
||||
srcs = ["grappler/item_wrapper.cc"],
|
||||
hdrs = [
|
||||
"//tensorflow/cc:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/utils:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_tf_item",
|
||||
deps = [
|
||||
":pybind11_status",
|
||||
"@pybind11",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_id",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]), # b/148556093,
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "item_test",
|
||||
size = "small",
|
||||
@ -7414,11 +7482,34 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pywrap_tensorflow_internal",
|
||||
":_pywrap_tf_cluster",
|
||||
"//tensorflow/core/grappler/costs:op_performance_data_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tf_cluster",
|
||||
srcs = ["grappler/cluster_wrapper.cc"],
|
||||
hdrs = [
|
||||
"//tensorflow/cc:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/utils:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_tf_cluster",
|
||||
deps = [
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_id",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cluster_test",
|
||||
size = "small",
|
||||
@ -7451,11 +7542,34 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pywrap_tensorflow_internal",
|
||||
":_pywrap_tf_optimizer",
|
||||
":tf_cluster",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tf_optimizer",
|
||||
srcs = ["grappler/tf_optimizer_wrapper.cc"],
|
||||
hdrs = [
|
||||
"//tensorflow/cc:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/optimizers:pywrap_required_hdrs",
|
||||
"//tensorflow/core/grappler/verifiers:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_tf_optimizer",
|
||||
deps = [
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_id",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tf_optimizer_test",
|
||||
size = "small",
|
||||
@ -7619,7 +7733,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow_internal",
|
||||
":_pywrap_cost_analyzer",
|
||||
":tf_cluster",
|
||||
":tf_item",
|
||||
],
|
||||
|
@ -1,450 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%include <std_shared_ptr.i>
|
||||
%include "item.i"
|
||||
|
||||
// Wrap the cluster into an object that swig can manipulate. This ensures it will call the object
|
||||
// destructor upon garbage collection instead of leaking memory.
|
||||
struct GCluster {
|
||||
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
|
||||
};
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||
|
||||
template <>
|
||||
bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::NamedDevice named_device;
|
||||
if (!named_device.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The NamedDevice could not be parsed as a valid protocol buffer");
|
||||
return false;
|
||||
}
|
||||
if (out) *out = named_device;
|
||||
return true;
|
||||
}
|
||||
%}
|
||||
|
||||
%typemap(in) const std::vector<tensorflow::NamedDevice>& (std::vector<tensorflow::NamedDevice> temp) {
|
||||
if (!tf_vector_input_helper($input, &temp, &_PyObjAs<tensorflow::NamedDevice>)) {
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(in) const tensorflow::NamedDevice& (tensorflow::NamedDevice temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The NamedDevice could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(in) const tensorflow::RunMetadata& (tensorflow::RunMetadata temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The RunMetadata could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(in) const string& (string temp) {
|
||||
char *buf;
|
||||
Py_ssize_t len;
|
||||
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) return NULL;
|
||||
temp.assign(buf, len);
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
|
||||
// Provide the implementation of the GCluster struct here.
|
||||
struct GCluster {
|
||||
GCluster() {}
|
||||
GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {}
|
||||
|
||||
tensorflow::grappler::Cluster* operator->() const {
|
||||
return cluster_.get();
|
||||
}
|
||||
tensorflow::grappler::Cluster* get() const {
|
||||
return cluster_.get();
|
||||
}
|
||||
bool is_none() const {
|
||||
return cluster_.get() == nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
|
||||
};
|
||||
|
||||
|
||||
static GCluster TF_NewCluster(bool allow_soft_placement,
|
||||
bool disable_detailed_stats, TF_Status* status) {
|
||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
int timeout_s = 60 * 10;
|
||||
tensorflow::grappler::Cluster* cluster_ =
|
||||
new tensorflow::grappler::SingleMachine(
|
||||
timeout_s, num_cpu_cores, num_gpus);
|
||||
cluster_->DisableDetailedStats(disable_detailed_stats);
|
||||
cluster_->AllowSoftPlacement(allow_soft_placement);
|
||||
cluster_->SetNumWarmupSteps(10);
|
||||
tensorflow::Status s = cluster_->Provision();
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return GCluster(cluster_);
|
||||
}
|
||||
|
||||
static GCluster TF_NewVirtualCluster(
|
||||
const std::vector<tensorflow::NamedDevice>& named_devices, TF_Status* status) {
|
||||
std::unordered_map<string, tensorflow::DeviceProperties> devices;
|
||||
for (const auto& named_device : named_devices) {
|
||||
devices[named_device.name()]= named_device.properties();
|
||||
}
|
||||
tensorflow::grappler::Cluster* cluster_ =
|
||||
new tensorflow::grappler::VirtualCluster(devices);
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
tensorflow::Status s = cluster_->Provision();
|
||||
PyGILState_Release(gstate);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return GCluster(cluster_);
|
||||
}
|
||||
|
||||
static void TF_ShutdownCluster(GCluster cluster) {
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
cluster->Shutdown();
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
|
||||
tensorflow::Status _GetOpPerformanceDataAndRunTime(
|
||||
const tensorflow::grappler::GrapplerItem& item,
|
||||
tensorflow::grappler::CostEstimator* cost_measure,
|
||||
tensorflow::OpPerformanceList* op_performance_data,
|
||||
tensorflow::grappler::Costs* costs) {
|
||||
tensorflow::Status status = cost_measure->Initialize(item);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
TF_RETURN_IF_ERROR(
|
||||
cost_measure->PredictCosts(item.graph, &run_metadata, costs));
|
||||
|
||||
if (op_performance_data) {
|
||||
*op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
|
||||
run_metadata.cost_graph(), item.graph);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
static PyObject* TF_ListDevices(GCluster cluster) {
|
||||
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* result = PyList_New(devices.size());
|
||||
int i = 0;
|
||||
for (auto& dev : devices) {
|
||||
tensorflow::NamedDevice d;
|
||||
d.set_name(dev.first);
|
||||
*d.mutable_properties() = dev.second;
|
||||
string dev_str = d.SerializeAsString();
|
||||
PyObject* dev_obj = PyBytes_FromStringAndSize(dev_str.data(),
|
||||
dev_str.size());
|
||||
PyList_SetItem(result, i, dev_obj);
|
||||
++i;
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* TF_ListAvailableOps() {
|
||||
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
|
||||
std::vector<tensorflow::OpDef> ops;
|
||||
registry->GetRegisteredOps(&ops);
|
||||
std::vector<string> op_names;
|
||||
for (const tensorflow::OpDef& op : ops) {
|
||||
op_names.push_back(op.name());
|
||||
}
|
||||
std::sort(op_names.begin(), op_names.end());
|
||||
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* result = PyList_New(op_names.size());
|
||||
for (int i = 0; i < op_names.size(); ++i) {
|
||||
PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str()));
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
|
||||
if (cluster.is_none() || item.is_none()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
|
||||
std::unordered_map<string, std::vector<string>> device_types;
|
||||
for (const auto& dev : devices) {
|
||||
device_types[dev.second.type()].push_back(dev.first);
|
||||
}
|
||||
|
||||
std::unordered_map<string, std::set<string>> supported_device_types;
|
||||
std::unordered_map<string, std::set<string>> device_restrictions;
|
||||
|
||||
for (const auto& node : item->graph.node()) {
|
||||
for (const auto& dev : device_types) {
|
||||
const string& type = dev.first;
|
||||
if (cluster->type() != "single_machine") {
|
||||
// The actual kernel may not be linked in this binary.
|
||||
supported_device_types[node.name()].insert(type);
|
||||
} else {
|
||||
// Check the kernel capabilities
|
||||
const tensorflow::DeviceType dev_type(type);
|
||||
tensorflow::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
|
||||
if (s.ok()) {
|
||||
supported_device_types[node.name()].insert(type);
|
||||
|
||||
// Check which inputs are restricted to reside on the host.
|
||||
// TODO: extends this to support outputs as well
|
||||
tensorflow::MemoryTypeVector inp_mtypes;
|
||||
tensorflow::MemoryTypeVector out_mtypes;
|
||||
s = tensorflow::MemoryTypesForNode(tensorflow::OpRegistry::Global(), dev_type, node,
|
||||
&inp_mtypes, &out_mtypes);
|
||||
if (s.ok()) {
|
||||
for (int i = 0; i < inp_mtypes.size(); ++i) {
|
||||
if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
|
||||
device_restrictions[tensorflow::grappler::NodeName(node.input(i))].insert("CPU");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* result = PyDict_New();
|
||||
|
||||
for (const auto& supported_dev : supported_device_types) {
|
||||
const string& node = supported_dev.first;
|
||||
std::set<string> feasible;
|
||||
const auto it = device_restrictions.find(node);
|
||||
if (it != device_restrictions.end()) {
|
||||
const std::set<string>& candidates = supported_dev.second;
|
||||
const std::set<string>& valid = it->second;
|
||||
std::set_intersection(candidates.begin(), candidates.end(), valid.begin(), valid.end(),
|
||||
std::inserter(feasible, feasible.begin()));
|
||||
} else {
|
||||
feasible = supported_dev.second;
|
||||
}
|
||||
|
||||
std::vector<string> device_names;
|
||||
for (const string& type : feasible) {
|
||||
auto it = device_types.find(type);
|
||||
CHECK(it != device_types.end());
|
||||
for (const string& name : it->second) {
|
||||
device_names.push_back(name);
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* dev = PyList_New(device_names.size());
|
||||
for (int i = 0; i < device_names.size(); ++i) {
|
||||
PyList_SetItem(dev, i, PyString_FromString(device_names[i].c_str()));
|
||||
}
|
||||
CHECK_EQ(0, PyDict_SetItem(result, PyString_FromString(node.c_str()), dev));
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
|
||||
tensorflow::grappler::OpLevelCostEstimator estimator;
|
||||
tensorflow::grappler::DeviceInfo info =
|
||||
estimator.GetDeviceInfo(device.properties());
|
||||
return info.gigaops;
|
||||
}
|
||||
|
||||
static PyObject* TF_MeasureCosts(
|
||||
GItem item,
|
||||
GCluster cluster,
|
||||
bool generate_timeline, TF_Status* status) {
|
||||
tensorflow::OpPerformanceList op_performance_data;
|
||||
tensorflow::StepStats step_stats;
|
||||
|
||||
const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
|
||||
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
|
||||
|
||||
tensorflow::grappler::Costs costs;
|
||||
tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
|
||||
*item, &cost_measure, &op_performance_data, &costs);
|
||||
double run_time = FLT_MAX;
|
||||
if (s.ok()) {
|
||||
run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
|
||||
}
|
||||
if (generate_timeline) {
|
||||
tensorflow::RunMetadata metadata;
|
||||
tensorflow::Status run_status = cluster->Run(
|
||||
item->graph, item->feed, item->fetch, &metadata);
|
||||
if (run_status.ok()) {
|
||||
step_stats = metadata.step_stats();
|
||||
} else {
|
||||
s = run_status;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
if (!s.ok()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* op_perf_objs = PyList_New(
|
||||
op_performance_data.op_performance_size());
|
||||
for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
|
||||
string op_perf_str =
|
||||
op_performance_data.op_performance(i).SerializeAsString();
|
||||
PyObject* op_perf_obj = PyBytes_FromStringAndSize(op_perf_str.data(),
|
||||
op_perf_str.size());
|
||||
PyList_SetItem(op_perf_objs, i, op_perf_obj);
|
||||
}
|
||||
|
||||
PyObject* run_time_obj = PyFloat_FromDouble(run_time);
|
||||
|
||||
string step_stats_str = step_stats.SerializeAsString();
|
||||
PyObject* metadata_obj = PyBytes_FromStringAndSize(step_stats_str.data(),
|
||||
step_stats_str.size());
|
||||
|
||||
PyObject* ret = PyTuple_New(3);
|
||||
if (PyTuple_SetItem(ret, 0, op_perf_objs) != 0 ||
|
||||
PyTuple_SetItem(ret, 1, run_time_obj) != 0 ||
|
||||
PyTuple_SetItem(ret, 2, metadata_obj) != 0) {
|
||||
Py_DECREF(ret);
|
||||
Py_XDECREF(op_perf_objs);
|
||||
Py_XDECREF(run_time_obj);
|
||||
Py_XDECREF(metadata_obj);
|
||||
s = tensorflow::Status(tensorflow::error::Code::INTERNAL,
|
||||
"Error setting return tuples.");
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
Py_INCREF(Py_None);
|
||||
ret = Py_None;
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
static PyObject* TF_DeterminePeakMemoryUsage(
|
||||
GItem item,
|
||||
GCluster cluster,
|
||||
TF_Status* status) {
|
||||
if (item.is_none() || cluster.is_none()) {
|
||||
tensorflow::Status s(tensorflow::error::Code::INTERNAL,
|
||||
"You need both a cluster and an item to determine peak memory usage");
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
tensorflow::grappler::GraphMemory memory(*item);
|
||||
|
||||
tensorflow::Status s;
|
||||
if (cluster->DetailedStatsEnabled()) {
|
||||
s = memory.InferDynamically(cluster.get());
|
||||
} else {
|
||||
s = memory.InferStatically(cluster->GetDevices());
|
||||
}
|
||||
if (!s.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* result = PyDict_New();
|
||||
for (const auto& device : cluster->GetDevices()) {
|
||||
const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
|
||||
memory.GetPeakMemoryUsage(device.first);
|
||||
PyObject* per_device = PyList_New(usage.live_tensors.size());
|
||||
for (int i = 0; i < usage.live_tensors.size(); ++i) {
|
||||
const auto& live_tensor = usage.live_tensors[i];
|
||||
PyObject* live = PyTuple_New(5);
|
||||
PyTuple_SetItem(live, 0, PyString_FromString(live_tensor.node.c_str()));
|
||||
PyTuple_SetItem(live, 1, PyInt_FromLong(live_tensor.output_id));
|
||||
PyTuple_SetItem(live, 2, PyLong_FromLong(live_tensor.memory_used));
|
||||
PyTuple_SetItem(live, 3, PyLong_FromLong(live_tensor.allocation_time.count()));
|
||||
PyTuple_SetItem(live, 4, PyLong_FromLong(live_tensor.deallocation_time.count()));
|
||||
PyList_SetItem(per_device, i, live);
|
||||
|
||||
}
|
||||
PyObject* ret = PyTuple_New(2);
|
||||
PyTuple_SetItem(ret, 0, PyLong_FromLong(usage.used_memory));
|
||||
PyTuple_SetItem(ret, 1, per_device);
|
||||
PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return result;
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
// Wrap these functions.
|
||||
static GCluster TF_NewCluster(
|
||||
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* status);
|
||||
static GCluster TF_NewVirtualCluster(
|
||||
const std::vector<tensorflow::NamedDevice>& named_devices,
|
||||
TF_Status* status);
|
||||
static void TF_ShutdownCluster(GCluster cluster);
|
||||
static PyObject* TF_ListDevices(GCluster cluster);
|
||||
static PyObject* TF_ListAvailableOps();
|
||||
static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
|
||||
static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
|
||||
static PyObject* TF_MeasureCosts(
|
||||
GItem item, GCluster cluster,
|
||||
bool generate_timeline, TF_Status* status);
|
||||
static PyObject* TF_DeterminePeakMemoryUsage(
|
||||
GItem item, GCluster cluster,
|
||||
TF_Status* status);
|
@ -23,7 +23,7 @@ import contextlib
|
||||
from tensorflow.core.framework import step_stats_pb2
|
||||
from tensorflow.core.grappler.costs import op_performance_data_pb2
|
||||
from tensorflow.core.protobuf import device_properties_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_cluster
|
||||
from tensorflow.python import _pywrap_tf_cluster as tf_cluster
|
||||
|
||||
|
||||
class Cluster(object):
|
||||
@ -92,13 +92,9 @@ class Cluster(object):
|
||||
item: The item for which to measure the costs.
|
||||
Returns: The triplet op_perfs, runtime, step_stats.
|
||||
"""
|
||||
ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item, self._tf_cluster,
|
||||
self._generate_timeline)
|
||||
op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts(
|
||||
item.tf_item, self._tf_cluster, self._generate_timeline)
|
||||
|
||||
if ret_from_swig is None:
|
||||
return None
|
||||
|
||||
op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig
|
||||
op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
|
||||
for op_perf_bytes in op_perf_bytes_list]
|
||||
return (op_perfs, run_time,
|
||||
|
332
tensorflow/python/grappler/cluster_wrapper.cc
Normal file
332
tensorflow/python/grappler/cluster_wrapper.cc
Normal file
@ -0,0 +1,332 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
tensorflow::Status _GetOpPerformanceDataAndRunTime(
|
||||
const tensorflow::grappler::GrapplerItem& item,
|
||||
tensorflow::grappler::CostEstimator* cost_measure,
|
||||
tensorflow::OpPerformanceList* op_performance_data,
|
||||
tensorflow::grappler::Costs* costs) {
|
||||
tensorflow::Status status = cost_measure->Initialize(item);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
cost_measure->PredictCosts(item.graph, &run_metadata, costs));
|
||||
|
||||
if (op_performance_data) {
|
||||
*op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
|
||||
run_metadata.cost_graph(), item.graph);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(tensorflow::grappler::Cluster);
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tf_cluster, m) {
|
||||
py::class_<tensorflow::grappler::Cluster> grappler_cluster(
|
||||
m, "tensorflow::grappler::Cluster");
|
||||
|
||||
m.def("TF_NewCluster",
|
||||
[](bool allow_soft_placement,
|
||||
bool disable_detailed_stats) -> tensorflow::grappler::Cluster* {
|
||||
// TODO(petebu): Make these named arguments with default values
|
||||
// instead.
|
||||
int num_cpu_cores =
|
||||
tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
int timeout_s = 60 * 10;
|
||||
std::unique_ptr<tensorflow::grappler::Cluster> cluster =
|
||||
std::make_unique<tensorflow::grappler::SingleMachine>(
|
||||
timeout_s, num_cpu_cores, num_gpus);
|
||||
cluster->DisableDetailedStats(disable_detailed_stats);
|
||||
cluster->AllowSoftPlacement(allow_soft_placement);
|
||||
cluster->SetNumWarmupSteps(10);
|
||||
MaybeRaiseRegisteredFromStatus(cluster->Provision());
|
||||
return cluster.release();
|
||||
});
|
||||
|
||||
m.def("TF_NewVirtualCluster",
|
||||
[](const std::vector<py::bytes>& serialized_named_devices)
|
||||
-> tensorflow::grappler::Cluster* {
|
||||
std::vector<tensorflow::NamedDevice> named_devices;
|
||||
for (const auto& s : serialized_named_devices) {
|
||||
tensorflow::NamedDevice named_device;
|
||||
if (!named_device.ParseFromString(s)) {
|
||||
throw std::invalid_argument(
|
||||
"The NamedDevice could not be parsed as a valid protocol "
|
||||
"buffer");
|
||||
}
|
||||
named_devices.push_back(named_device);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, tensorflow::DeviceProperties> devices;
|
||||
for (const auto& named_device : named_devices) {
|
||||
devices[named_device.name()] = named_device.properties();
|
||||
}
|
||||
std::unique_ptr<tensorflow::grappler::Cluster> cluster =
|
||||
std::make_unique<tensorflow::grappler::VirtualCluster>(devices);
|
||||
{
|
||||
// TODO(petebu): Do we need to hold the GIL here?
|
||||
py::gil_scoped_acquire acquire;
|
||||
MaybeRaiseRegisteredFromStatus(cluster->Provision());
|
||||
}
|
||||
return cluster.release();
|
||||
});
|
||||
|
||||
m.def("TF_ShutdownCluster", [](tensorflow::grappler::Cluster* cluster) {
|
||||
// TODO(petebu): Do we need to hold the GIL here?
|
||||
py::gil_scoped_acquire acquire;
|
||||
cluster->Shutdown();
|
||||
});
|
||||
|
||||
m.def("TF_ListDevices",
|
||||
[](tensorflow::grappler::Cluster* cluster) -> std::vector<py::bytes> {
|
||||
const std::unordered_map<std::string, tensorflow::DeviceProperties>&
|
||||
devices = cluster->GetDevices();
|
||||
std::vector<py::bytes> named_devices;
|
||||
for (auto& dev : devices) {
|
||||
tensorflow::NamedDevice d;
|
||||
d.set_name(dev.first);
|
||||
*d.mutable_properties() = dev.second;
|
||||
named_devices.push_back(d.SerializeAsString());
|
||||
}
|
||||
return named_devices;
|
||||
});
|
||||
|
||||
m.def("TF_ListAvailableOps", []() -> std::vector<std::string> {
|
||||
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
|
||||
std::vector<tensorflow::OpDef> ops;
|
||||
registry->GetRegisteredOps(&ops);
|
||||
std::vector<std::string> op_names;
|
||||
for (const tensorflow::OpDef& op : ops) {
|
||||
op_names.push_back(op.name());
|
||||
}
|
||||
std::sort(op_names.begin(), op_names.end());
|
||||
return op_names;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"TF_GetSupportedDevices",
|
||||
[](tensorflow::grappler::Cluster* cluster,
|
||||
tensorflow::grappler::GrapplerItem* item)
|
||||
-> std::unordered_map<std::string, std::vector<std::string>> {
|
||||
if (cluster == nullptr || item == nullptr) {
|
||||
MaybeRaiseRegisteredFromStatus(tensorflow::Status(
|
||||
tensorflow::errors::Internal("You need both a cluster and an "
|
||||
"item to get supported devices.")));
|
||||
}
|
||||
const std::unordered_map<std::string, tensorflow::DeviceProperties>&
|
||||
devices = cluster->GetDevices();
|
||||
std::unordered_map<std::string, std::vector<std::string>> device_types;
|
||||
for (const auto& dev : devices) {
|
||||
device_types[dev.second.type()].push_back(dev.first);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::set<std::string>>
|
||||
supported_device_types;
|
||||
std::unordered_map<std::string, std::set<std::string>>
|
||||
device_restrictions;
|
||||
|
||||
for (const auto& node : item->graph.node()) {
|
||||
for (const auto& dev : device_types) {
|
||||
const std::string& type = dev.first;
|
||||
if (cluster->type() != "single_machine") {
|
||||
// The actual kernel may not be linked in this binary.
|
||||
supported_device_types[node.name()].insert(type);
|
||||
} else {
|
||||
// Check the kernel capabilities
|
||||
const tensorflow::DeviceType dev_type(type);
|
||||
tensorflow::Status s =
|
||||
tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
|
||||
if (s.ok()) {
|
||||
supported_device_types[node.name()].insert(type);
|
||||
|
||||
// Check which inputs are restricted to reside on the host.
|
||||
// TODO: extends this to support outputs as well
|
||||
tensorflow::MemoryTypeVector inp_mtypes;
|
||||
tensorflow::MemoryTypeVector out_mtypes;
|
||||
tensorflow::Status s = tensorflow::MemoryTypesForNode(
|
||||
tensorflow::OpRegistry::Global(), dev_type, node,
|
||||
&inp_mtypes, &out_mtypes);
|
||||
if (s.ok()) {
|
||||
for (int i = 0; i < inp_mtypes.size(); ++i) {
|
||||
if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
|
||||
device_restrictions[tensorflow::grappler::NodeName(
|
||||
node.input(i))]
|
||||
.insert("CPU");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::string>> result;
|
||||
for (const auto& supported_dev : supported_device_types) {
|
||||
const std::string& node = supported_dev.first;
|
||||
std::set<std::string> feasible;
|
||||
const auto it = device_restrictions.find(node);
|
||||
if (it != device_restrictions.end()) {
|
||||
const std::set<std::string>& candidates = supported_dev.second;
|
||||
const std::set<std::string>& valid = it->second;
|
||||
std::set_intersection(candidates.begin(), candidates.end(),
|
||||
valid.begin(), valid.end(),
|
||||
std::inserter(feasible, feasible.begin()));
|
||||
} else {
|
||||
feasible = supported_dev.second;
|
||||
}
|
||||
|
||||
std::vector<std::string> device_names;
|
||||
for (const std::string& type : feasible) {
|
||||
auto it = device_types.find(type);
|
||||
DCHECK(it != device_types.end());
|
||||
for (const std::string& name : it->second) {
|
||||
device_names.push_back(name);
|
||||
}
|
||||
}
|
||||
result[node] = device_names;
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
m.def("TF_EstimatePerformance", [](const py::bytes& serialized_device) {
|
||||
tensorflow::NamedDevice device;
|
||||
if (!device.ParseFromString(serialized_device)) {
|
||||
throw std::invalid_argument(
|
||||
"The NamedDevice could not be parsed as a valid protocol buffer");
|
||||
}
|
||||
tensorflow::grappler::OpLevelCostEstimator estimator;
|
||||
tensorflow::grappler::DeviceInfo info =
|
||||
estimator.GetDeviceInfo(device.properties());
|
||||
return info.gigaops;
|
||||
});
|
||||
|
||||
m.def("TF_MeasureCosts",
|
||||
[](tensorflow::grappler::GrapplerItem* item,
|
||||
tensorflow::grappler::Cluster* cluster, bool generate_timeline)
|
||||
-> std::tuple<std::vector<py::bytes>, double, py::bytes> {
|
||||
const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
|
||||
tensorflow::grappler::MeasuringCostEstimator cost_measure(
|
||||
cluster, num_measurements, 0);
|
||||
|
||||
tensorflow::OpPerformanceList op_performance_data;
|
||||
tensorflow::grappler::Costs costs;
|
||||
tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
|
||||
*item, &cost_measure, &op_performance_data, &costs);
|
||||
double run_time = FLT_MAX;
|
||||
if (s.ok()) {
|
||||
run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
|
||||
}
|
||||
tensorflow::StepStats step_stats;
|
||||
if (generate_timeline) {
|
||||
tensorflow::RunMetadata metadata;
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
cluster->Run(item->graph, item->feed, item->fetch, &metadata));
|
||||
step_stats = metadata.step_stats();
|
||||
}
|
||||
|
||||
std::vector<py::bytes> op_perf_objs;
|
||||
op_perf_objs.resize(op_performance_data.op_performance_size());
|
||||
for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
|
||||
op_perf_objs[i] =
|
||||
op_performance_data.op_performance(i).SerializeAsString();
|
||||
}
|
||||
|
||||
py::bytes step_stats_str = step_stats.SerializeAsString();
|
||||
return std::make_tuple(op_perf_objs, run_time, step_stats_str);
|
||||
});
|
||||
|
||||
using DurationType = tensorflow::grappler::Costs::Duration::rep;
|
||||
using MemoryUsage =
|
||||
std::tuple<std::string, int, size_t, DurationType, DurationType>;
|
||||
|
||||
m.def(
|
||||
"TF_DeterminePeakMemoryUsage",
|
||||
[](tensorflow::grappler::GrapplerItem* item,
|
||||
tensorflow::grappler::Cluster* cluster)
|
||||
-> std::unordered_map<std::string,
|
||||
std::tuple<int64_t, std::vector<MemoryUsage>>> {
|
||||
if (item == nullptr || cluster == nullptr) {
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
tensorflow::Status(tensorflow::errors::Internal(
|
||||
"You need both a cluster and an item to determine peak "
|
||||
"memory usage.")));
|
||||
}
|
||||
tensorflow::grappler::GraphMemory memory(*item);
|
||||
|
||||
if (cluster->DetailedStatsEnabled()) {
|
||||
MaybeRaiseRegisteredFromStatus(memory.InferDynamically(cluster));
|
||||
} else {
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
memory.InferStatically(cluster->GetDevices()));
|
||||
}
|
||||
|
||||
std::unordered_map<std::string,
|
||||
std::tuple<int64_t, std::vector<MemoryUsage>>>
|
||||
result;
|
||||
for (const auto& device : cluster->GetDevices()) {
|
||||
const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
|
||||
memory.GetPeakMemoryUsage(device.first);
|
||||
std::vector<MemoryUsage> per_device;
|
||||
for (int i = 0; i < usage.live_tensors.size(); ++i) {
|
||||
const auto& live_tensor = usage.live_tensors[i];
|
||||
per_device.push_back(std::make_tuple(
|
||||
live_tensor.node, live_tensor.output_id,
|
||||
live_tensor.memory_used, live_tensor.allocation_time.count(),
|
||||
live_tensor.deallocation_time.count()));
|
||||
}
|
||||
result[device.first] = std::make_tuple(usage.used_memory, per_device);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
@ -1,67 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%include "cluster.i"
|
||||
|
||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/python/grappler/cost_analyzer.h"
|
||||
%}
|
||||
|
||||
%{
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
|
||||
bool verbose, GCluster cluster) {
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.apply_optimizations = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
|
||||
if (!item) {
|
||||
return "Error: failed to preprocess metagraph: check your log file for errors";
|
||||
}
|
||||
|
||||
string suffix;
|
||||
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix);
|
||||
|
||||
std::stringstream os;
|
||||
analyzer.GenerateReport(os, per_node_report, verbose);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
|
||||
bool verbose, GCluster cluster);
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python import _pywrap_cost_analyzer as tf_wrap
|
||||
from tensorflow.python.grappler import cluster as gcluster
|
||||
from tensorflow.python.grappler import item as gitem
|
||||
|
||||
@ -44,10 +44,9 @@ def GenerateCostReport(metagraph,
|
||||
if cluster is None:
|
||||
cluster = gcluster.Cluster(disable_detailed_stats=False)
|
||||
|
||||
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
|
||||
return tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
|
||||
per_node_report, verbose,
|
||||
cluster.tf_cluster)
|
||||
return ret_from_swig
|
||||
|
||||
|
||||
def GenerateMemoryReport(metagraph, detailed_report=True, cluster=None):
|
||||
|
58
tensorflow/python/grappler/cost_analyzer_wrapper.cc
Normal file
58
tensorflow/python/grappler/cost_analyzer_wrapper.cc
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/python/grappler/cost_analyzer.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_cost_analyzer, m) {
|
||||
m.def("GenerateCostReport",
|
||||
[](const py::bytes& serialized_metagraph, bool per_node_report,
|
||||
bool verbose, tensorflow::grappler::Cluster* cluster) -> py::bytes {
|
||||
tensorflow::MetaGraphDef metagraph;
|
||||
if (!metagraph.ParseFromString(serialized_metagraph)) {
|
||||
return "The MetaGraphDef could not be parsed as a valid protocol "
|
||||
"buffer";
|
||||
}
|
||||
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.apply_optimizations = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
|
||||
"metagraph", metagraph, cfg);
|
||||
if (item == nullptr) {
|
||||
return "Error: failed to preprocess metagraph: check your log file "
|
||||
"for errors";
|
||||
}
|
||||
|
||||
std::string suffix;
|
||||
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster, suffix);
|
||||
|
||||
std::stringstream os;
|
||||
tensorflow::MaybeRaiseFromStatus(
|
||||
analyzer.GenerateReport(os, per_node_report, verbose));
|
||||
return py::bytes(os.str());
|
||||
});
|
||||
}
|
@ -1,315 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include <std_shared_ptr.i>
|
||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
// Wrap the item into an object that swig can manipulate. This ensures it will call the object
|
||||
// destructor upon garbage collection instead of leaking memory.
|
||||
struct GItem {
|
||||
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
|
||||
};
|
||||
|
||||
|
||||
%{
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
// Provide the implementation fo the GItem struct here.
|
||||
struct GItem {
|
||||
GItem() {}
|
||||
GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {}
|
||||
|
||||
tensorflow::grappler::GrapplerItem* operator->() const {
|
||||
return item_.get();
|
||||
}
|
||||
const tensorflow::grappler::GrapplerItem& operator*() const {
|
||||
return *item_.get();
|
||||
}
|
||||
bool is_none() const {
|
||||
return item_.get() == nullptr;
|
||||
}
|
||||
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
|
||||
};
|
||||
|
||||
static GItem TF_NewItem(
|
||||
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
||||
bool ignore_user_placement, TF_Status* status) {
|
||||
if (meta_graph.collection_def().count("train_op") == 0) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status,
|
||||
tensorflow::errors::InvalidArgument("train_op not specified in the metagraph"));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.ignore_user_placement = ignore_user_placement;
|
||||
cfg.ignore_colocation = ignore_colocation;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg);
|
||||
if (!item) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status,
|
||||
tensorflow::errors::InvalidArgument("Invalid metagraph"));
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Set_TF_Status_from_Status(status, tensorflow::Status::OK());
|
||||
return GItem(item.release());
|
||||
}
|
||||
|
||||
static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
|
||||
TF_Status* status) {
|
||||
if (item.is_none()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin();
|
||||
std::vector<const tensorflow::NodeDef*> enqueue_ops = item->EnqueueOpsFanin();
|
||||
std::unordered_set<string> op_names;
|
||||
for (auto op : main_ops) {
|
||||
op_names.insert(op->name());
|
||||
}
|
||||
for (auto op : enqueue_ops) {
|
||||
op_names.insert(op->name());
|
||||
}
|
||||
|
||||
std::vector<string> ops;
|
||||
if (sort_topologically) {
|
||||
tensorflow::GraphDef subgraph;
|
||||
for (const tensorflow::NodeDef& node : item->graph.node()) {
|
||||
if (op_names.find(node.name()) != op_names.end()) {
|
||||
*subgraph.add_node() = node;
|
||||
}
|
||||
}
|
||||
tensorflow::Status s = tensorflow::grappler::TopologicalSort(&subgraph);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
for (const tensorflow::NodeDef& node : subgraph.node()) {
|
||||
ops.push_back(node.name());
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (const auto& op_name : op_names) {
|
||||
ops.push_back(op_name);
|
||||
}
|
||||
}
|
||||
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* result = PyList_New(ops.size());
|
||||
for (int i = 0; i < ops.size(); ++i) {
|
||||
PyList_SetItem(result, i, PyString_FromString(ops[i].c_str()));
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* TF_GetOpProperties(GItem item) {
|
||||
if (item.is_none()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
tensorflow::grappler::GraphProperties properties(*item);
|
||||
tensorflow::Status status = properties.InferStatically(false);
|
||||
if (!status.ok()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* props = PyDict_New();
|
||||
for (const auto& node : item->graph.node()) {
|
||||
const string& node_name = node.name();
|
||||
const std::vector<tensorflow::OpInfo::TensorProperties>& output_props =
|
||||
properties.GetOutputProperties(node_name);
|
||||
|
||||
PyObject* prop = PyList_New(output_props.size());
|
||||
for (int i = 0; i < output_props.size(); ++i) {
|
||||
string output_prop_str = output_props[i].SerializeAsString();
|
||||
PyObject* output_prop = PyBytes_FromStringAndSize(output_prop_str.data(),
|
||||
output_prop_str.size());
|
||||
PyList_SetItem(prop, i, output_prop);
|
||||
}
|
||||
CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return props;
|
||||
}
|
||||
|
||||
class ColocationGroups {
|
||||
public:
|
||||
void Group(const string& x, const string& y) {
|
||||
Rep* x_root = Find(x);
|
||||
Rep* y_root = Find(y);
|
||||
|
||||
// x and y are already in the same set
|
||||
if (x_root == y_root) {
|
||||
return;
|
||||
}
|
||||
// x and y are not in same set, so we merge them
|
||||
// Use the occasion to strengthen what we know about the handle by merging the
|
||||
// information about the 2 subsets.
|
||||
if (x_root->rank < y_root->rank) {
|
||||
x_root->parent = y_root;
|
||||
} else if (x_root->rank > y_root->rank) {
|
||||
y_root->parent = x_root;
|
||||
} else {
|
||||
// Arbitrarily make one root the new parent
|
||||
y_root->parent = x_root;
|
||||
x_root->rank = x_root->rank + 1;
|
||||
}
|
||||
}
|
||||
|
||||
void ExtractGroups(std::vector<std::vector<string>>* groups) {
|
||||
groups->reserve(nodes_.size());
|
||||
std::unordered_map<const Rep*, int> group_ids;
|
||||
for (const auto& rep : nodes_) {
|
||||
Rep* r = Find(rep.first);
|
||||
auto it = group_ids.find(r);
|
||||
std::vector<string>* g;
|
||||
if (it == group_ids.end()) {
|
||||
int id = group_ids.size();
|
||||
group_ids[r] = id;
|
||||
groups->resize(id+1);
|
||||
g = &groups->back();
|
||||
} else {
|
||||
int id = it->second;
|
||||
g = &((*groups)[id]);
|
||||
}
|
||||
g->push_back(rep.first);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct Rep {
|
||||
// Parent in the tree used to encode the set.
|
||||
Rep* parent;
|
||||
// Rank in the tree, used to figure out how to compress the path to the root
|
||||
// of the tree.
|
||||
int rank;
|
||||
// The node.
|
||||
string value;
|
||||
};
|
||||
|
||||
Rep* Find(const string& n) {
|
||||
auto it = nodes_.find(n);
|
||||
if (it == nodes_.end()) {
|
||||
// This is the first time we process this handle, create an entry for it.
|
||||
Rep* node = new Rep;
|
||||
node->parent = node;
|
||||
node->rank = 0;
|
||||
node->value = n;
|
||||
nodes_[n] = node;
|
||||
return node;
|
||||
}
|
||||
// Return the representative for the set, which is the root of the tree. Apply
|
||||
// path compression to speedup future queries.
|
||||
Rep* node = it->second;
|
||||
Rep* root = node->parent;
|
||||
while (root != root->parent) {
|
||||
root = root->parent;
|
||||
}
|
||||
while (node->parent != root) {
|
||||
Rep* next = node->parent;
|
||||
node->parent = root;
|
||||
node = next;
|
||||
}
|
||||
return root;
|
||||
}
|
||||
|
||||
std::unordered_map<string, Rep*> nodes_;
|
||||
};
|
||||
|
||||
static PyObject* TF_GetColocationGroups(GItem item) {
|
||||
if (item.is_none()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
ColocationGroups groupings;
|
||||
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
|
||||
for (const auto& node : item->graph.node()) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def);
|
||||
if (!s.ok()) {
|
||||
continue;
|
||||
}
|
||||
tensorflow::NameRangeMap inputs;
|
||||
tensorflow::NameRangeMap outputs;
|
||||
s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs);
|
||||
if (!s.ok()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& arg : op_def->input_arg()) {
|
||||
if (!arg.is_ref()) {
|
||||
continue;
|
||||
}
|
||||
const auto& range = inputs[arg.name()];
|
||||
for (int i = range.first; i < range.second; ++i) {
|
||||
string input = tensorflow::grappler::NodeName(node.input(i));
|
||||
groupings.Group(node.name(), input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<string>> groups;
|
||||
groupings.ExtractGroups(&groups);
|
||||
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* result = PyList_New(groups.size());
|
||||
for (int i = 0; i < groups.size(); ++i) {
|
||||
const std::vector<string>& group = groups[i];
|
||||
PyObject* g = PyTuple_New(group.size());
|
||||
for (int j = 0; j < group.size(); ++j) {
|
||||
const string& node_name = group[j];
|
||||
PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str()));
|
||||
}
|
||||
PyList_SetItem(result, i, g);
|
||||
}
|
||||
PyGILState_Release(gstate);
|
||||
return result;
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
|
||||
// Wrap these functions.
|
||||
static GItem TF_NewItem(
|
||||
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
||||
bool ignore_user_placement, TF_Status* status);
|
||||
static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
|
||||
TF_Status* status);
|
||||
static PyObject* TF_GetOpProperties(GItem item);
|
||||
static PyObject* TF_GetColocationGroups(GItem item);
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.core.grappler.costs import op_performance_data_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_item
|
||||
from tensorflow.python import _pywrap_tf_item as tf_item
|
||||
|
||||
|
||||
class Item(object):
|
||||
@ -53,11 +53,14 @@ class Item(object):
|
||||
return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically)
|
||||
|
||||
def GetOpProperties(self):
|
||||
ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item)
|
||||
"""Get Op properties."""
|
||||
props = tf_item.TF_GetOpProperties(self.tf_item)
|
||||
properties = {}
|
||||
for key, values in ret_from_swig.items():
|
||||
for key, values in props.items():
|
||||
prop = []
|
||||
for value in values:
|
||||
# TODO(petebu): Make this conversion to a dictionary be done in the C++
|
||||
# wrapper for performance.
|
||||
prop.append(
|
||||
op_performance_data_pb2.OpInfo.TensorProperties.FromString(value))
|
||||
properties[key] = prop
|
||||
|
245
tensorflow/python/grappler/item_wrapper.cc
Normal file
245
tensorflow/python/grappler/item_wrapper.cc
Normal file
@ -0,0 +1,245 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
class ColocationGroups {
|
||||
public:
|
||||
void Group(const std::string& x, const std::string& y) {
|
||||
Rep* x_root = Find(x);
|
||||
Rep* y_root = Find(y);
|
||||
|
||||
// x and y are already in the same set
|
||||
if (x_root == y_root) {
|
||||
return;
|
||||
}
|
||||
// x and y are not in same set, so we merge them
|
||||
// Use the occasion to strengthen what we know about the handle by merging
|
||||
// the information about the 2 subsets.
|
||||
if (x_root->rank < y_root->rank) {
|
||||
x_root->parent = y_root;
|
||||
} else if (x_root->rank > y_root->rank) {
|
||||
y_root->parent = x_root;
|
||||
} else {
|
||||
// Arbitrarily make one root the new parent
|
||||
y_root->parent = x_root;
|
||||
x_root->rank = x_root->rank + 1;
|
||||
}
|
||||
}
|
||||
|
||||
void ExtractGroups(std::vector<std::vector<std::string>>* groups) {
|
||||
groups->reserve(nodes_.size());
|
||||
std::unordered_map<const Rep*, int> group_ids;
|
||||
for (const auto& rep : nodes_) {
|
||||
Rep* r = Find(rep.first);
|
||||
auto it = group_ids.find(r);
|
||||
std::vector<std::string>* g;
|
||||
if (it == group_ids.end()) {
|
||||
int id = group_ids.size();
|
||||
group_ids[r] = id;
|
||||
groups->resize(id + 1);
|
||||
g = &groups->back();
|
||||
} else {
|
||||
int id = it->second;
|
||||
g = &((*groups)[id]);
|
||||
}
|
||||
g->push_back(rep.first);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct Rep {
|
||||
// Parent in the tree used to encode the set.
|
||||
Rep* parent;
|
||||
// Rank in the tree, used to figure out how to compress the path to the root
|
||||
// of the tree.
|
||||
int rank;
|
||||
// The node.
|
||||
std::string value;
|
||||
};
|
||||
|
||||
Rep* Find(const std::string& n) {
|
||||
auto it = nodes_.find(n);
|
||||
if (it == nodes_.end()) {
|
||||
// This is the first time we process this handle, create an entry for it.
|
||||
Rep* node = new Rep;
|
||||
node->parent = node;
|
||||
node->rank = 0;
|
||||
node->value = n;
|
||||
nodes_[n] = node;
|
||||
return node;
|
||||
}
|
||||
// Return the representative for the set, which is the root of the tree.
|
||||
// Apply path compression to speedup future queries.
|
||||
Rep* node = it->second;
|
||||
Rep* root = node->parent;
|
||||
while (root != root->parent) {
|
||||
root = root->parent;
|
||||
}
|
||||
while (node->parent != root) {
|
||||
Rep* next = node->parent;
|
||||
node->parent = root;
|
||||
node = next;
|
||||
}
|
||||
return root;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, Rep*> nodes_;
|
||||
};
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(tensorflow::grappler::GrapplerItem);
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tf_item, m) {
|
||||
py::class_<tensorflow::grappler::GrapplerItem> grappler_item(
|
||||
m, "tensorflow::grappler::GrapplerItem");
|
||||
|
||||
m.def("TF_NewItem",
|
||||
[](const py::bytes& serialized_metagraph, bool ignore_colocation,
|
||||
bool ignore_user_placement) -> tensorflow::grappler::GrapplerItem* {
|
||||
tensorflow::MetaGraphDef metagraph;
|
||||
if (!metagraph.ParseFromString(serialized_metagraph)) {
|
||||
throw std::invalid_argument(
|
||||
"The MetaGraphDef could not be parsed as a valid protocol "
|
||||
"buffer");
|
||||
}
|
||||
if (metagraph.collection_def().count("train_op") == 0) {
|
||||
MaybeRaiseRegisteredFromStatus(tensorflow::errors::InvalidArgument(
|
||||
"train_op not specified in the metagraph"));
|
||||
}
|
||||
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.ignore_user_placement = ignore_user_placement;
|
||||
cfg.ignore_colocation = ignore_colocation;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
|
||||
"item", metagraph, cfg);
|
||||
if (item == nullptr) {
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
tensorflow::errors::InvalidArgument("Invalid metagraph"));
|
||||
}
|
||||
return item.release();
|
||||
});
|
||||
|
||||
m.def("TF_IdentifyImportantOps",
|
||||
[](tensorflow::grappler::GrapplerItem* item,
|
||||
bool sort_topologically) -> std::vector<std::string> {
|
||||
std::vector<const tensorflow::NodeDef*> main_ops =
|
||||
item->MainOpsFanin();
|
||||
std::vector<const tensorflow::NodeDef*> enqueue_ops =
|
||||
item->EnqueueOpsFanin();
|
||||
std::unordered_set<std::string> op_names;
|
||||
for (auto op : main_ops) {
|
||||
op_names.insert(op->name());
|
||||
}
|
||||
for (auto op : enqueue_ops) {
|
||||
op_names.insert(op->name());
|
||||
}
|
||||
|
||||
std::vector<std::string> ops;
|
||||
if (sort_topologically) {
|
||||
tensorflow::GraphDef subgraph;
|
||||
for (const tensorflow::NodeDef& node : item->graph.node()) {
|
||||
if (op_names.find(node.name()) != op_names.end()) {
|
||||
*subgraph.add_node() = node;
|
||||
}
|
||||
}
|
||||
tensorflow::MaybeRaiseFromStatus(
|
||||
tensorflow::grappler::TopologicalSort(&subgraph));
|
||||
for (const tensorflow::NodeDef& node : subgraph.node()) {
|
||||
ops.push_back(node.name());
|
||||
}
|
||||
} else {
|
||||
for (const auto& op_name : op_names) {
|
||||
ops.push_back(op_name);
|
||||
}
|
||||
}
|
||||
return ops;
|
||||
});
|
||||
|
||||
m.def("TF_GetOpProperties",
|
||||
[](tensorflow::grappler::GrapplerItem* item)
|
||||
-> std::unordered_map<std::string, std::vector<py::bytes>> {
|
||||
tensorflow::grappler::GraphProperties properties(*item);
|
||||
tensorflow::MaybeRaiseFromStatus(properties.InferStatically(false));
|
||||
|
||||
std::unordered_map<std::string, std::vector<py::bytes>> props;
|
||||
for (const auto& node : item->graph.node()) {
|
||||
const std::string& node_name = node.name();
|
||||
const std::vector<tensorflow::OpInfo::TensorProperties>&
|
||||
output_props = properties.GetOutputProperties(node_name);
|
||||
|
||||
std::vector<py::bytes> prop;
|
||||
prop.reserve(output_props.size());
|
||||
for (const auto& output_prop : output_props) {
|
||||
prop.push_back(output_prop.SerializeAsString());
|
||||
}
|
||||
props[node_name] = prop;
|
||||
}
|
||||
return props;
|
||||
});
|
||||
|
||||
m.def("TF_GetColocationGroups",
|
||||
[](tensorflow::grappler::GrapplerItem* item)
|
||||
-> std::vector<std::vector<std::string>> {
|
||||
ColocationGroups groupings;
|
||||
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
|
||||
for (const auto& node : item->graph.node()) {
|
||||
const tensorflow::OpDef* op_def;
|
||||
if (!registry->LookUpOpDef(node.op(), &op_def).ok()) {
|
||||
continue;
|
||||
}
|
||||
tensorflow::NameRangeMap inputs;
|
||||
tensorflow::NameRangeMap outputs;
|
||||
if (!tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs)
|
||||
.ok()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& arg : op_def->input_arg()) {
|
||||
if (!arg.is_ref()) {
|
||||
continue;
|
||||
}
|
||||
const auto& range = inputs[arg.name()];
|
||||
for (int i = range.first; i < range.second; ++i) {
|
||||
groupings.Group(node.name(),
|
||||
tensorflow::grappler::NodeName(node.input(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> groups;
|
||||
groupings.ExtractGroups(&groups);
|
||||
return groups;
|
||||
});
|
||||
}
|
@ -1,144 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%include "cluster.i"
|
||||
|
||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(in) const tensorflow::ConfigProto& (
|
||||
tensorflow::ConfigProto temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The ConfigProto could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include <memory>
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
|
||||
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
|
||||
tensorflow::SessionOptions options;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
|
||||
if (!status.ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<tensorflow::Device>& device : devices) {
|
||||
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
|
||||
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
|
||||
|
||||
// Overwrite the memory limit since users might have requested to use only a fraction of the
|
||||
// available device memory.
|
||||
const tensorflow::DeviceAttributes& attr = device->attributes();
|
||||
prop.set_memory_size(attr.memory_limit());
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* TF_OptimizeGraph(
|
||||
GCluster cluster,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::MetaGraphDef& metagraph,
|
||||
bool verbose, const string& graph_id,
|
||||
bool strip_default_attributes,
|
||||
TF_Status* status) {
|
||||
tensorflow::grappler::ItemConfig item_config;
|
||||
item_config.apply_optimizations = false;
|
||||
item_config.ignore_user_placement = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
|
||||
|
||||
if (!grappler_item) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::DeviceBase* cpu_device = nullptr;
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
|
||||
tensorflow::Status s = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
if (!s.ok()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
if (strip_default_attributes) {
|
||||
tensorflow::StripDefaultAttributes(*tensorflow::OpRegistry::Global(),
|
||||
out_graph.mutable_node());
|
||||
}
|
||||
if (verbose) {
|
||||
optimizer.PrintResult();
|
||||
}
|
||||
string out_graph_str = out_graph.SerializeAsString();
|
||||
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
|
||||
out_graph_str.size());
|
||||
return ret;
|
||||
}
|
||||
%}
|
||||
|
||||
|
||||
// Wrap this function
|
||||
PyObject* TF_OptimizeGraph(
|
||||
GCluster cluster,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::MetaGraphDef& metagraph, bool verbose,
|
||||
const string& graph_id, bool strip_default_attributes, TF_Status* status);
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_opt
|
||||
from tensorflow.python import _pywrap_tf_optimizer as tf_opt
|
||||
from tensorflow.python.grappler import cluster as gcluster
|
||||
|
||||
|
||||
@ -52,12 +52,8 @@ def OptimizeGraph(config_proto,
|
||||
type(config_proto))
|
||||
if cluster is None:
|
||||
cluster = gcluster.Cluster()
|
||||
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
|
||||
out_graph = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
|
||||
config_proto.SerializeToString(),
|
||||
metagraph.SerializeToString(),
|
||||
verbose, graph_id,
|
||||
strip_default_attributes)
|
||||
if ret_from_swig is None:
|
||||
return None
|
||||
out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
|
||||
return out_graph
|
||||
metagraph.SerializeToString(), verbose,
|
||||
graph_id, strip_default_attributes)
|
||||
return graph_pb2.GraphDef().FromString(out_graph)
|
||||
|
108
tensorflow/python/grappler/tf_optimizer_wrapper.cc
Normal file
108
tensorflow/python/grappler/tf_optimizer_wrapper.cc
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void DetectDevices(
|
||||
std::unordered_map<std::string, tensorflow::DeviceProperties>* device_map) {
|
||||
tensorflow::SessionOptions options;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
if (!tensorflow::DeviceFactory::AddDevices(options, "", &devices).ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<tensorflow::Device>& device : devices) {
|
||||
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
|
||||
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
|
||||
|
||||
// Overwrite the memory limit since users might have requested to use only a
|
||||
// fraction of the available device memory.
|
||||
const tensorflow::DeviceAttributes& attr = device->attributes();
|
||||
prop.set_memory_size(attr.memory_limit());
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tf_optimizer, m) {
|
||||
m.def(
|
||||
"TF_OptimizeGraph",
|
||||
[](tensorflow::grappler::Cluster* cluster,
|
||||
const py::bytes& serialized_config_proto,
|
||||
const py::bytes& serialized_metagraph, bool verbose,
|
||||
const std::string& graph_id,
|
||||
bool strip_default_attributes) -> py::bytes {
|
||||
tensorflow::ConfigProto config_proto;
|
||||
if (!config_proto.ParseFromString(serialized_config_proto)) {
|
||||
throw std::invalid_argument(
|
||||
"The ConfigProto could not be parsed as a valid protocol buffer");
|
||||
}
|
||||
tensorflow::MetaGraphDef metagraph;
|
||||
if (!metagraph.ParseFromString(serialized_metagraph)) {
|
||||
throw std::invalid_argument(
|
||||
"The MetaGraphDef could not be parsed as a valid protocol "
|
||||
"buffer");
|
||||
}
|
||||
|
||||
tensorflow::grappler::ItemConfig item_config;
|
||||
// This disables graph optimizations in the older graph optimizer, which
|
||||
// tend to overlap / be redundant with those in Grappler.
|
||||
item_config.apply_optimizations = false;
|
||||
item_config.ignore_user_placement = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
|
||||
graph_id, metagraph, item_config);
|
||||
if (!grappler_item) {
|
||||
throw std::invalid_argument(
|
||||
"Failed to import metagraph, check error log for more info.");
|
||||
}
|
||||
|
||||
tensorflow::DeviceBase* cpu_device = nullptr;
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
|
||||
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
optimizer.Optimize(cluster, *grappler_item, &out_graph));
|
||||
if (strip_default_attributes) {
|
||||
tensorflow::StripDefaultAttributes(*tensorflow::OpRegistry::Global(),
|
||||
out_graph.mutable_node());
|
||||
}
|
||||
if (verbose) {
|
||||
optimizer.PrintResult();
|
||||
}
|
||||
return out_graph.SerializeAsString();
|
||||
});
|
||||
}
|
@ -17,10 +17,11 @@ limitations under the License.
|
||||
* The includes are intentionally not alphabetically sorted, as the order of
|
||||
* includes follows dependency order */
|
||||
|
||||
%include "tensorflow/python/grappler/cluster.i"
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
%}
|
||||
|
||||
// TODO(slebedev): This is a temporary workaround for projects implicitly
|
||||
// relying on TensorFlow exposing tensorflow::Status.
|
||||
|
@ -243,3 +243,59 @@ tensorflow::TF_GraphSetTensorShape_wrapper
|
||||
tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper
|
||||
tensorflow::TF_TryEvaluateConstant_wrapper
|
||||
|
||||
[grappler_item] # tf_item
|
||||
tensorflow::grappler::GrapplerItem::MainOpsFanin
|
||||
tensorflow::grappler::GrapplerItem::EnqueueOpsFanin
|
||||
|
||||
[graph_properties] # tf_item
|
||||
tensorflow::grappler::GraphProperties::InferStatically
|
||||
tensorflow::grappler::GraphProperties::GetOutputProperties
|
||||
|
||||
[grappler_item_builder] # tf_item
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef
|
||||
|
||||
[topological_sort] # tf_item
|
||||
tensorflow::grappler::TopologicalSort
|
||||
|
||||
[clusters/utils] # tf_cluster tf_optimizer
|
||||
tensorflow::grappler::GetDeviceInfo
|
||||
|
||||
[costs/utils] # tf_optimizer tf_cluster
|
||||
tensorflow::grappler::CostGraphToOpPerformanceData
|
||||
tensorflow::grappler::GetDeviceInfo
|
||||
|
||||
[meta_optimizer] # tf_optimizer
|
||||
tensorflow::grappler::MetaOptimizer::MetaOptimizer
|
||||
tensorflow::grappler::MetaOptimizer::Optimize
|
||||
tensorflow::grappler::MetaOptimizer::PrintResult
|
||||
|
||||
[clusters/cluster] # tf_cluster
|
||||
tensorflow::grappler::Cluster::AllowSoftPlacement
|
||||
tensorflow::grappler::Cluster::SetNumWarmupSteps
|
||||
tensorflow::grappler::Cluster::DisableDetailedStats
|
||||
tensorflow::grappler::Cluster::DetailedStatsEnabled
|
||||
|
||||
[single_machine] # tf_cluster
|
||||
tensorflow::grappler::SingleMachine::SingleMachine
|
||||
|
||||
[op_level_cost_estimator] # tf_cluster
|
||||
tensorflow::grappler::OpLevelCostEstimator::OpLevelCostEstimator
|
||||
tensorflow::grappler::OpLevelCostEstimator::PredictCosts
|
||||
tensorflow::grappler::OpLevelCostEstimator::GetDeviceInfo
|
||||
|
||||
[virtual_cluster] # tf_cluster
|
||||
tensorflow::grappler::VirtualCluster::VirtualCluster
|
||||
|
||||
[graph_memory] # tf_cluster
|
||||
tensorflow::grappler::GraphMemory::InferStatically
|
||||
tensorflow::grappler::GraphMemory::InferDynamically
|
||||
|
||||
[measuring_cost_estimator] # tf_cluster
|
||||
tensorflow::grappler::MeasuringCostEstimator::MeasuringCostEstimator
|
||||
tensorflow::grappler::MeasuringCostEstimator::Initialize
|
||||
tensorflow::grappler::MeasuringCostEstimator::PredictCosts
|
||||
|
||||
[devices] # tf_cluster
|
||||
tensorflow::grappler::GetNumAvailableGPUs
|
||||
tensorflow::grappler::GetNumAvailableLogicalCPUCores
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user