Export grappler code from C++ to Python with pybind11 instead of swig. Adding pywrap_required_hdrs to downstream grappler build files which will go away with bazel first class shared library support.

PiperOrigin-RevId: 292986813
Change-Id: I24cbeba85e593fcb4604913874c0ac01eac73e4d
This commit is contained in:
Amit Patankar 2020-02-03 13:07:13 -08:00 committed by TensorFlower Gardener
parent 7954fb8fd1
commit e3e22538e8
22 changed files with 1034 additions and 1017 deletions

View File

@ -41,6 +41,16 @@ filegroup(
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"training/coordinator.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "gradients",
srcs = [

View File

@ -4,6 +4,18 @@ package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"devices.h",
"grappler_item.h",
"grappler_item_builder.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "op_types",
srcs = ["op_types.cc"],

View File

@ -53,6 +53,19 @@ tf_cc_test(
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"cluster.h",
"single_machine.h",
"utils.h",
"virtual_cluster.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "cluster",
srcs = ["cluster.cc"],

View File

@ -36,6 +36,25 @@ tf_pyclif_proto_library(
visibility = ["//visibility:public"],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"analytical_cost_estimator.h",
"cost_estimator.h",
"graph_memory.h",
"graph_properties.h",
"measuring_cost_estimator.h",
"op_context.h",
"op_level_cost_estimator.h",
"utils.h",
"virtual_placer.h",
"virtual_scheduler.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "graph_properties",
srcs = ["graph_properties.cc"],

View File

@ -11,6 +11,17 @@ package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"graph_optimizer.h",
"meta_optimizer.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "static_schedule",
srcs = ["static_schedule.cc"],

View File

@ -8,6 +8,16 @@ package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"topological_sort.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "scc",
srcs = ["scc.cc"],

View File

@ -4,6 +4,16 @@ package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"graph_verifier.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "graph_verifier",
hdrs = [

View File

@ -364,6 +364,40 @@ cc_library(
] + tf_protos_grappler(),
)
# Necessary for the pywrap inclusion below. Combining targets does not work
# properly.
tf_pybind_cc_library_wrapper(
name = "cost_analyzer_headers",
deps = [
":cost_analyzer_lib",
],
)
tf_python_pybind_extension(
name = "_pywrap_cost_analyzer",
srcs = ["grappler/cost_analyzer_wrapper.cc"],
hdrs = [
"grappler/cost_analyzer.h",
"//tensorflow/cc:pywrap_required_hdrs",
"//tensorflow/core/grappler:pywrap_required_hdrs",
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
"//tensorflow/core/public:session.h",
"//tensorflow/core/public:session_options.h",
],
module_name = "_pywrap_cost_analyzer",
deps = [
":cost_analyzer_headers",
":pybind11_status",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_id",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@pybind11",
],
)
cc_library(
name = "model_analyzer_lib",
srcs = ["grappler/model_analyzer.cc"],
@ -380,14 +414,16 @@ cc_library(
tf_python_pybind_extension(
name = "_pywrap_model_analyzer",
srcs = ["grappler/model_analyzer_wrapper.cc"],
hdrs = ["grappler/model_analyzer.h"],
hdrs = [
"grappler/model_analyzer.h",
"//tensorflow/core/grappler:pywrap_required_hdrs",
],
module_name = "_pywrap_model_analyzer",
deps = [
":model_analyzer_lib",
":pybind11_status",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item_builder",
"@pybind11",
],
)
@ -5652,10 +5688,6 @@ tf_py_wrap_cc(
name = "pywrap_tensorflow_internal",
srcs = ["tensorflow.i"],
swig_includes = [
"grappler/cluster.i",
"grappler/cost_analyzer.i",
"grappler/item.i",
"grappler/tf_optimizer.i",
"lib/core/strings.i",
"platform/base.i",
],
@ -5721,10 +5753,11 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
":cpp_python_util", # util
":py_func_lib", # py_func
":model_analyzer_lib", # model_analyzer
"//tensorflow/core/util:port", # util_port
":cost_analyzer_lib", # cost_analyzer
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
"//tensorflow/core/util:port", # util_port
"//tensorflow/core:framework_internal_impl", # op_def_registry
"//tensorflow/core:lib_internal_impl", # device_lib
"//tensorflow/core:core_cpu_impl", # device_lib
@ -5749,6 +5782,20 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
"//tensorflow/core:core_cpu_base_no_ops", # tf_session
"//tensorflow/c:python_api", # tf_session
"//tensorflow/python:tf_session_helper", # tf_session
"//tensorflow/core/grappler:grappler_item", # tf_item
"//tensorflow/core/grappler/costs:graph_properties", # tf_item
"//tensorflow/core/grappler:grappler_item_builder", # tf_item
"//tensorflow/core/grappler/utils:topological_sort", # tf_item
"//tensorflow/core/grappler/costs:utils", # tf_cluster
"//tensorflow/core/grappler/optimizers:meta_optimizer", # tf_optimizer
"//tensorflow/core/grappler/clusters:cluster", # tf_cluster
"//tensorflow/core/grappler/clusters:single_machine", # tf_cluster
"//tensorflow/core/grappler/costs:op_level_cost_estimator", # tf_cluster
"//tensorflow/core/grappler/clusters:virtual_cluster", # tf_cluster
"//tensorflow/core/grappler/costs:graph_memory", # tf_cluster
"//tensorflow/core/grappler:devices", # tf_cluster
"//tensorflow/core/grappler/clusters:utils", # tf_optimizer
"//tensorflow/core/grappler/costs:measuring_cost_estimator", # tf_cluster
]
# Filter the DEF file to reduce the number of symbols to 64K or less.
@ -7359,11 +7406,32 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":pywrap_tensorflow_internal",
":_pywrap_tf_item",
"//tensorflow/core/grappler/costs:op_performance_data_py",
],
)
tf_python_pybind_extension(
name = "_pywrap_tf_item",
srcs = ["grappler/item_wrapper.cc"],
hdrs = [
"//tensorflow/cc:pywrap_required_hdrs",
"//tensorflow/core/grappler:pywrap_required_hdrs",
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
"//tensorflow/core/grappler/utils:pywrap_required_hdrs",
],
module_name = "_pywrap_tf_item",
deps = [
":pybind11_status",
"@pybind11",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_id",
"//tensorflow/core:protos_all_cc",
] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]), # b/148556093,
)
tf_py_test(
name = "item_test",
size = "small",
@ -7414,11 +7482,34 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":pywrap_tensorflow_internal",
":_pywrap_tf_cluster",
"//tensorflow/core/grappler/costs:op_performance_data_py",
],
)
tf_python_pybind_extension(
name = "_pywrap_tf_cluster",
srcs = ["grappler/cluster_wrapper.cc"],
hdrs = [
"//tensorflow/cc:pywrap_required_hdrs",
"//tensorflow/core/grappler:pywrap_required_hdrs",
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
"//tensorflow/core/grappler/utils:pywrap_required_hdrs",
],
module_name = "_pywrap_tf_cluster",
deps = [
":pybind11_status",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_id",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
"@pybind11",
],
)
cuda_py_test(
name = "cluster_test",
size = "small",
@ -7451,11 +7542,34 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":pywrap_tensorflow_internal",
":_pywrap_tf_optimizer",
":tf_cluster",
],
)
tf_python_pybind_extension(
name = "_pywrap_tf_optimizer",
srcs = ["grappler/tf_optimizer_wrapper.cc"],
hdrs = [
"//tensorflow/cc:pywrap_required_hdrs",
"//tensorflow/core/grappler:pywrap_required_hdrs",
"//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
"//tensorflow/core/grappler/costs:pywrap_required_hdrs",
"//tensorflow/core/grappler/optimizers:pywrap_required_hdrs",
"//tensorflow/core/grappler/verifiers:pywrap_required_hdrs",
],
module_name = "_pywrap_tf_optimizer",
deps = [
":pybind11_status",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_id",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@pybind11",
],
)
tf_py_test(
name = "tf_optimizer_test",
size = "small",
@ -7619,7 +7733,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow_internal",
":_pywrap_cost_analyzer",
":tf_cluster",
":tf_item",
],

