[tf.data] Avoid calling CollectTunableParameters
function repeatedly.
PiperOrigin-RevId: 319168243 Change-Id: I8f3f0389820250bd0f6f9727f4ad10b6cf5c8ea9
This commit is contained in:
parent
3252cc128c
commit
6d0cf63bb2
@ -1320,14 +1320,14 @@ Model::CollectTunableParameters(std::shared_ptr<Node> node) {
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
||||
Model::CollectEssentialParallelism(std::shared_ptr<Node> node) {
|
||||
Model::CollectEssentialParallelism(
|
||||
std::shared_ptr<Node> node,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters) {
|
||||
// Parallelism parameter is considered to be essential if the corresponding
|
||||
// transformations's processing time is greater than essential rate times the
|
||||
// average transformation self processing time.
|
||||
constexpr double kEssentialRate = 0.3L;
|
||||
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters;
|
||||
node->CollectTunableParameters(¶meters);
|
||||
absl::flat_hash_map<string, double> processing_times;
|
||||
double processing_time = node->TotalProcessingTime(&processing_times);
|
||||
double uniform_share =
|
||||
@ -1350,7 +1350,7 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) {
|
||||
}
|
||||
VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
|
||||
auto parameters = CollectTunableParameters(snapshot);
|
||||
auto essential_parameters = CollectEssentialParallelism(snapshot);
|
||||
auto essential_parameters = CollectEssentialParallelism(snapshot, parameters);
|
||||
// We add the number of model's buffered bytes because it is excluded from the
|
||||
// memory budget, but it is included in the maximum number of buffered bytes.
|
||||
ram_budget += TotalBufferedBytes(snapshot);
|
||||
|
@ -628,7 +628,10 @@ class Model {
|
||||
// relative to other transformations. The collected parameters are returned
|
||||
// as a mapping from a (unique) node name to a parallelism parameter.
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
||||
CollectEssentialParallelism(std::shared_ptr<Node> node);
|
||||
CollectEssentialParallelism(
|
||||
std::shared_ptr<Node> node,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
|
||||
parameters);
|
||||
|
||||
// This optimization algorithm starts by setting all tunable parallelism
|
||||
// parameters to the minimum value. It then repeatedly identifies the
|
||||
|
Loading…
x
Reference in New Issue
Block a user