Fix 64-bit integer portability problems in TensorFlow grappler.

Removes reliance on the assumption that tensorflow::int64 is long long. This is intended to eventually enable changing the definition to int64_t from <cstdint>.

PiperOrigin-RevId: 291219532
Change-Id: I7aee94f28022bfc3bd7cceb9a8943e97d47c41d0
This commit is contained in:
A. Unique TensorFlower 2020-01-23 12:22:57 -08:00 committed by TensorFlower Gardener
parent 4a561b44b9
commit 7b0cb335fb
3 changed files with 42 additions and 60 deletions
tensorflow/core/grappler/costs

View File

@ -138,6 +138,8 @@ tf_cuda_library(
deps = [
":cost_estimator",
"//third_party/eigen3",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:gpu_id",
@ -231,6 +233,7 @@ cc_library(
hdrs = ["virtual_scheduler.h"],
visibility = ["//visibility:public"],
deps = [
":cost_estimator",
":graph_properties",
":op_context",
":utils",
@ -242,7 +245,8 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/costs:cost_estimator",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

View File

@ -16,10 +16,12 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/utils.h"
#include <stddef.h>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
@ -37,7 +39,6 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@ -140,7 +141,7 @@ static void ExtractExtraProperties(
}
AttrValue attr;
attr.set_i(stat.length);
string attr_key = strings::StrCat("input_", i, "_filesize");
string attr_key = absl::StrCat("input_", i, "_filesize");
(*op_info->mutable_attr())[attr_key] = attr;
}
}
@ -149,7 +150,7 @@ static void ExtractExtraProperties(
// in the op itself is not sufficient to predict the op memory.
if (op_def && i < op_def->input_arg_size() &&
op_def->input_arg(i).name().find("handle") != string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
string new_key = absl::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
(*op_info->mutable_attr())[new_key] = attr;
@ -353,49 +354,31 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
}
string TensorSizeHistogram::ToString() const {
string r;
char buf[200];
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
r.append(buf);
r.append(strings::HumanReadableNumBytes(Average()));
r.append(", Min: ");
r.append(strings::HumanReadableNumBytes(min_));
r.append(", Max: ");
r.append(strings::HumanReadableNumBytes(max_));
r.append("\n------------------------------------------------------\n");
string r = absl::StrFormat(
"Count: %lld, Average: %s, Min: %s, Max: %s"
"\n------------------------------------------------------\n",
num_elem_, strings::HumanReadableNumBytes(Average()),
strings::HumanReadableNumBytes(min_),
strings::HumanReadableNumBytes(max_));
const double mult = num_elem_ > 0 ? 100.0 / num_elem_ : 0.0;
uint64 cumul_sum = 0;
const int size_string_width = 12;
for (int i = 0; i < buckets_.size(); i++) {
if (buckets_[i] == 0) continue;
cumul_sum += buckets_[i];
r.append("[ ");
if (i == 0) {
r.append(size_string_width - 2, ' ');
r.append("0B");
} else {
uint64 left = 1ULL << (i - 1);
const auto left_string = strings::HumanReadableNumBytes(left);
r.append(size_string_width - left_string.size(), ' ');
r.append(left_string);
}
r.append(", ");
uint64 left = i == 0 ? 0ULL : 1ULL << (i - 1);
uint64 right = 1ULL << i;
const auto right_string = strings::HumanReadableNumBytes(right);
r.append(size_string_width - right_string.size(), ' ');
r.append(right_string);
snprintf(buf, sizeof(buf), ") %7lld %7.3f%% %7.3f%% ",
buckets_[i], // count
mult * buckets_[i], // percentage
mult * cumul_sum); // cum percentage
r.append(buf);
absl::StrAppendFormat(&r, "[ %12s, %12s) %7d %7.3f%% %7.3f%% ",
strings::HumanReadableNumBytes(left),
strings::HumanReadableNumBytes(right),
buckets_[i], // count
mult * buckets_[i], // percentage
mult * cumul_sum); // cumulative percentage
// Add hash marks based on percentage; 40 marks for 100%.
auto marks = static_cast<int>(
(static_cast<double>(40 * buckets_[i] + (num_elem_ >> 1)) / num_elem_));
r.append(marks, '#');
r.push_back('\n');
absl::StrAppendFormat(&r, "%s\n", std::string(marks, '#'));
}
return r;
}
@ -422,7 +405,7 @@ string GetDeviceClassForNonChannelDevice(const string& device_name) {
}
if (parsed) {
const string jobname = parsed_name.has_job ? parsed_name.job : "";
return strings::StrCat("/", jobname, "/", parsed_name.type);
return absl::StrCat("/", jobname, "/", parsed_name.type);
} else {
return "Unclassified";
}
@ -440,7 +423,7 @@ string GetDeviceClass(const string& device_name) {
const auto src_device_full = device_name.substr(
from_loc + from.size(), to_loc - (from_loc + from.size()));
const auto dst_device_full = device_name.substr(to_loc + to.size());
return strings::StrCat(
return absl::StrCat(
"Channel", ": ", GetDeviceClassForNonChannelDevice(src_device_full),
" -> ", GetDeviceClassForNonChannelDevice(dst_device_full));
} else {

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_replace.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -27,8 +29,6 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -37,7 +37,6 @@ namespace grappler {
namespace {
using ::absl::StrCat;
using ::tensorflow::strings::HumanReadableNumBytes;
constexpr char kAttrInputSrc[] = "input_source_";
@ -606,15 +605,15 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
// Replace the ":" characters that may be present in the device name with "_".
// This makes it possible to then use the resulting string in a node name.
return str_util::StringReplace(placer_->get_canonical_device_name(*node), ":",
"_", true);
return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
{{":", "_"}});
}
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const {
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
return StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from), "_to_",
SanitizedDeviceName(to));
return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
"_to_", SanitizedDeviceName(to));
}
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
@ -636,9 +635,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
auto input_node_port_num = NodePosition(input_name);
string src_name;
if (input_node_port_num >= 0) {
src_name = StrCat(from->name(), "_", input_node_port_num);
src_name = absl::StrCat(from->name(), "_", input_node_port_num);
} else {
src_name = StrCat(from->name(), "_minus1");
src_name = absl::StrCat(from->name(), "_minus1");
}
// _Send op.
@ -967,11 +966,10 @@ Costs VirtualScheduler::Summary() const {
op_cost_pair.second.intermediate_memory_time.count();
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (cost) { // Skip printing out zero-cost ops.
VLOG(1) << strings::Printf(
" + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
(is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
static_cast<int64>(intermediate_memory_cost));
VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op,
(is_op_cost_accurate ? ' ' : '~'), cost,
compute_cost, memory_cost,
intermediate_memory_cost);
}
}
@ -1072,13 +1070,10 @@ Costs VirtualScheduler::Summary() const {
: 0.0;
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
VLOG(1) << strings::Printf(
" + %30s : %c %10lld / %10lld / %10lld / %10lld",
op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
static_cast<int64>(cost),
static_cast<int64>(compute_cost),
static_cast<int64>(memory_cost),
static_cast<int64>(intermediate_memory_cost))
VLOG(1) << absl::StrFormat(
" + %30s : %c %10d / %10d / %10d / %10d", op.c_str(),
(is_op_cost_accurate ? ' ' : '~'), cost, compute_cost,
memory_cost, intermediate_memory_cost)
<< " (" << HumanReadableNumBytes(op_mem_usage) << " ["
<< mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");