[tf.data] Avoid calling CollectTunableParameters
function repeatedly.
PiperOrigin-RevId: 319168243 Change-Id: I8f3f0389820250bd0f6f9727f4ad10b6cf5c8ea9
This commit is contained in:
parent
3252cc128c
commit
6d0cf63bb2
@ -1320,14 +1320,14 @@ Model::CollectTunableParameters(std::shared_ptr<Node> node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
||||||
Model::CollectEssentialParallelism(std::shared_ptr<Node> node) {
|
Model::CollectEssentialParallelism(
|
||||||
|
std::shared_ptr<Node> node,
|
||||||
|
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters) {
|
||||||
// Parallelism parameter is considered to be essential if the corresponding
|
// Parallelism parameter is considered to be essential if the corresponding
|
||||||
// transformations's processing time is greater than essential rate times the
|
// transformations's processing time is greater than essential rate times the
|
||||||
// average transformation self processing time.
|
// average transformation self processing time.
|
||||||
constexpr double kEssentialRate = 0.3L;
|
constexpr double kEssentialRate = 0.3L;
|
||||||
|
|
||||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters;
|
|
||||||
node->CollectTunableParameters(¶meters);
|
|
||||||
absl::flat_hash_map<string, double> processing_times;
|
absl::flat_hash_map<string, double> processing_times;
|
||||||
double processing_time = node->TotalProcessingTime(&processing_times);
|
double processing_time = node->TotalProcessingTime(&processing_times);
|
||||||
double uniform_share =
|
double uniform_share =
|
||||||
@ -1350,7 +1350,7 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) {
|
|||||||
}
|
}
|
||||||
VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
|
VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
|
||||||
auto parameters = CollectTunableParameters(snapshot);
|
auto parameters = CollectTunableParameters(snapshot);
|
||||||
auto essential_parameters = CollectEssentialParallelism(snapshot);
|
auto essential_parameters = CollectEssentialParallelism(snapshot, parameters);
|
||||||
// We add the number of model's buffered bytes because it is excluded from the
|
// We add the number of model's buffered bytes because it is excluded from the
|
||||||
// memory budget, but it is included in the maximum number of buffered bytes.
|
// memory budget, but it is included in the maximum number of buffered bytes.
|
||||||
ram_budget += TotalBufferedBytes(snapshot);
|
ram_budget += TotalBufferedBytes(snapshot);
|
||||||
|
@ -628,7 +628,10 @@ class Model {
|
|||||||
// relative to other transformations. The collected parameters are returned
|
// relative to other transformations. The collected parameters are returned
|
||||||
// as a mapping from a (unique) node name to a parallelism parameter.
|
// as a mapping from a (unique) node name to a parallelism parameter.
|
||||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
||||||
CollectEssentialParallelism(std::shared_ptr<Node> node);
|
CollectEssentialParallelism(
|
||||||
|
std::shared_ptr<Node> node,
|
||||||
|
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
|
||||||
|
parameters);
|
||||||
|
|
||||||
// This optimization algorithm starts by setting all tunable parallelism
|
// This optimization algorithm starts by setting all tunable parallelism
|
||||||
// parameters to the minimum value. It then repeatedly identifies the
|
// parameters to the minimum value. It then repeatedly identifies the
|
||||||
|
Loading…
x
Reference in New Issue
Block a user