View File

@ -1,450 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%include <std_shared_ptr.i>
%include "item.i"
// Wrap the cluster into an object that swig can manipulate. This ensures it will call the object
// destructor upon garbage collection instead of leaking memory.
struct GCluster {
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
};
%{
#include "tensorflow/core/protobuf/device_properties.pb.h"
template <>
bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
return false;
}
tensorflow::NamedDevice named_device;
if (!named_device.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The NamedDevice could not be parsed as a valid protocol buffer");
return false;
}
if (out) *out = named_device;
return true;
}
%}
%typemap(in) const std::vector<tensorflow::NamedDevice>& (std::vector<tensorflow::NamedDevice> temp) {
if (!tf_vector_input_helper($input, &temp, &_PyObjAs<tensorflow::NamedDevice>)) {
SWIG_fail;
}
$1 = &temp;
}
%typemap(in) const tensorflow::NamedDevice& (tensorflow::NamedDevice temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The NamedDevice could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%typemap(in) const tensorflow::RunMetadata& (tensorflow::RunMetadata temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The RunMetadata could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%typemap(in) const string& (string temp) {
char *buf;
Py_ssize_t len;
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) return NULL;
temp.assign(buf, len);
$1 = &temp;
}
%{
#include <memory>
#include <vector>
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/memory_types.h"
// Provide the implementation of the GCluster struct here.
struct GCluster {
GCluster() {}
GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {}
tensorflow::grappler::Cluster* operator->() const {
return cluster_.get();
}
tensorflow::grappler::Cluster* get() const {
return cluster_.get();
}
bool is_none() const {
return cluster_.get() == nullptr;
}
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
};
static GCluster TF_NewCluster(bool allow_soft_placement,
bool disable_detailed_stats, TF_Status* status) {
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
int timeout_s = 60 * 10;
tensorflow::grappler::Cluster* cluster_ =
new tensorflow::grappler::SingleMachine(
timeout_s, num_cpu_cores, num_gpus);
cluster_->DisableDetailedStats(disable_detailed_stats);
cluster_->AllowSoftPlacement(allow_soft_placement);
cluster_->SetNumWarmupSteps(10);
tensorflow::Status s = cluster_->Provision();
tensorflow::Set_TF_Status_from_Status(status, s);
return GCluster(cluster_);
}
static GCluster TF_NewVirtualCluster(
const std::vector<tensorflow::NamedDevice>& named_devices, TF_Status* status) {
std::unordered_map<string, tensorflow::DeviceProperties> devices;
for (const auto& named_device : named_devices) {
devices[named_device.name()]= named_device.properties();
}
tensorflow::grappler::Cluster* cluster_ =
new tensorflow::grappler::VirtualCluster(devices);
PyGILState_STATE gstate = PyGILState_Ensure();
tensorflow::Status s = cluster_->Provision();
PyGILState_Release(gstate);
tensorflow::Set_TF_Status_from_Status(status, s);
return GCluster(cluster_);
}
static void TF_ShutdownCluster(GCluster cluster) {
PyGILState_STATE gstate = PyGILState_Ensure();
cluster->Shutdown();
PyGILState_Release(gstate);
}
tensorflow::Status _GetOpPerformanceDataAndRunTime(
const tensorflow::grappler::GrapplerItem& item,
tensorflow::grappler::CostEstimator* cost_measure,
tensorflow::OpPerformanceList* op_performance_data,
tensorflow::grappler::Costs* costs) {
tensorflow::Status status = cost_measure->Initialize(item);
if (!status.ok()) return status;
tensorflow::RunMetadata run_metadata;
TF_RETURN_IF_ERROR(
cost_measure->PredictCosts(item.graph, &run_metadata, costs));
if (op_performance_data) {
*op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
run_metadata.cost_graph(), item.graph);
}
return tensorflow::Status::OK();
}
static PyObject* TF_ListDevices(GCluster cluster) {
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyList_New(devices.size());
int i = 0;
for (auto& dev : devices) {
tensorflow::NamedDevice d;
d.set_name(dev.first);
*d.mutable_properties() = dev.second;
string dev_str = d.SerializeAsString();
PyObject* dev_obj = PyBytes_FromStringAndSize(dev_str.data(),
dev_str.size());
PyList_SetItem(result, i, dev_obj);
++i;
}
PyGILState_Release(gstate);
return result;
}
static PyObject* TF_ListAvailableOps() {
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
std::vector<tensorflow::OpDef> ops;
registry->GetRegisteredOps(&ops);
std::vector<string> op_names;
for (const tensorflow::OpDef& op : ops) {
op_names.push_back(op.name());
}
std::sort(op_names.begin(), op_names.end());
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyList_New(op_names.size());
for (int i = 0; i < op_names.size(); ++i) {
PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str()));
}
PyGILState_Release(gstate);
return result;
}
static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
if (cluster.is_none() || item.is_none()) {
Py_RETURN_NONE;
}
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
std::unordered_map<string, std::vector<string>> device_types;
for (const auto& dev : devices) {
device_types[dev.second.type()].push_back(dev.first);
}
std::unordered_map<string, std::set<string>> supported_device_types;
std::unordered_map<string, std::set<string>> device_restrictions;
for (const auto& node : item->graph.node()) {
for (const auto& dev : device_types) {
const string& type = dev.first;
if (cluster->type() != "single_machine") {
// The actual kernel may not be linked in this binary.
supported_device_types[node.name()].insert(type);
} else {
// Check the kernel capabilities
const tensorflow::DeviceType dev_type(type);
tensorflow::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
if (s.ok()) {
supported_device_types[node.name()].insert(type);
// Check which inputs are restricted to reside on the host.
// TODO: extends this to support outputs as well
tensorflow::MemoryTypeVector inp_mtypes;
tensorflow::MemoryTypeVector out_mtypes;
s = tensorflow::MemoryTypesForNode(tensorflow::OpRegistry::Global(), dev_type, node,
&inp_mtypes, &out_mtypes);
if (s.ok()) {
for (int i = 0; i < inp_mtypes.size(); ++i) {
if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
device_restrictions[tensorflow::grappler::NodeName(node.input(i))].insert("CPU");
break;
}
}
}
}
}
}
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyDict_New();
for (const auto& supported_dev : supported_device_types) {
const string& node = supported_dev.first;
std::set<string> feasible;
const auto it = device_restrictions.find(node);
if (it != device_restrictions.end()) {
const std::set<string>& candidates = supported_dev.second;
const std::set<string>& valid = it->second;
std::set_intersection(candidates.begin(), candidates.end(), valid.begin(), valid.end(),
std::inserter(feasible, feasible.begin()));
} else {
feasible = supported_dev.second;
}
std::vector<string> device_names;
for (const string& type : feasible) {
auto it = device_types.find(type);
CHECK(it != device_types.end());
for (const string& name : it->second) {
device_names.push_back(name);
}
}
PyObject* dev = PyList_New(device_names.size());
for (int i = 0; i < device_names.size(); ++i) {
PyList_SetItem(dev, i, PyString_FromString(device_names[i].c_str()));
}
CHECK_EQ(0, PyDict_SetItem(result, PyString_FromString(node.c_str()), dev));
}
PyGILState_Release(gstate);
return result;
}
static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
tensorflow::grappler::OpLevelCostEstimator estimator;
tensorflow::grappler::DeviceInfo info =
estimator.GetDeviceInfo(device.properties());
return info.gigaops;
}
static PyObject* TF_MeasureCosts(
GItem item,
GCluster cluster,
bool generate_timeline, TF_Status* status) {
tensorflow::OpPerformanceList op_performance_data;
tensorflow::StepStats step_stats;
const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
tensorflow::grappler::Costs costs;
tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
*item, &cost_measure, &op_performance_data, &costs);
double run_time = FLT_MAX;
if (s.ok()) {
run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
}
if (generate_timeline) {
tensorflow::RunMetadata metadata;
tensorflow::Status run_status = cluster->Run(
item->graph, item->feed, item->fetch, &metadata);
if (run_status.ok()) {
step_stats = metadata.step_stats();
} else {
s = run_status;
}
}
tensorflow::Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
Py_RETURN_NONE;
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* op_perf_objs = PyList_New(
op_performance_data.op_performance_size());
for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
string op_perf_str =
op_performance_data.op_performance(i).SerializeAsString();
PyObject* op_perf_obj = PyBytes_FromStringAndSize(op_perf_str.data(),
op_perf_str.size());
PyList_SetItem(op_perf_objs, i, op_perf_obj);
}
PyObject* run_time_obj = PyFloat_FromDouble(run_time);
string step_stats_str = step_stats.SerializeAsString();
PyObject* metadata_obj = PyBytes_FromStringAndSize(step_stats_str.data(),
step_stats_str.size());
PyObject* ret = PyTuple_New(3);
if (PyTuple_SetItem(ret, 0, op_perf_objs) != 0 ||
PyTuple_SetItem(ret, 1, run_time_obj) != 0 ||
PyTuple_SetItem(ret, 2, metadata_obj) != 0) {
Py_DECREF(ret);
Py_XDECREF(op_perf_objs);
Py_XDECREF(run_time_obj);
Py_XDECREF(metadata_obj);
s = tensorflow::Status(tensorflow::error::Code::INTERNAL,
"Error setting return tuples.");
tensorflow::Set_TF_Status_from_Status(status, s);
Py_INCREF(Py_None);
ret = Py_None;
}
PyGILState_Release(gstate);
return ret;
}
static PyObject* TF_DeterminePeakMemoryUsage(
GItem item,
GCluster cluster,
TF_Status* status) {
if (item.is_none() || cluster.is_none()) {
tensorflow::Status s(tensorflow::error::Code::INTERNAL,
"You need both a cluster and an item to determine peak memory usage");
tensorflow::Set_TF_Status_from_Status(status, s);
Py_RETURN_NONE;
}
tensorflow::grappler::GraphMemory memory(*item);
tensorflow::Status s;
if (cluster->DetailedStatsEnabled()) {
s = memory.InferDynamically(cluster.get());
} else {
s = memory.InferStatically(cluster->GetDevices());
}
if (!s.ok()) {
tensorflow::Set_TF_Status_from_Status(status, s);
Py_RETURN_NONE;
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyDict_New();
for (const auto& device : cluster->GetDevices()) {
const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
memory.GetPeakMemoryUsage(device.first);
PyObject* per_device = PyList_New(usage.live_tensors.size());
for (int i = 0; i < usage.live_tensors.size(); ++i) {
const auto& live_tensor = usage.live_tensors[i];
PyObject* live = PyTuple_New(5);
PyTuple_SetItem(live, 0, PyString_FromString(live_tensor.node.c_str()));
PyTuple_SetItem(live, 1, PyInt_FromLong(live_tensor.output_id));
PyTuple_SetItem(live, 2, PyLong_FromLong(live_tensor.memory_used));
PyTuple_SetItem(live, 3, PyLong_FromLong(live_tensor.allocation_time.count()));
PyTuple_SetItem(live, 4, PyLong_FromLong(live_tensor.deallocation_time.count()));
PyList_SetItem(per_device, i, live);
}
PyObject* ret = PyTuple_New(2);
PyTuple_SetItem(ret, 0, PyLong_FromLong(usage.used_memory));
PyTuple_SetItem(ret, 1, per_device);
PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
}
PyGILState_Release(gstate);
return result;
}
%}
// Wrap these functions.
static GCluster TF_NewCluster(
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* status);
static GCluster TF_NewVirtualCluster(
const std::vector<tensorflow::NamedDevice>& named_devices,
TF_Status* status);
static void TF_ShutdownCluster(GCluster cluster);
static PyObject* TF_ListDevices(GCluster cluster);
static PyObject* TF_ListAvailableOps();
static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
static PyObject* TF_MeasureCosts(
GItem item, GCluster cluster,
bool generate_timeline, TF_Status* status);
static PyObject* TF_DeterminePeakMemoryUsage(
GItem item, GCluster cluster,
TF_Status* status);

