Fix 64-bit integer portability problems in TensorFlow grappler.
Removes reliance on the assumption that tensorflow::int64 is long long. This is intended to eventually enable changing the definition to int64_t from <cstdint>. PiperOrigin-RevId: 291219532 Change-Id: I7aee94f28022bfc3bd7cceb9a8943e97d47c41d0
This commit is contained in:
parent
4a561b44b9
commit
7b0cb335fb
tensorflow/core/grappler/costs
@ -138,6 +138,8 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:gpu_id",
|
||||
@ -231,6 +233,7 @@ cc_library(
|
||||
hdrs = ["virtual_scheduler.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
":graph_properties",
|
||||
":op_context",
|
||||
":utils",
|
||||
@ -242,7 +245,8 @@ cc_library(
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//tensorflow/core/grappler/costs:cost_estimator",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,10 +16,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
@ -37,7 +39,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -140,7 +141,7 @@ static void ExtractExtraProperties(
|
||||
}
|
||||
AttrValue attr;
|
||||
attr.set_i(stat.length);
|
||||
string attr_key = strings::StrCat("input_", i, "_filesize");
|
||||
string attr_key = absl::StrCat("input_", i, "_filesize");
|
||||
(*op_info->mutable_attr())[attr_key] = attr;
|
||||
}
|
||||
}
|
||||
@ -149,7 +150,7 @@ static void ExtractExtraProperties(
|
||||
// in the op itself is not sufficient to predict the op memory.
|
||||
if (op_def && i < op_def->input_arg_size() &&
|
||||
op_def->input_arg(i).name().find("handle") != string::npos) {
|
||||
string new_key = strings::StrCat("parent_", i, "_op");
|
||||
string new_key = absl::StrCat("parent_", i, "_op");
|
||||
AttrValue attr;
|
||||
attr.set_s(input_node->op());
|
||||
(*op_info->mutable_attr())[new_key] = attr;
|
||||
@ -353,49 +354,31 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
|
||||
}
|
||||
|
||||
string TensorSizeHistogram::ToString() const {
|
||||
string r;
|
||||
char buf[200];
|
||||
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
|
||||
r.append(buf);
|
||||
r.append(strings::HumanReadableNumBytes(Average()));
|
||||
r.append(", Min: ");
|
||||
r.append(strings::HumanReadableNumBytes(min_));
|
||||
r.append(", Max: ");
|
||||
r.append(strings::HumanReadableNumBytes(max_));
|
||||
r.append("\n------------------------------------------------------\n");
|
||||
string r = absl::StrFormat(
|
||||
"Count: %lld, Average: %s, Min: %s, Max: %s"
|
||||
"\n------------------------------------------------------\n",
|
||||
num_elem_, strings::HumanReadableNumBytes(Average()),
|
||||
strings::HumanReadableNumBytes(min_),
|
||||
strings::HumanReadableNumBytes(max_));
|
||||
const double mult = num_elem_ > 0 ? 100.0 / num_elem_ : 0.0;
|
||||
uint64 cumul_sum = 0;
|
||||
|
||||
const int size_string_width = 12;
|
||||
for (int i = 0; i < buckets_.size(); i++) {
|
||||
if (buckets_[i] == 0) continue;
|
||||
cumul_sum += buckets_[i];
|
||||
r.append("[ ");
|
||||
if (i == 0) {
|
||||
r.append(size_string_width - 2, ' ');
|
||||
r.append("0B");
|
||||
} else {
|
||||
uint64 left = 1ULL << (i - 1);
|
||||
const auto left_string = strings::HumanReadableNumBytes(left);
|
||||
r.append(size_string_width - left_string.size(), ' ');
|
||||
r.append(left_string);
|
||||
}
|
||||
r.append(", ");
|
||||
uint64 left = i == 0 ? 0ULL : 1ULL << (i - 1);
|
||||
uint64 right = 1ULL << i;
|
||||
const auto right_string = strings::HumanReadableNumBytes(right);
|
||||
r.append(size_string_width - right_string.size(), ' ');
|
||||
r.append(right_string);
|
||||
snprintf(buf, sizeof(buf), ") %7lld %7.3f%% %7.3f%% ",
|
||||
buckets_[i], // count
|
||||
mult * buckets_[i], // percentage
|
||||
mult * cumul_sum); // cum percentage
|
||||
r.append(buf);
|
||||
absl::StrAppendFormat(&r, "[ %12s, %12s) %7d %7.3f%% %7.3f%% ",
|
||||
strings::HumanReadableNumBytes(left),
|
||||
strings::HumanReadableNumBytes(right),
|
||||
buckets_[i], // count
|
||||
mult * buckets_[i], // percentage
|
||||
mult * cumul_sum); // cumulative percentage
|
||||
|
||||
// Add hash marks based on percentage; 40 marks for 100%.
|
||||
auto marks = static_cast<int>(
|
||||
(static_cast<double>(40 * buckets_[i] + (num_elem_ >> 1)) / num_elem_));
|
||||
r.append(marks, '#');
|
||||
r.push_back('\n');
|
||||
absl::StrAppendFormat(&r, "%s\n", std::string(marks, '#'));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@ -422,7 +405,7 @@ string GetDeviceClassForNonChannelDevice(const string& device_name) {
|
||||
}
|
||||
if (parsed) {
|
||||
const string jobname = parsed_name.has_job ? parsed_name.job : "";
|
||||
return strings::StrCat("/", jobname, "/", parsed_name.type);
|
||||
return absl::StrCat("/", jobname, "/", parsed_name.type);
|
||||
} else {
|
||||
return "Unclassified";
|
||||
}
|
||||
@ -440,7 +423,7 @@ string GetDeviceClass(const string& device_name) {
|
||||
const auto src_device_full = device_name.substr(
|
||||
from_loc + from.size(), to_loc - (from_loc + from.size()));
|
||||
const auto dst_device_full = device_name.substr(to_loc + to.size());
|
||||
return strings::StrCat(
|
||||
return absl::StrCat(
|
||||
"Channel", ": ", GetDeviceClassForNonChannelDevice(src_device_full),
|
||||
" -> ", GetDeviceClassForNonChannelDevice(dst_device_full));
|
||||
} else {
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -27,8 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
@ -37,7 +37,6 @@ namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::absl::StrCat;
|
||||
using ::tensorflow::strings::HumanReadableNumBytes;
|
||||
|
||||
constexpr char kAttrInputSrc[] = "input_source_";
|
||||
@ -606,15 +605,15 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
|
||||
// Replace the ":" characters that may be present in the device name with "_".
|
||||
// This makes it possible to then use the resulting string in a node name.
|
||||
return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":",
|
||||
"_", true);
|
||||
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
|
||||
{{":", "_"}});
|
||||
}
|
||||
|
||||
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
|
||||
const NodeDef* to) const {
|
||||
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
|
||||
return StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from), "_to_",
|
||||
SanitizedDeviceName(to));
|
||||
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
|
||||
"_to_", SanitizedDeviceName(to));
|
||||
}
|
||||
|
||||
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||
@ -636,9 +635,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||
auto input_node_port_num = NodePosition(input_name);
|
||||
string src_name;
|
||||
if (input_node_port_num >= 0) {
|
||||
src_name = StrCat(from->name(), "_", input_node_port_num);
|
||||
src_name = absl::StrCat(from->name(), "_", input_node_port_num);
|
||||
} else {
|
||||
src_name = StrCat(from->name(), "_minus1");
|
||||
src_name = absl::StrCat(from->name(), "_minus1");
|
||||
}
|
||||
|
||||
// _Send op.
|
||||
@ -967,11 +966,10 @@ Costs VirtualScheduler::Summary() const {
|
||||
op_cost_pair.second.intermediate_memory_time.count();
|
||||
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
|
||||
if (cost) { // Skip printing out zero-cost ops.
|
||||
VLOG(1) << strings::Printf(
|
||||
" + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
|
||||
(is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
|
||||
static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
|
||||
static_cast<int64>(intermediate_memory_cost));
|
||||
VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op,
|
||||
(is_op_cost_accurate ? ' ' : '~'), cost,
|
||||
compute_cost, memory_cost,
|
||||
intermediate_memory_cost);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1072,13 +1070,10 @@ Costs VirtualScheduler::Summary() const {
|
||||
: 0.0;
|
||||
if (cost || mem_usage_percent > 1.0) {
|
||||
// Print out only non-zero cost ops or ops with > 1% memory usage.
|
||||
VLOG(1) << strings::Printf(
|
||||
" + %30s : %c %10lld / %10lld / %10lld / %10lld",
|
||||
op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
|
||||
static_cast<int64>(cost),
|
||||
static_cast<int64>(compute_cost),
|
||||
static_cast<int64>(memory_cost),
|
||||
static_cast<int64>(intermediate_memory_cost))
|
||||
VLOG(1) << absl::StrFormat(
|
||||
" + %30s : %c %10d / %10d / %10d / %10d", op.c_str(),
|
||||
(is_op_cost_accurate ? ' ' : '~'), cost, compute_cost,
|
||||
memory_cost, intermediate_memory_cost)
|
||||
<< " (" << HumanReadableNumBytes(op_mem_usage) << " ["
|
||||
<< mem_usage_percent << "%] "
|
||||
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
|
||||
|
Loading…
Reference in New Issue
Block a user