[tf.data] Avoiding a possible float-to-integer-cast overflow.
PiperOrigin-RevId: 245738567
This commit is contained in:
parent
a709a894c2
commit
c38b41d7c8
@ -438,13 +438,13 @@ void Model::Optimize(int64 cpu_budget) {
|
|||||||
snapshot = output_->Snapshot(nullptr);
|
snapshot = output_->Snapshot(nullptr);
|
||||||
}
|
}
|
||||||
VLOG(2) << "Starting optimization of tunable parameters";
|
VLOG(2) << "Starting optimization of tunable parameters";
|
||||||
const int64 processing_time = TotalProcessingTime(snapshot);
|
const double processing_time = TotalProcessingTime(snapshot);
|
||||||
auto parameters = CollectTunableParameters(snapshot);
|
auto parameters = CollectTunableParameters(snapshot);
|
||||||
for (auto& pair : parameters) {
|
for (auto& pair : parameters) {
|
||||||
pair.second->value = 1;
|
pair.second->value = 1;
|
||||||
}
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
const int64 output_time = OutputTime(snapshot);
|
const double output_time = OutputTime(snapshot);
|
||||||
bool all_max = true;
|
bool all_max = true;
|
||||||
for (auto& pair : parameters) {
|
for (auto& pair : parameters) {
|
||||||
if (pair.second->value < pair.second->max) {
|
if (pair.second->value < pair.second->max) {
|
||||||
@ -455,15 +455,15 @@ void Model::Optimize(int64 cpu_budget) {
|
|||||||
if (output_time < processing_time / cpu_budget || all_max) {
|
if (output_time < processing_time / cpu_budget || all_max) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
int64 best_delta = -1;
|
double best_delta = -1.0L;
|
||||||
Parameter* best_parameter = nullptr;
|
Parameter* best_parameter = nullptr;
|
||||||
for (auto& pair : parameters) {
|
for (auto& pair : parameters) {
|
||||||
if (pair.second->value == pair.second->max) {
|
if (pair.second->value == pair.second->max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
pair.second->value++;
|
pair.second->value++;
|
||||||
int64 new_output_time = OutputTime(snapshot);
|
double new_output_time = OutputTime(snapshot);
|
||||||
int64 delta = output_time - new_output_time;
|
double delta = output_time - new_output_time;
|
||||||
if (delta > best_delta) {
|
if (delta > best_delta) {
|
||||||
best_delta = delta;
|
best_delta = delta;
|
||||||
best_parameter = pair.second.get();
|
best_parameter = pair.second.get();
|
||||||
@ -551,7 +551,7 @@ std::map<string, std::shared_ptr<Parameter>> Model::CollectTunableParameters(
|
|||||||
return parameters;
|
return parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 Model::OutputTime(std::shared_ptr<Node> node) {
|
double Model::OutputTime(std::shared_ptr<Node> node) {
|
||||||
std::vector<double> input_times(1, 0);
|
std::vector<double> input_times(1, 0);
|
||||||
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
|
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
|
||||||
// computation, assuming that the input is infinitely fast will result in
|
// computation, assuming that the input is infinitely fast will result in
|
||||||
@ -562,7 +562,7 @@ int64 Model::OutputTime(std::shared_ptr<Node> node) {
|
|||||||
return node->OutputTime(&input_times);
|
return node->OutputTime(&input_times);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 Model::TotalProcessingTime(std::shared_ptr<Node> node) {
|
double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
|
||||||
return node->TotalProcessingTime();
|
return node->TotalProcessingTime();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -481,10 +481,10 @@ class Model {
|
|||||||
std::shared_ptr<Node> node);
|
std::shared_ptr<Node> node);
|
||||||
|
|
||||||
// Collects the output time for the given node.
|
// Collects the output time for the given node.
|
||||||
int64 OutputTime(std::shared_ptr<Node> node);
|
double OutputTime(std::shared_ptr<Node> node);
|
||||||
|
|
||||||
// Collects the processing time for the given node.
|
// Collects the processing time for the given node.
|
||||||
int64 TotalProcessingTime(std::shared_ptr<Node> node);
|
double TotalProcessingTime(std::shared_ptr<Node> node);
|
||||||
|
|
||||||
// Used for coordination between different input pipeline threads. Exclusive
|
// Used for coordination between different input pipeline threads. Exclusive
|
||||||
// access is required only when adding or removing nodes. Concurrent access to
|
// access is required only when adding or removing nodes. Concurrent access to
|
||||||
|
Loading…
Reference in New Issue
Block a user