diff --git a/scone/head/cli/__init__.py b/scone/head/cli/__init__.py index b7b09db..934361b 100644 --- a/scone/head/cli/__init__.py +++ b/scone/head/cli/__init__.py @@ -75,7 +75,7 @@ async def cli_async() -> int: if argp.menu: menu_subset = argp.menu.split(",") - head = Head.open(str(cdir), menu_subset) + head = Head.open(str(cdir)) eprint(head.debug_info()) @@ -92,6 +92,9 @@ async def cli_async() -> int: eprint(f"Selected the following souss: {', '.join(hosts)}") + head.load_variables(hosts) + head.load_menus(menu_subset, hosts) + eprint("Preparing recipes…") prepare = Preparation(head) diff --git a/scone/head/head.py b/scone/head/head.py index cd9c4c9..8cb8ade 100644 --- a/scone/head/head.py +++ b/scone/head/head.py @@ -21,7 +21,7 @@ import re import sys from os import path from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, cast +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast import toml from nacl.encoding import URLSafeBase64Encoder @@ -58,7 +58,7 @@ class Head: self.pools = pools @staticmethod - def open(directory: str, menu_subset: Optional[List[str]] = None): + def open(directory: str): with open(path.join(directory, "scone.head.toml")) as head_toml: head_data = toml.load(head_toml) @@ -84,8 +84,6 @@ class Head: pools = Pools() head = Head(directory, recipe_loader, sous, groups, secret_access, pools) - head._load_variables() - head._load_menus(menu_subset) return head def _preload_variables(self, who_for: str) -> Tuple[dict, dict]: @@ -131,12 +129,16 @@ class Head: return out_chilled, out_frozen - def _load_variables(self): + def load_variables(self, host_subset: Optional[Set[str]]): preload: Dict[str, Tuple[dict, dict]] = dict() for who_name in itertools.chain(self.souss, self.groups): + # TODO(performance): don't preload vars for deselected souss and + # groups preload[who_name] = self._preload_variables(who_name) for sous_name in self.souss: + if host_subset and sous_name not in host_subset: + continue order = ["all"] order += [ group @@ -159,14 +161,14 @@ class Head: self.variables[sous_name] = sous_vars - def _load_menus(self, subset: Optional[List[str]]): + def load_menus(self, subset: Optional[List[str]], host_subset: Set[str]): loader = MenuLoader(Path(self.directory, "menu"), self) if subset: for unit in subset: loader.load(unit) else: loader.load_menus_in_dir() - loader.dagify_all() + loader.dagify_all(host_subset) # TODO remove # def _construct_hostmenu_for( diff --git a/scone/head/menu_reader.py b/scone/head/menu_reader.py index a37b4da..421b030 100644 --- a/scone/head/menu_reader.py +++ b/scone/head/menu_reader.py @@ -20,7 +20,7 @@ import os import typing from collections import defaultdict, deque from pathlib import Path -from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Deque, Dict, Iterable, List, Optional, Set, Tuple, Union import attr import textx @@ -359,6 +359,7 @@ class MenuLoader: hierarchical_source: str, fors: Tuple[ForDirective, ...], applicable_souss: Iterable[str], + sous_mask: Optional[Set[str]], applicable_user: Optional[str], ): recipe_class = self._head.recipe_loader.get_class(recipe.kind) @@ -372,6 +373,9 @@ class MenuLoader: if recipe.sous_directive: applicable_souss = self._head.get_souss_for_hostspec(recipe.sous_directive) + if sous_mask: + applicable_souss = set(applicable_souss) + applicable_souss.intersection_update(sous_mask) for sous in applicable_souss: if not applicable_user: @@ -403,6 +407,7 @@ class MenuLoader: hierarchical_source: str, fors: Tuple[ForDirective, ...], applicable_souss: Iterable[str], + sous_mask: Optional[Set[str]], applicable_user: Optional[str], ): fors = fors + tuple(block.for_directives) @@ -412,6 +417,9 @@ class MenuLoader: if block.sous_directive: applicable_souss = self._head.get_souss_for_hostspec(block.sous_directive) + if sous_mask: + applicable_souss = set(applicable_souss) + applicable_souss.intersection_update(sous_mask) for content in block.contents: if isinstance(content, MenuBlock): @@ -421,6 +429,7 @@ class MenuLoader: f"{hierarchical_source}.{block_name}", fors, applicable_souss, + sous_mask, applicable_user, ) elif isinstance(content, MenuRecipe): @@ -429,6 +438,7 @@ class MenuLoader: hierarchical_source, fors, applicable_souss, + sous_mask, applicable_user, ) else: @@ -439,6 +449,7 @@ class MenuLoader: recipe: MenuRecipe, fors: Tuple[ForDirective, ...], applicable_souss: Iterable[str], + sous_mask: Optional[Set[str]], ): # TODO(feature): add edges @@ -447,6 +458,9 @@ class MenuLoader: if recipe.sous_directive: applicable_souss = self._head.get_souss_for_hostspec(recipe.sous_directive) + if sous_mask: + applicable_souss = set(applicable_souss) + applicable_souss.intersection_update(sous_mask) for sous in applicable_souss: sous_vars = self._head.variables[sous] @@ -490,6 +504,7 @@ class MenuLoader: block: MenuBlock, fors: Tuple[ForDirective, ...], applicable_souss: Iterable[str], + sous_mask: Optional[Set[str]], ): # XXX pass down specific edges here @@ -499,23 +514,31 @@ class MenuLoader: if block.sous_directive: applicable_souss = self._head.get_souss_for_hostspec(block.sous_directive) + if sous_mask: + applicable_souss = set(applicable_souss) + applicable_souss.intersection_update(sous_mask) for content in block.contents: if isinstance(content, MenuBlock): - self.postdagify_block(content, fors, applicable_souss) + self.postdagify_block(content, fors, applicable_souss, sous_mask) elif isinstance(content, MenuRecipe): - self.postdagify_recipe(content, fors, applicable_souss) + self.postdagify_recipe(content, fors, applicable_souss, sous_mask) else: raise ValueError(f"{content}?") - def dagify_all(self): + def dagify_all(self, sous_subset: Optional[Set[str]]): for name, unit in self._units.items(): self.dagify_block( - unit, name, tuple(), self._head.get_souss_for_hostspec("all"), None + unit, + name, + tuple(), + self._head.get_souss_for_hostspec("all"), + sous_subset, + None, ) for _name, unit in self._units.items(): self.postdagify_block( - unit, tuple(), self._head.get_souss_for_hostspec("all") + unit, tuple(), self._head.get_souss_for_hostspec("all"), sous_subset ) def _for_apply(