[tf.data] Avoiding a possible float-to-integer-cast overflow.
PiperOrigin-RevId: 245738567
This commit is contained in:
parent
a709a894c2
commit
c38b41d7c8
@ -438,13 +438,13 @@ void Model::Optimize(int64 cpu_budget) {
|
||||
snapshot = output_->Snapshot(nullptr);
|
||||
}
|
||||
VLOG(2) << "Starting optimization of tunable parameters";
|
||||
const int64 processing_time = TotalProcessingTime(snapshot);
|
||||
const double processing_time = TotalProcessingTime(snapshot);
|
||||
auto parameters = CollectTunableParameters(snapshot);
|
||||
for (auto& pair : parameters) {
|
||||
pair.second->value = 1;
|
||||
}
|
||||
while (true) {
|
||||
const int64 output_time = OutputTime(snapshot);
|
||||
const double output_time = OutputTime(snapshot);
|
||||
bool all_max = true;
|
||||
for (auto& pair : parameters) {
|
||||
if (pair.second->value < pair.second->max) {
|
||||
@ -455,15 +455,15 @@ void Model::Optimize(int64 cpu_budget) {
|
||||
if (output_time < processing_time / cpu_budget || all_max) {
|
||||
break;
|
||||
}
|
||||
int64 best_delta = -1;
|
||||
double best_delta = -1.0L;
|
||||
Parameter* best_parameter = nullptr;
|
||||
for (auto& pair : parameters) {
|
||||
if (pair.second->value == pair.second->max) {
|
||||
continue;
|
||||
}
|
||||
pair.second->value++;
|
||||
int64 new_output_time = OutputTime(snapshot);
|
||||
int64 delta = output_time - new_output_time;
|
||||
double new_output_time = OutputTime(snapshot);
|
||||
double delta = output_time - new_output_time;
|
||||
if (delta > best_delta) {
|
||||
best_delta = delta;
|
||||
best_parameter = pair.second.get();
|
||||
@ -551,7 +551,7 @@ std::map<string, std::shared_ptr<Parameter>> Model::CollectTunableParameters(
|
||||
return parameters;
|
||||
}
|
||||
|
||||
int64 Model::OutputTime(std::shared_ptr<Node> node) {
|
||||
double Model::OutputTime(std::shared_ptr<Node> node) {
|
||||
std::vector<double> input_times(1, 0);
|
||||
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
|
||||
// computation, assuming that the input is infinitely fast will result in
|
||||
@ -562,7 +562,7 @@ int64 Model::OutputTime(std::shared_ptr<Node> node) {
|
||||
return node->OutputTime(&input_times);
|
||||
}
|
||||
|
||||
int64 Model::TotalProcessingTime(std::shared_ptr<Node> node) {
|
||||
double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
|
||||
return node->TotalProcessingTime();
|
||||
}
|
||||
|
||||
|
@ -481,10 +481,10 @@ class Model {
|
||||
std::shared_ptr<Node> node);
|
||||
|
||||
// Collects the output time for the given node.
|
||||
int64 OutputTime(std::shared_ptr<Node> node);
|
||||
double OutputTime(std::shared_ptr<Node> node);
|
||||
|
||||
// Collects the processing time for the given node.
|
||||
int64 TotalProcessingTime(std::shared_ptr<Node> node);
|
||||
double TotalProcessingTime(std::shared_ptr<Node> node);
|
||||
|
||||
// Used for coordination between different input pipeline threads. Exclusive
|
||||
// access is required only when adding or removing nodes. Concurrent access to
|
||||
|
Loading…
Reference in New Issue
Block a user