[tf.data] Avoiding a possible float-to-integer-cast overflow.

PiperOrigin-RevId: 245738567
This commit is contained in:
Jiri Simsa 2019-04-29 06:59:14 -07:00 committed by TensorFlower Gardener
parent a709a894c2
commit c38b41d7c8
2 changed files with 9 additions and 9 deletions

View File

@ -438,13 +438,13 @@ void Model::Optimize(int64 cpu_budget) {
snapshot = output_->Snapshot(nullptr); snapshot = output_->Snapshot(nullptr);
} }
VLOG(2) << "Starting optimization of tunable parameters"; VLOG(2) << "Starting optimization of tunable parameters";
const int64 processing_time = TotalProcessingTime(snapshot); const double processing_time = TotalProcessingTime(snapshot);
auto parameters = CollectTunableParameters(snapshot); auto parameters = CollectTunableParameters(snapshot);
for (auto& pair : parameters) { for (auto& pair : parameters) {
pair.second->value = 1; pair.second->value = 1;
} }
while (true) { while (true) {
const int64 output_time = OutputTime(snapshot); const double output_time = OutputTime(snapshot);
bool all_max = true; bool all_max = true;
for (auto& pair : parameters) { for (auto& pair : parameters) {
if (pair.second->value < pair.second->max) { if (pair.second->value < pair.second->max) {
@ -455,15 +455,15 @@ void Model::Optimize(int64 cpu_budget) {
if (output_time < processing_time / cpu_budget || all_max) { if (output_time < processing_time / cpu_budget || all_max) {
break; break;
} }
int64 best_delta = -1; double best_delta = -1.0L;
Parameter* best_parameter = nullptr; Parameter* best_parameter = nullptr;
for (auto& pair : parameters) { for (auto& pair : parameters) {
if (pair.second->value == pair.second->max) { if (pair.second->value == pair.second->max) {
continue; continue;
} }
pair.second->value++; pair.second->value++;
int64 new_output_time = OutputTime(snapshot); double new_output_time = OutputTime(snapshot);
int64 delta = output_time - new_output_time; double delta = output_time - new_output_time;
if (delta > best_delta) { if (delta > best_delta) {
best_delta = delta; best_delta = delta;
best_parameter = pair.second.get(); best_parameter = pair.second.get();
@ -551,7 +551,7 @@ std::map<string, std::shared_ptr<Parameter>> Model::CollectTunableParameters(
return parameters; return parameters;
} }
int64 Model::OutputTime(std::shared_ptr<Node> node) { double Model::OutputTime(std::shared_ptr<Node> node) {
std::vector<double> input_times(1, 0); std::vector<double> input_times(1, 0);
// TODO(jsimsa): Now that we are accounting for buffer size in wait time // TODO(jsimsa): Now that we are accounting for buffer size in wait time
// computation, assuming that the input is infinitely fast will result in // computation, assuming that the input is infinitely fast will result in
@ -562,7 +562,7 @@ int64 Model::OutputTime(std::shared_ptr<Node> node) {
return node->OutputTime(&input_times); return node->OutputTime(&input_times);
} }
int64 Model::TotalProcessingTime(std::shared_ptr<Node> node) { double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
return node->TotalProcessingTime(); return node->TotalProcessingTime();
} }

View File

@ -481,10 +481,10 @@ class Model {
std::shared_ptr<Node> node); std::shared_ptr<Node> node);
// Collects the output time for the given node. // Collects the output time for the given node.
int64 OutputTime(std::shared_ptr<Node> node); double OutputTime(std::shared_ptr<Node> node);
// Collects the processing time for the given node. // Collects the processing time for the given node.
int64 TotalProcessingTime(std::shared_ptr<Node> node); double TotalProcessingTime(std::shared_ptr<Node> node);
// Used for coordination between different input pipeline threads. Exclusive // Used for coordination between different input pipeline threads. Exclusive
// access is required only when adding or removing nodes. Concurrent access to // access is required only when adding or removing nodes. Concurrent access to