123 lines
3.4 KiB
Python
123 lines
3.4 KiB
Python
import subprocess
|
|
from pathlib import Path
|
|
from random import Random
|
|
from typing import Dict, Optional, Set
|
|
|
|
import attr
|
|
from helpers import DirectoryDescriptor, FileDescriptor
|
|
from immutabledict import immutabledict
|
|
|
|
DEFAULT_PILES_SECTION = """
|
|
[piles.main]
|
|
path = "main"
|
|
included_labels = ["precious"]
|
|
"""
|
|
|
|
MULTI_PILES_SECTION = """
|
|
[piles.pocket]
|
|
path = "pocket"
|
|
included_labels = ["pocket"]
|
|
|
|
[piles.precious]
|
|
path = "precious"
|
|
included_labels = ["precious"]
|
|
|
|
[piles.bulky]
|
|
path = "bulky"
|
|
included_labels = ["bulky"]
|
|
"""
|
|
|
|
|
|
def get_hostname():
|
|
return subprocess.check_output("hostname").strip().decode()
|
|
|
|
|
|
def set_up_simple_datman(
|
|
path: Path,
|
|
custom_extra_test: Optional[str] = None,
|
|
piles_section: str = DEFAULT_PILES_SECTION,
|
|
):
|
|
path.mkdir(exist_ok=True)
|
|
subprocess.check_call(("datman", "init"), cwd=path)
|
|
|
|
with path.joinpath("datman.toml").open("a") as file:
|
|
file.write(
|
|
f"""
|
|
[sources.srca]
|
|
directory = "{path.joinpath("srca")}"
|
|
hostname = "{get_hostname()}"
|
|
"""
|
|
+ piles_section
|
|
)
|
|
if custom_extra_test:
|
|
file.write(custom_extra_test)
|
|
|
|
|
|
def save_labelling_rules(path: Path, rules: Dict[str, str]):
|
|
with path.open("wb") as fout:
|
|
proc = subprocess.Popen(
|
|
["zstd", "-", "--stdout"], stdin=subprocess.PIPE, stdout=fout
|
|
)
|
|
for rule_k, rule_v in rules.items():
|
|
proc.stdin.write(f"{rule_k}\t{rule_v}\n".encode())
|
|
proc.stdin.write(b"---\n")
|
|
proc.stdin.close()
|
|
|
|
if proc.wait() != 0:
|
|
raise ChildProcessError(f"zstd failed with {proc.returncode}.")
|
|
|
|
|
|
def generate_labels(
|
|
dir_descriptor: DirectoryDescriptor,
|
|
rng: Random,
|
|
dict_in_place: Optional[Dict[str, str]] = None,
|
|
prefix: str = "",
|
|
) -> Dict[str, str]:
|
|
if not dict_in_place:
|
|
dict_in_place = dict()
|
|
|
|
# split on this.
|
|
dict_in_place[prefix] = "?"
|
|
|
|
for name, descriptor in dir_descriptor.contents.items():
|
|
if isinstance(descriptor, DirectoryDescriptor):
|
|
generate_labels(descriptor, rng, dict_in_place, prefix + "/" + name)
|
|
elif isinstance(descriptor, FileDescriptor):
|
|
dict_in_place[prefix + "/" + name] = rng.choice(
|
|
["bulky", "precious", "pocket", "!"]
|
|
)
|
|
|
|
return dict_in_place
|
|
|
|
|
|
def filter_descriptor_by_label(
|
|
labels: Set[str],
|
|
orig: DirectoryDescriptor,
|
|
label_map: Dict[str, str],
|
|
prefix: str = "",
|
|
) -> DirectoryDescriptor:
|
|
new_contents = {}
|
|
|
|
for key, value in orig.contents.items():
|
|
full_name = prefix + "/" + key
|
|
specified_filter = label_map[full_name]
|
|
# print(full_name, specified_filter, labels)
|
|
if not (specified_filter == "?" or specified_filter in labels):
|
|
continue
|
|
|
|
if isinstance(value, DirectoryDescriptor):
|
|
new_dd = filter_descriptor_by_label(labels, value, label_map, full_name)
|
|
if not new_dd.contents and specified_filter == "?":
|
|
# don't include splits that are empty.
|
|
continue
|
|
new_contents[key] = new_dd
|
|
elif isinstance(value, FileDescriptor):
|
|
assert (
|
|
specified_filter != "?"
|
|
), "why is there a split filter on a file descriptor?"
|
|
new_contents[key] = value
|
|
else:
|
|
raise ValueError("what kind of descriptor is value?")
|
|
|
|
return attr.evolve(orig, contents=immutabledict(new_contents))
|