View File

@ -23,7 +23,7 @@ import contextlib
from tensorflow.core.framework import step_stats_pb2
from tensorflow.core.grappler.costs import op_performance_data_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.python import pywrap_tensorflow as tf_cluster
from tensorflow.python import _pywrap_tf_cluster as tf_cluster
class Cluster(object):
@ -92,13 +92,9 @@ class Cluster(object):
item: The item for which to measure the costs.
Returns: The triplet op_perfs, runtime, step_stats.
"""
ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item, self._tf_cluster,
self._generate_timeline)
op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts(
item.tf_item, self._tf_cluster, self._generate_timeline)
if ret_from_swig is None:
return None
op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig
op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
for op_perf_bytes in op_perf_bytes_list]
return (op_perfs, run_time,

View File

@ -0,0 +1,332 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cfloat>
#include <cstdint>
#include <memory>
#include <set>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
tensorflow::Status _GetOpPerformanceDataAndRunTime(
const tensorflow::grappler::GrapplerItem& item,
tensorflow::grappler::CostEstimator* cost_measure,
tensorflow::OpPerformanceList* op_performance_data,
tensorflow::grappler::Costs* costs) {
tensorflow::Status status = cost_measure->Initialize(item);
if (!status.ok()) return status;
tensorflow::RunMetadata run_metadata;
MaybeRaiseRegisteredFromStatus(
cost_measure->PredictCosts(item.graph, &run_metadata, costs));
if (op_performance_data) {
*op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
run_metadata.cost_graph(), item.graph);
}
return tensorflow::Status::OK();
}
PYBIND11_MAKE_OPAQUE(tensorflow::grappler::Cluster);
PYBIND11_MODULE(_pywrap_tf_cluster, m) {
py::class_<tensorflow::grappler::Cluster> grappler_cluster(
m, "tensorflow::grappler::Cluster");
m.def("TF_NewCluster",
[](bool allow_soft_placement,
bool disable_detailed_stats) -> tensorflow::grappler::Cluster* {
// TODO(petebu): Make these named arguments with default values
// instead.
int num_cpu_cores =
tensorflow::grappler::GetNumAvailableLogicalCPUCores();
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
int timeout_s = 60 * 10;
std::unique_ptr<tensorflow::grappler::Cluster> cluster =
std::make_unique<tensorflow::grappler::SingleMachine>(
timeout_s, num_cpu_cores, num_gpus);
cluster->DisableDetailedStats(disable_detailed_stats);
cluster->AllowSoftPlacement(allow_soft_placement);
cluster->SetNumWarmupSteps(10);
MaybeRaiseRegisteredFromStatus(cluster->Provision());
return cluster.release();
});
m.def("TF_NewVirtualCluster",
[](const std::vector<py::bytes>& serialized_named_devices)
-> tensorflow::grappler::Cluster* {
std::vector<tensorflow::NamedDevice> named_devices;
for (const auto& s : serialized_named_devices) {
tensorflow::NamedDevice named_device;
if (!named_device.ParseFromString(s)) {
throw std::invalid_argument(
"The NamedDevice could not be parsed as a valid protocol "
"buffer");
}
named_devices.push_back(named_device);
}
std::unordered_map<std::string, tensorflow::DeviceProperties> devices;
for (const auto& named_device : named_devices) {
devices[named_device.name()] = named_device.properties();
}
std::unique_ptr<tensorflow::grappler::Cluster> cluster =
std::make_unique<tensorflow::grappler::VirtualCluster>(devices);
{
// TODO(petebu): Do we need to hold the GIL here?
py::gil_scoped_acquire acquire;
MaybeRaiseRegisteredFromStatus(cluster->Provision());
}
return cluster.release();
});
m.def("TF_ShutdownCluster", [](tensorflow::grappler::Cluster* cluster) {
// TODO(petebu): Do we need to hold the GIL here?
py::gil_scoped_acquire acquire;
cluster->Shutdown();
});
m.def("TF_ListDevices",
[](tensorflow::grappler::Cluster* cluster) -> std::vector<py::bytes> {
const std::unordered_map<std::string, tensorflow::DeviceProperties>&
devices = cluster->GetDevices();
std::vector<py::bytes> named_devices;
for (auto& dev : devices) {
tensorflow::NamedDevice d;
d.set_name(dev.first);
*d.mutable_properties() = dev.second;
named_devices.push_back(d.SerializeAsString());
}
return named_devices;
});
m.def("TF_ListAvailableOps", []() -> std::vector<std::string> {
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
std::vector<tensorflow::OpDef> ops;
registry->GetRegisteredOps(&ops);
std::vector<std::string> op_names;
for (const tensorflow::OpDef& op : ops) {
op_names.push_back(op.name());
}
std::sort(op_names.begin(), op_names.end());
return op_names;
});
m.def(
"TF_GetSupportedDevices",
[](tensorflow::grappler::Cluster* cluster,
tensorflow::grappler::GrapplerItem* item)
-> std::unordered_map<std::string, std::vector<std::string>> {
if (cluster == nullptr || item == nullptr) {
MaybeRaiseRegisteredFromStatus(tensorflow::Status(
tensorflow::errors::Internal("You need both a cluster and an "
"item to get supported devices.")));
}
const std::unordered_map<std::string, tensorflow::DeviceProperties>&
devices = cluster->GetDevices();
std::unordered_map<std::string, std::vector<std::string>> device_types;
for (const auto& dev : devices) {
device_types[dev.second.type()].push_back(dev.first);
}
std::unordered_map<std::string, std::set<std::string>>
supported_device_types;
std::unordered_map<std::string, std::set<std::string>>
device_restrictions;
for (const auto& node : item->graph.node()) {
for (const auto& dev : device_types) {
const std::string& type = dev.first;
if (cluster->type() != "single_machine") {
// The actual kernel may not be linked in this binary.
supported_device_types[node.name()].insert(type);
} else {
// Check the kernel capabilities
const tensorflow::DeviceType dev_type(type);
tensorflow::Status s =
tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
if (s.ok()) {
supported_device_types[node.name()].insert(type);
// Check which inputs are restricted to reside on the host.
// TODO: extends this to support outputs as well
tensorflow::MemoryTypeVector inp_mtypes;
tensorflow::MemoryTypeVector out_mtypes;
tensorflow::Status s = tensorflow::MemoryTypesForNode(
tensorflow::OpRegistry::Global(), dev_type, node,
&inp_mtypes, &out_mtypes);
if (s.ok()) {
for (int i = 0; i < inp_mtypes.size(); ++i) {
if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
device_restrictions[tensorflow::grappler::NodeName(
node.input(i))]
.insert("CPU");
break;
}
}
}
}
}
}
}
std::unordered_map<std::string, std::vector<std::string>> result;
for (const auto& supported_dev : supported_device_types) {
const std::string& node = supported_dev.first;
std::set<std::string> feasible;
const auto it = device_restrictions.find(node);
if (it != device_restrictions.end()) {
const std::set<std::string>& candidates = supported_dev.second;
const std::set<std::string>& valid = it->second;
std::set_intersection(candidates.begin(), candidates.end(),
valid.begin(), valid.end(),
std::inserter(feasible, feasible.begin()));
} else {
feasible = supported_dev.second;
}
std::vector<std::string> device_names;
for (const std::string& type : feasible) {
auto it = device_types.find(type);
DCHECK(it != device_types.end());
for (const std::string& name : it->second) {
device_names.push_back(name);
}
}
result[node] = device_names;
}
return result;
});
m.def("TF_EstimatePerformance", [](const py::bytes& serialized_device) {
tensorflow::NamedDevice device;
if (!device.ParseFromString(serialized_device)) {
throw std::invalid_argument(
"The NamedDevice could not be parsed as a valid protocol buffer");
}
tensorflow::grappler::OpLevelCostEstimator estimator;
tensorflow::grappler::DeviceInfo info =
estimator.GetDeviceInfo(device.properties());
return info.gigaops;
});
m.def("TF_MeasureCosts",
[](tensorflow::grappler::GrapplerItem* item,
tensorflow::grappler::Cluster* cluster, bool generate_timeline)
-> std::tuple<std::vector<py::bytes>, double, py::bytes> {
const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
tensorflow::grappler::MeasuringCostEstimator cost_measure(
cluster, num_measurements, 0);
tensorflow::OpPerformanceList op_performance_data;
tensorflow::grappler::Costs costs;
tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
*item, &cost_measure, &op_performance_data, &costs);
double run_time = FLT_MAX;
if (s.ok()) {
run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
}
tensorflow::StepStats step_stats;
if (generate_timeline) {
tensorflow::RunMetadata metadata;
MaybeRaiseRegisteredFromStatus(
cluster->Run(item->graph, item->feed, item->fetch, &metadata));
step_stats = metadata.step_stats();
}
std::vector<py::bytes> op_perf_objs;
op_perf_objs.resize(op_performance_data.op_performance_size());
for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
op_perf_objs[i] =
op_performance_data.op_performance(i).SerializeAsString();
}
py::bytes step_stats_str = step_stats.SerializeAsString();
return std::make_tuple(op_perf_objs, run_time, step_stats_str);
});
using DurationType = tensorflow::grappler::Costs::Duration::rep;
using MemoryUsage =
std::tuple<std::string, int, size_t, DurationType, DurationType>;
m.def(
"TF_DeterminePeakMemoryUsage",
[](tensorflow::grappler::GrapplerItem* item,
tensorflow::grappler::Cluster* cluster)
-> std::unordered_map<std::string,
std::tuple<int64_t, std::vector<MemoryUsage>>> {
if (item == nullptr || cluster == nullptr) {
MaybeRaiseRegisteredFromStatus(
tensorflow::Status(tensorflow::errors::Internal(
"You need both a cluster and an item to determine peak "
"memory usage.")));
}
tensorflow::grappler::GraphMemory memory(*item);
if (cluster->DetailedStatsEnabled()) {
MaybeRaiseRegisteredFromStatus(memory.InferDynamically(cluster));
} else {
MaybeRaiseRegisteredFromStatus(
memory.InferStatically(cluster->GetDevices()));
}
std::unordered_map<std::string,
std::tuple<int64_t, std::vector<MemoryUsage>>>
result;
for (const auto& device : cluster->GetDevices()) {
const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
memory.GetPeakMemoryUsage(device.first);
std::vector<MemoryUsage> per_device;
for (int i = 0; i < usage.live_tensors.size(); ++i) {
const auto& live_tensor = usage.live_tensors[i];
per_device.push_back(std::make_tuple(
live_tensor.node, live_tensor.output_id,
live_tensor.memory_used, live_tensor.allocation_time.count(),
live_tensor.deallocation_time.count()));
}
result[device.first] = std::make_tuple(usage.used_memory, per_device);
}
return result;
});
}

View File

@ -1,67 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%include "cluster.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The MetaGraphDef could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/python/grappler/cost_analyzer.h"
%}
%{
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
bool verbose, GCluster cluster) {
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
if (!item) {
return "Error: failed to preprocess metagraph: check your log file for errors";
}
string suffix;
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix);
std::stringstream os;
analyzer.GenerateReport(os, per_node_report, verbose);
return os.str();
}
%}
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
bool verbose, GCluster cluster);

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python import _pywrap_cost_analyzer as tf_wrap
from tensorflow.python.grappler import cluster as gcluster
from tensorflow.python.grappler import item as gitem
@ -44,10 +44,9 @@ def GenerateCostReport(metagraph,
if cluster is None:
cluster = gcluster.Cluster(disable_detailed_stats=False)
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
return tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
per_node_report, verbose,
cluster.tf_cluster)
return ret_from_swig
def GenerateMemoryReport(metagraph, detailed_report=True, cluster=None):

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <sstream>
#include <string>
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/python/grappler/cost_analyzer.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_cost_analyzer, m) {
m.def("GenerateCostReport",
[](const py::bytes& serialized_metagraph, bool per_node_report,
bool verbose, tensorflow::grappler::Cluster* cluster) -> py::bytes {
tensorflow::MetaGraphDef metagraph;
if (!metagraph.ParseFromString(serialized_metagraph)) {
return "The MetaGraphDef could not be parsed as a valid protocol "
"buffer";
}
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"metagraph", metagraph, cfg);
if (item == nullptr) {
return "Error: failed to preprocess metagraph: check your log file "
"for errors";
}
std::string suffix;
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster, suffix);
std::stringstream os;
tensorflow::MaybeRaiseFromStatus(
analyzer.GenerateReport(os, per_node_report, verbose));
return py::bytes(os.str());
});
}

View File

@ -1,315 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include <std_shared_ptr.i>
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The MetaGraphDef could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
// Wrap the item into an object that swig can manipulate. This ensures it will call the object
// destructor upon garbage collection instead of leaking memory.
struct GItem {
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
};
%{
#include <unordered_set>
#include <map>
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
// Provide the implementation fo the GItem struct here.
struct GItem {
GItem() {}
GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {}
tensorflow::grappler::GrapplerItem* operator->() const {
return item_.get();
}
const tensorflow::grappler::GrapplerItem& operator*() const {
return *item_.get();
}
bool is_none() const {
return item_.get() == nullptr;
}
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
};
static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* status) {
if (meta_graph.collection_def().count("train_op") == 0) {
tensorflow::Set_TF_Status_from_Status(
status,
tensorflow::errors::InvalidArgument("train_op not specified in the metagraph"));
return nullptr;
}
tensorflow::grappler::ItemConfig cfg;
cfg.ignore_user_placement = ignore_user_placement;
cfg.ignore_colocation = ignore_colocation;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg);
if (!item) {
tensorflow::Set_TF_Status_from_Status(
status,
tensorflow::errors::InvalidArgument("Invalid metagraph"));
return nullptr;
}
tensorflow::Set_TF_Status_from_Status(status, tensorflow::Status::OK());
return GItem(item.release());
}
static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
TF_Status* status) {
if (item.is_none()) {
Py_RETURN_NONE;
}
std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin();
std::vector<const tensorflow::NodeDef*> enqueue_ops = item->EnqueueOpsFanin();
std::unordered_set<string> op_names;
for (auto op : main_ops) {
op_names.insert(op->name());
}
for (auto op : enqueue_ops) {
op_names.insert(op->name());
}
std::vector<string> ops;
if (sort_topologically) {
tensorflow::GraphDef subgraph;
for (const tensorflow::NodeDef& node : item->graph.node()) {
if (op_names.find(node.name()) != op_names.end()) {
*subgraph.add_node() = node;
}
}
tensorflow::Status s = tensorflow::grappler::TopologicalSort(&subgraph);
tensorflow::Set_TF_Status_from_Status(status, s);
for (const tensorflow::NodeDef& node : subgraph.node()) {
ops.push_back(node.name());
}
}
else {
for (const auto& op_name : op_names) {
ops.push_back(op_name);
}
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyList_New(ops.size());
for (int i = 0; i < ops.size(); ++i) {
PyList_SetItem(result, i, PyString_FromString(ops[i].c_str()));
}
PyGILState_Release(gstate);
return result;
}
static PyObject* TF_GetOpProperties(GItem item) {
if (item.is_none()) {
Py_RETURN_NONE;
}
tensorflow::grappler::GraphProperties properties(*item);
tensorflow::Status status = properties.InferStatically(false);
if (!status.ok()) {
Py_RETURN_NONE;
}
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* props = PyDict_New();
for (const auto& node : item->graph.node()) {
const string& node_name = node.name();
const std::vector<tensorflow::OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node_name);
PyObject* prop = PyList_New(output_props.size());
for (int i = 0; i < output_props.size(); ++i) {
string output_prop_str = output_props[i].SerializeAsString();
PyObject* output_prop = PyBytes_FromStringAndSize(output_prop_str.data(),
output_prop_str.size());
PyList_SetItem(prop, i, output_prop);
}
CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
}
PyGILState_Release(gstate);
return props;
}
class ColocationGroups {
public:
void Group(const string& x, const string& y) {
Rep* x_root = Find(x);
Rep* y_root = Find(y);
// x and y are already in the same set
if (x_root == y_root) {
return;
}
// x and y are not in same set, so we merge them
// Use the occasion to strengthen what we know about the handle by merging the
// information about the 2 subsets.
if (x_root->rank < y_root->rank) {
x_root->parent = y_root;
} else if (x_root->rank > y_root->rank) {
y_root->parent = x_root;
} else {
// Arbitrarily make one root the new parent
y_root->parent = x_root;
x_root->rank = x_root->rank + 1;
}
}
void ExtractGroups(std::vector<std::vector<string>>* groups) {
groups->reserve(nodes_.size());
std::unordered_map<const Rep*, int> group_ids;
for (const auto& rep : nodes_) {
Rep* r = Find(rep.first);
auto it = group_ids.find(r);
std::vector<string>* g;
if (it == group_ids.end()) {
int id = group_ids.size();
group_ids[r] = id;
groups->resize(id+1);
g = &groups->back();
} else {
int id = it->second;
g = &((*groups)[id]);
}
g->push_back(rep.first);
}
}
private:
struct Rep {
// Parent in the tree used to encode the set.
Rep* parent;
// Rank in the tree, used to figure out how to compress the path to the root
// of the tree.
int rank;
// The node.
string value;
};
Rep* Find(const string& n) {
auto it = nodes_.find(n);
if (it == nodes_.end()) {
// This is the first time we process this handle, create an entry for it.
Rep* node = new Rep;
node->parent = node;
node->rank = 0;
node->value = n;
nodes_[n] = node;
return node;
}
// Return the representative for the set, which is the root of the tree. Apply
// path compression to speedup future queries.
Rep* node = it->second;
Rep* root = node->parent;
while (root != root->parent) {
root = root->parent;
}
while (node->parent != root) {
Rep* next = node->parent;
node->parent = root;
node = next;
}
return root;
}
std::unordered_map<string, Rep*> nodes_;
};
static PyObject* TF_GetColocationGroups(GItem item) {
if (item.is_none()) {
Py_RETURN_NONE;
}
ColocationGroups groupings;
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
for (const auto& node : item->graph.node()) {
const tensorflow::OpDef* op_def;
tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def);
if (!s.ok()) {
continue;
}
tensorflow::NameRangeMap inputs;
tensorflow::NameRangeMap outputs;
s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs);
if (!s.ok()) {
continue;
}
for (const auto& arg : op_def->input_arg()) {
if (!arg.is_ref()) {
continue;
}
const auto& range = inputs[arg.name()];
for (int i = range.first; i < range.second; ++i) {
string input = tensorflow::grappler::NodeName(node.input(i));
groupings.Group(node.name(), input);
}
}
}
std::vector<std::vector<string>> groups;
groupings.ExtractGroups(&groups);
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyList_New(groups.size());
for (int i = 0; i < groups.size(); ++i) {
const std::vector<string>& group = groups[i];
PyObject* g = PyTuple_New(group.size());
for (int j = 0; j < group.size(); ++j) {
const string& node_name = group[j];
PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str()));
}
PyList_SetItem(result, i, g);
}
PyGILState_Release(gstate);
return result;
}
%}
// Wrap these functions.
static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* status);
static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
TF_Status* status);
static PyObject* TF_GetOpProperties(GItem item);
static PyObject* TF_GetColocationGroups(GItem item);

View File

@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.core.grappler.costs import op_performance_data_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python import pywrap_tensorflow as tf_item
from tensorflow.python import _pywrap_tf_item as tf_item
class Item(object):
@ -53,11 +53,14 @@ class Item(object):
return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically)
def GetOpProperties(self):
ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item)
"""Get Op properties."""
props = tf_item.TF_GetOpProperties(self.tf_item)
properties = {}
for key, values in ret_from_swig.items():
for key, values in props.items():
prop = []
for value in values:
# TODO(petebu): Make this conversion to a dictionary be done in the C++
# wrapper for performance.
prop.append(
op_performance_data_pb2.OpInfo.TensorProperties.FromString(value))
properties[key] = prop

View File

@ -0,0 +1,245 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
class ColocationGroups {
public:
void Group(const std::string& x, const std::string& y) {
Rep* x_root = Find(x);
Rep* y_root = Find(y);
// x and y are already in the same set
if (x_root == y_root) {
return;
}
// x and y are not in same set, so we merge them
// Use the occasion to strengthen what we know about the handle by merging
// the information about the 2 subsets.
if (x_root->rank < y_root->rank) {
x_root->parent = y_root;
} else if (x_root->rank > y_root->rank) {
y_root->parent = x_root;
} else {
// Arbitrarily make one root the new parent
y_root->parent = x_root;
x_root->rank = x_root->rank + 1;
}
}
void ExtractGroups(std::vector<std::vector<std::string>>* groups) {
groups->reserve(nodes_.size());
std::unordered_map<const Rep*, int> group_ids;
for (const auto& rep : nodes_) {
Rep* r = Find(rep.first);
auto it = group_ids.find(r);
std::vector<std::string>* g;
if (it == group_ids.end()) {
int id = group_ids.size();
group_ids[r] = id;
groups->resize(id + 1);
g = &groups->back();
} else {
int id = it->second;
g = &((*groups)[id]);
}
g->push_back(rep.first);
}
}
private:
struct Rep {
// Parent in the tree used to encode the set.
Rep* parent;
// Rank in the tree, used to figure out how to compress the path to the root
// of the tree.
int rank;
// The node.
std::string value;
};
Rep* Find(const std::string& n) {
auto it = nodes_.find(n);
if (it == nodes_.end()) {
// This is the first time we process this handle, create an entry for it.
Rep* node = new Rep;
node->parent = node;
node->rank = 0;
node->value = n;
nodes_[n] = node;
return node;
}
// Return the representative for the set, which is the root of the tree.
// Apply path compression to speedup future queries.
Rep* node = it->second;
Rep* root = node->parent;
while (root != root->parent) {
root = root->parent;
}
while (node->parent != root) {
Rep* next = node->parent;
node->parent = root;
node = next;
}
return root;
}
std::unordered_map<std::string, Rep*> nodes_;
};
PYBIND11_MAKE_OPAQUE(tensorflow::grappler::GrapplerItem);
PYBIND11_MODULE(_pywrap_tf_item, m) {
py::class_<tensorflow::grappler::GrapplerItem> grappler_item(
m, "tensorflow::grappler::GrapplerItem");
m.def("TF_NewItem",
[](const py::bytes& serialized_metagraph, bool ignore_colocation,
bool ignore_user_placement) -> tensorflow::grappler::GrapplerItem* {
tensorflow::MetaGraphDef metagraph;
if (!metagraph.ParseFromString(serialized_metagraph)) {
throw std::invalid_argument(
"The MetaGraphDef could not be parsed as a valid protocol "
"buffer");
}
if (metagraph.collection_def().count("train_op") == 0) {
MaybeRaiseRegisteredFromStatus(tensorflow::errors::InvalidArgument(
"train_op not specified in the metagraph"));
}
tensorflow::grappler::ItemConfig cfg;
cfg.ignore_user_placement = ignore_user_placement;
cfg.ignore_colocation = ignore_colocation;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"item", metagraph, cfg);
if (item == nullptr) {
MaybeRaiseRegisteredFromStatus(
tensorflow::errors::InvalidArgument("Invalid metagraph"));
}
return item.release();
});
m.def("TF_IdentifyImportantOps",
[](tensorflow::grappler::GrapplerItem* item,
bool sort_topologically) -> std::vector<std::string> {
std::vector<const tensorflow::NodeDef*> main_ops =
item->MainOpsFanin();
std::vector<const tensorflow::NodeDef*> enqueue_ops =
item->EnqueueOpsFanin();
std::unordered_set<std::string> op_names;
for (auto op : main_ops) {
op_names.insert(op->name());
}
for (auto op : enqueue_ops) {
op_names.insert(op->name());
}
std::vector<std::string> ops;
if (sort_topologically) {
tensorflow::GraphDef subgraph;
for (const tensorflow::NodeDef& node : item->graph.node()) {
if (op_names.find(node.name()) != op_names.end()) {
*subgraph.add_node() = node;
}
}
tensorflow::MaybeRaiseFromStatus(
tensorflow::grappler::TopologicalSort(&subgraph));
for (const tensorflow::NodeDef& node : subgraph.node()) {
ops.push_back(node.name());
}
} else {
for (const auto& op_name : op_names) {
ops.push_back(op_name);
}
}
return ops;
});
m.def("TF_GetOpProperties",
[](tensorflow::grappler::GrapplerItem* item)
-> std::unordered_map<std::string, std::vector<py::bytes>> {
tensorflow::grappler::GraphProperties properties(*item);
tensorflow::MaybeRaiseFromStatus(properties.InferStatically(false));
std::unordered_map<std::string, std::vector<py::bytes>> props;
for (const auto& node : item->graph.node()) {
const std::string& node_name = node.name();
const std::vector<tensorflow::OpInfo::TensorProperties>&
output_props = properties.GetOutputProperties(node_name);
std::vector<py::bytes> prop;
prop.reserve(output_props.size());
for (const auto& output_prop : output_props) {
prop.push_back(output_prop.SerializeAsString());
}
props[node_name] = prop;
}
return props;
});
m.def("TF_GetColocationGroups",
[](tensorflow::grappler::GrapplerItem* item)
-> std::vector<std::vector<std::string>> {
ColocationGroups groupings;
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
for (const auto& node : item->graph.node()) {
const tensorflow::OpDef* op_def;
if (!registry->LookUpOpDef(node.op(), &op_def).ok()) {
continue;
}
tensorflow::NameRangeMap inputs;
tensorflow::NameRangeMap outputs;
if (!tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs)
.ok()) {
continue;
}
for (const auto& arg : op_def->input_arg()) {
if (!arg.is_ref()) {
continue;
}
const auto& range = inputs[arg.name()];
for (int i = range.first; i < range.second; ++i) {
groupings.Group(node.name(),
tensorflow::grappler::NodeName(node.input(i)));
}
}
}
std::vector<std::vector<std::string>> groups;
groupings.ExtractGroups(&groups);
return groups;
});
}

View File

@ -1,144 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%include "cluster.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The MetaGraphDef could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%typemap(in) const tensorflow::ConfigProto& (
tensorflow::ConfigProto temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The ConfigProto could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include <memory>
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session_options.h"
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
tensorflow::SessionOptions options;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
if (!status.ok()) {
return;
}
for (const std::unique_ptr<tensorflow::Device>& device : devices) {
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
// Overwrite the memory limit since users might have requested to use only a fraction of the
// available device memory.
const tensorflow::DeviceAttributes& attr = device->attributes();
prop.set_memory_size(attr.memory_limit());
}
}
PyObject* TF_OptimizeGraph(
GCluster cluster,
const tensorflow::ConfigProto& config_proto,
const tensorflow::MetaGraphDef& metagraph,
bool verbose, const string& graph_id,
bool strip_default_attributes,
TF_Status* status) {
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = false;
item_config.ignore_user_placement = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
if (!grappler_item) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
return nullptr;
}
tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::GraphDef out_graph;
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
tensorflow::Status s = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
tensorflow::Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
Py_RETURN_NONE;
}
if (strip_default_attributes) {
tensorflow::StripDefaultAttributes(*tensorflow::OpRegistry::Global(),
out_graph.mutable_node());
}
if (verbose) {
optimizer.PrintResult();
}
string out_graph_str = out_graph.SerializeAsString();
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
out_graph_str.size());
return ret;
}
%}
// Wrap this function
PyObject* TF_OptimizeGraph(
GCluster cluster,
const tensorflow::ConfigProto& config_proto,
const tensorflow::MetaGraphDef& metagraph, bool verbose,
const string& graph_id, bool strip_default_attributes, TF_Status* status);

View File

@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_opt
from tensorflow.python import _pywrap_tf_optimizer as tf_opt
from tensorflow.python.grappler import cluster as gcluster
@ -52,12 +52,8 @@ def OptimizeGraph(config_proto,
type(config_proto))
if cluster is None:
cluster = gcluster.Cluster()
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
out_graph = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
config_proto.SerializeToString(),
metagraph.SerializeToString(),
verbose, graph_id,
strip_default_attributes)
if ret_from_swig is None:
return None
out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
return out_graph
metagraph.SerializeToString(), verbose,
graph_id, strip_default_attributes)
return graph_pb2.GraphDef().FromString(out_graph)

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
void DetectDevices(
std::unordered_map<std::string, tensorflow::DeviceProperties>* device_map) {
tensorflow::SessionOptions options;
std::vector<std::unique_ptr<tensorflow::Device>> devices;
if (!tensorflow::DeviceFactory::AddDevices(options, "", &devices).ok()) {
return;
}
for (const std::unique_ptr<tensorflow::Device>& device : devices) {
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
// Overwrite the memory limit since users might have requested to use only a
// fraction of the available device memory.
const tensorflow::DeviceAttributes& attr = device->attributes();
prop.set_memory_size(attr.memory_limit());
}
}
PYBIND11_MODULE(_pywrap_tf_optimizer, m) {
m.def(
"TF_OptimizeGraph",
[](tensorflow::grappler::Cluster* cluster,
const py::bytes& serialized_config_proto,
const py::bytes& serialized_metagraph, bool verbose,
const std::string& graph_id,
bool strip_default_attributes) -> py::bytes {
tensorflow::ConfigProto config_proto;
if (!config_proto.ParseFromString(serialized_config_proto)) {
throw std::invalid_argument(
"The ConfigProto could not be parsed as a valid protocol buffer");
}
tensorflow::MetaGraphDef metagraph;
if (!metagraph.ParseFromString(serialized_metagraph)) {
throw std::invalid_argument(
"The MetaGraphDef could not be parsed as a valid protocol "
"buffer");
}
tensorflow::grappler::ItemConfig item_config;
// This disables graph optimizations in the older graph optimizer, which
// tend to overlap / be redundant with those in Grappler.
item_config.apply_optimizations = false;
item_config.ignore_user_placement = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
graph_id, metagraph, item_config);
if (!grappler_item) {
throw std::invalid_argument(
"Failed to import metagraph, check error log for more info.");
}
tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::GraphDef out_graph;
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
MaybeRaiseRegisteredFromStatus(
optimizer.Optimize(cluster, *grappler_item, &out_graph));
if (strip_default_attributes) {
tensorflow::StripDefaultAttributes(*tensorflow::OpRegistry::Global(),
out_graph.mutable_node());
}
if (verbose) {
optimizer.PrintResult();
}
return out_graph.SerializeAsString();
});
}

View File

@ -17,10 +17,11 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
%include "tensorflow/python/grappler/cluster.i"
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
%}
// TODO(slebedev): This is a temporary workaround for projects implicitly
// relying on TensorFlow exposing tensorflow::Status.

View File

@ -243,3 +243,59 @@ tensorflow::TF_GraphSetTensorShape_wrapper
tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper
tensorflow::TF_TryEvaluateConstant_wrapper
[grappler_item] # tf_item
tensorflow::grappler::GrapplerItem::MainOpsFanin
tensorflow::grappler::GrapplerItem::EnqueueOpsFanin
[graph_properties] # tf_item
tensorflow::grappler::GraphProperties::InferStatically
tensorflow::grappler::GraphProperties::GetOutputProperties
[grappler_item_builder] # tf_item
tensorflow::grappler::GrapplerItemFromMetaGraphDef
[topological_sort] # tf_item
tensorflow::grappler::TopologicalSort
[clusters/utils] # tf_cluster tf_optimizer
tensorflow::grappler::GetDeviceInfo
[costs/utils] # tf_optimizer tf_cluster
tensorflow::grappler::CostGraphToOpPerformanceData
tensorflow::grappler::GetDeviceInfo
[meta_optimizer] # tf_optimizer
tensorflow::grappler::MetaOptimizer::MetaOptimizer
tensorflow::grappler::MetaOptimizer::Optimize
tensorflow::grappler::MetaOptimizer::PrintResult
[clusters/cluster] # tf_cluster
tensorflow::grappler::Cluster::AllowSoftPlacement
tensorflow::grappler::Cluster::SetNumWarmupSteps
tensorflow::grappler::Cluster::DisableDetailedStats
tensorflow::grappler::Cluster::DetailedStatsEnabled
[single_machine] # tf_cluster
tensorflow::grappler::SingleMachine::SingleMachine
[op_level_cost_estimator] # tf_cluster
tensorflow::grappler::OpLevelCostEstimator::OpLevelCostEstimator
tensorflow::grappler::OpLevelCostEstimator::PredictCosts
tensorflow::grappler::OpLevelCostEstimator::GetDeviceInfo
[virtual_cluster] # tf_cluster
tensorflow::grappler::VirtualCluster::VirtualCluster
[graph_memory] # tf_cluster
tensorflow::grappler::GraphMemory::InferStatically
tensorflow::grappler::GraphMemory::InferDynamically
[measuring_cost_estimator] # tf_cluster
tensorflow::grappler::MeasuringCostEstimator::MeasuringCostEstimator
tensorflow::grappler::MeasuringCostEstimator::Initialize
tensorflow::grappler::MeasuringCostEstimator::PredictCosts
[devices] # tf_cluster
tensorflow::grappler::GetNumAvailableGPUs
tensorflow::grappler::GetNumAvailableLogicalCPUCores