Fix 64-bit integer portability problems in TensorFlow kernels.

Removes reliance on the assumption that tensorflow::int64 is long long. This is intended to eventually enable changing the definition to int64_t from <cstdint>.

PiperOrigin-RevId: 290872365
Change-Id: I18534aeabf153d65c3521599855f8cca279fce51
This commit is contained in:
A. Unique TensorFlower 2020-01-21 19:11:06 -08:00 committed by TensorFlower Gardener
parent 2cbb324ceb
commit a1bc56203f
14 changed files with 22 additions and 20 deletions

View File

@ -5369,6 +5369,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)

View File

@ -432,7 +432,7 @@ void SerialDeviceBatchScheduler<TaskType>::ProcessBatches() {
// the desired target pending.
in_flight_batches_limit_ +=
std::round(options_.target_pending - avg_pending);
in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1LL);
in_flight_batches_limit_ = std::max(in_flight_batches_limit_, int64{1});
in_flight_batches_limit_ =
std::min(in_flight_batches_limit_, options_.num_batch_threads);
// Add extra processing threads if necessary.

View File

@ -433,6 +433,7 @@ tf_kernel_library(
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
],
)

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <random>
#include "absl/strings/str_format.h"
#include "absl/time/clock.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -1420,9 +1421,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
string GetSnapshotFilename() {
mutex_lock l(mu_);
string snapshot_data_filename = io::JoinPath(
run_dir_,
absl::StrCat(strings::Printf("%08llu", next_file_index_),
".snapshot"));
run_dir_, absl::StrFormat("%08u.snapshot", next_file_index_));
next_file_index_++;
return snapshot_data_filename;
}

View File

@ -68,9 +68,9 @@ class RangeDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override {
if (step_ > 0) {
return std::max(0LL, (stop_ - start_ - 1) / step_ + 1);
return std::max(int64{0}, (stop_ - start_ - 1) / step_ + 1);
} else {
return std::max(0LL, (start_ - stop_ - 1) / -step_ + 1);
return std::max(int64{0}, (start_ - stop_ - 1) / -step_ + 1);
}
}

View File

@ -72,7 +72,7 @@ class SkipDatasetOp::Dataset : public DatasetBase {
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return count_ < 0 ? 0 : std::max(0LL, n - count_);
return count_ < 0 ? 0 : std::max(int64{0}, n - count_);
}
Status CheckExternalState() const override {

View File

@ -607,7 +607,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
// so the factor 0.01 (i.e. 1/100) with a max of 10000, was chosen to limit
// the work unit cost to an operating range in which it emperically performed
// best.
const int64 work_unit_cost = std::max(int64{10000}, work_unit_size / 100LL);
const int64 work_unit_cost = std::max(int64{10000}, work_unit_size / 100);
const DeviceBase::CpuWorkerThreads& worker_threads =
*(context->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,

View File

@ -718,8 +718,8 @@ inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
.unaryExpr(int64_right_shift_op<32>())) +
(input_offset_fp - output_offset_fp + rounding_delta);
auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>());
auto input_requantized = intermediate.cwiseMax(0LL)
.cwiseMin(255LL)
auto input_requantized = intermediate.cwiseMax(int64{0})
.cwiseMin(int64{255})
.template cast<int32>()
.template cast<quint8>();
output->flat<quint8>().device(device) = input_requantized;

View File

@ -275,7 +275,7 @@ class ResizeAreaOp : public OpKernel {
private:
static EIGEN_ALWAYS_INLINE int64 Bound(int64 val, int64 limit) {
return std::min(limit - 1ll, std::max(int64{0}, val));
return std::min(limit - 1, std::max(int64{0}, val));
}
bool align_corners_;

View File

@ -66,7 +66,7 @@ const float* GetCoeffsTable(const bool use_keys_cubic) {
}
inline int64 Bound(int64 val, int64 limit) {
return std::min(limit - 1ll, std::max(int64{0}, val));
return std::min(limit - 1, std::max(int64{0}, val));
}
struct WeightsAndIndices {

View File

@ -81,7 +81,7 @@ class ResizeBicubicOpTest : public OpsTestBase {
// Used in the baseline implementation
inline int64 Bound(int64 val, int64 limit) {
return std::min(limit - 1ll, std::max(int64{0}, val));
return std::min(limit - 1, std::max(int64{0}, val));
}
// Used in the baseline implementation

View File

@ -18,6 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <stdint.h>
#include <atomic>
#include <limits>
#include <memory>
@ -25,6 +26,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/strings/str_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@ -47,7 +49,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@ -103,7 +104,7 @@ struct ComputeOptions {
static_cast<int64>(num_dense_features) <=
std::numeric_limits<int>::max(),
errors::InvalidArgument(
strings::Printf("Too many feature groups: %lld > %d",
absl::StrFormat("Too many feature groups: %d > %d",
static_cast<int64>(num_sparse_features) +
static_cast<int64>(num_dense_features),
std::numeric_limits<int>::max())));

View File

@ -202,9 +202,9 @@ class SparseReduceOp : public OpKernel {
}
auto CoordinatesToFlatIndex = [](ArraySlice<int64> coords,
ArraySlice<int64> strides) {
ArraySlice<int64> strides) -> int64 {
if (strides.empty()) { // Reduce all.
return 0LL;
return 0;
}
CHECK_EQ(coords.size(), strides.size());
int64 idx = 0;

View File

@ -308,15 +308,15 @@ TEST(SparseUtils, FindConfigValueForKey) {
TEST(SparseUtils, GetLinearBucket) {
EXPECT_EQ(11, GetLinearBucket(11, 5));
EXPECT_EQ(11, GetLinearBucket(12, 5));
EXPECT_EQ(1, GetLinearBucket(4ll, 5ll));
EXPECT_EQ(1, GetLinearBucket(int64{4}, int64{5}));
}
TEST(SparseUtils, GetPowerBucket) {
EXPECT_EQ(6, GetPowerBucket(11, 5));
EXPECT_EQ(6, GetPowerBucket(12, 5));
EXPECT_EQ(1332, GetPowerBucket(1335, 11));
EXPECT_EQ(5, GetPowerBucket(5ll, 4ll));
EXPECT_EQ(1, GetPowerBucket(4ll, 1ll));
EXPECT_EQ(5, GetPowerBucket(int64{5}, int64{4}));
EXPECT_EQ(1, GetPowerBucket(int64{4}, int64{1}));
}
} // namespace