[tf.data] Avoid calling CollectTunableParameters function repeatedly.

PiperOrigin-RevId: 319168243
Change-Id: I8f3f0389820250bd0f6f9727f4ad10b6cf5c8ea9
This commit is contained in:
Jay Shi 2020-06-30 23:06:13 -07:00 committed by TensorFlower Gardener
parent 3252cc128c
commit 6d0cf63bb2
2 changed files with 8 additions and 5 deletions

View File

@ -1320,14 +1320,14 @@ Model::CollectTunableParameters(std::shared_ptr<Node> node) {
}
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
Model::CollectEssentialParallelism(std::shared_ptr<Node> node) {
Model::CollectEssentialParallelism(
std::shared_ptr<Node> node,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters) {
// Parallelism parameter is considered to be essential if the corresponding
// transformations's processing time is greater than essential rate times the
// average transformation self processing time.
constexpr double kEssentialRate = 0.3L;
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters;
node->CollectTunableParameters(&parameters);
absl::flat_hash_map<string, double> processing_times;
double processing_time = node->TotalProcessingTime(&processing_times);
double uniform_share =
@ -1350,7 +1350,7 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) {
}
VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
auto parameters = CollectTunableParameters(snapshot);
auto essential_parameters = CollectEssentialParallelism(snapshot);
auto essential_parameters = CollectEssentialParallelism(snapshot, parameters);
// We add the number of model's buffered bytes because it is excluded from the
// memory budget, but it is included in the maximum number of buffered bytes.
ram_budget += TotalBufferedBytes(snapshot);

View File

@ -628,7 +628,10 @@ class Model {
// relative to other transformations. The collected parameters are returned
// as a mapping from a (unique) node name to a parallelism parameter.
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
CollectEssentialParallelism(std::shared_ptr<Node> node);
CollectEssentialParallelism(
std::shared_ptr<Node> node,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
parameters);
// This optimization algorithm starts by setting all tunable parallelism
// parameters to the minimum value. It then repeatedly identifies the