[tf.data] Avoiding a possible float-to-integer-cast overflow.

PiperOrigin-RevId: 245738567
This commit is contained in:
Jiri Simsa 2019-04-29 06:59:14 -07:00 committed by TensorFlower Gardener
parent a709a894c2
commit c38b41d7c8
2 changed files with 9 additions and 9 deletions

View File

@ -438,13 +438,13 @@ void Model::Optimize(int64 cpu_budget) {
snapshot = output_->Snapshot(nullptr);
}
VLOG(2) << "Starting optimization of tunable parameters";
const int64 processing_time = TotalProcessingTime(snapshot);
const double processing_time = TotalProcessingTime(snapshot);
auto parameters = CollectTunableParameters(snapshot);
for (auto& pair : parameters) {
pair.second->value = 1;
}
while (true) {
const int64 output_time = OutputTime(snapshot);
const double output_time = OutputTime(snapshot);
bool all_max = true;
for (auto& pair : parameters) {
if (pair.second->value < pair.second->max) {
@ -455,15 +455,15 @@ void Model::Optimize(int64 cpu_budget) {
if (output_time < processing_time / cpu_budget || all_max) {
break;
}
int64 best_delta = -1;
double best_delta = -1.0L;
Parameter* best_parameter = nullptr;
for (auto& pair : parameters) {
if (pair.second->value == pair.second->max) {
continue;
}
pair.second->value++;
int64 new_output_time = OutputTime(snapshot);
int64 delta = output_time - new_output_time;
double new_output_time = OutputTime(snapshot);
double delta = output_time - new_output_time;
if (delta > best_delta) {
best_delta = delta;
best_parameter = pair.second.get();
@ -551,7 +551,7 @@ std::map<string, std::shared_ptr<Parameter>> Model::CollectTunableParameters(
return parameters;
}
int64 Model::OutputTime(std::shared_ptr<Node> node) {
double Model::OutputTime(std::shared_ptr<Node> node) {
std::vector<double> input_times(1, 0);
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
// computation, assuming that the input is infinitely fast will result in
@ -562,7 +562,7 @@ int64 Model::OutputTime(std::shared_ptr<Node> node) {
return node->OutputTime(&input_times);
}
int64 Model::TotalProcessingTime(std::shared_ptr<Node> node) {
double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
return node->TotalProcessingTime();
}

View File

@ -481,10 +481,10 @@ class Model {
std::shared_ptr<Node> node);
// Collects the output time for the given node.
int64 OutputTime(std::shared_ptr<Node> node);
double OutputTime(std::shared_ptr<Node> node);
// Collects the processing time for the given node.
int64 TotalProcessingTime(std::shared_ptr<Node> node);
double TotalProcessingTime(std::shared_ptr<Node> node);
// Used for coordination between different input pipeline threads. Exclusive
// access is required only when adding or removing nodes. Concurrent access to