From 11a62d00234bc166d0d27e176afec67579234b53 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Sat, 24 Oct 2020 13:34:28 +0100 Subject: [PATCH] Postgres fixes. Much more robust and actually usable now. (#6) --- scone/default/recipes/postgres.py | 44 ++++++++++++++++------- scone/default/utensils/db_utensils.py | 52 +++++++++++++++++++++------ 2 files changed, 73 insertions(+), 23 deletions(-) diff --git a/scone/default/recipes/postgres.py b/scone/default/recipes/postgres.py index 4d40b75..601cdc8 100644 --- a/scone/default/recipes/postgres.py +++ b/scone/default/recipes/postgres.py @@ -5,6 +5,15 @@ from scone.head.recipe import Recipe, RecipeContext from scone.head.utils import check_type +def postgres_dodgy_escape_literal(unescaped: str) -> str: + python_esc = repr(unescaped) + if python_esc[0] == '"': + return "E'" + python_esc[1:-1].replace("'", "\\'") + "'" + else: + assert python_esc[0] == "'" + return "E" + python_esc + + class PostgresDatabase(Recipe): _NAME = "pg-db" @@ -14,19 +23,25 @@ class PostgresDatabase(Recipe): self.database_name = check_type(args.get("name"), str) self.owner = check_type(args.get("owner"), str) self.encoding = args.get("encoding", "utf8") - self.collate = args.get("collate", "en_GB.utf8") - self.ctype = args.get("ctype", "en_GB.utf8") + # en_GB.UTF-8 may have perf impact and needs to be installed as a locale + # with locale-gen on Ubuntu. In short, a pain. + # C or POSIX is recommended. + self.collate = args.get("collate", "C") + self.ctype = args.get("ctype", "C") self.template = args.get("template", "template0") def prepare(self, preparation: Preparation, head: Head) -> None: super().prepare(preparation, head) - # todo + preparation.provides("postgres-database", self.database_name) + preparation.needs("postgres-user", self.owner) async def cook(self, kitchen: Kitchen) -> None: - ch = await kitchen.start(PostgresTransaction("postgres")) + ch = await kitchen.start( + PostgresTransaction("postgres", use_transaction_block=False) + ) await ch.send( ( - "SELECT 1 AS count FROM pg_catalog.pg_database WHERE datname = ?;", + "SELECT 1 AS count FROM pg_catalog.pg_database WHERE datname = $1", self.database_name, ) ) @@ -40,9 +55,9 @@ class PostgresDatabase(Recipe): CREATE DATABASE {self.database_name} WITH OWNER {self.owner} ENCODING {self.encoding} - LC_COLLATE {self.collate} - LC_CTYPE {self.ctype} - TEMPLATE {self.template}; + LC_COLLATE {postgres_dodgy_escape_literal(self.collate)} + LC_CTYPE {postgres_dodgy_escape_literal(self.ctype)} + TEMPLATE {postgres_dodgy_escape_literal(self.template)} """ await ch.send((q,)) @@ -64,13 +79,13 @@ class PostgresUser(Recipe): def prepare(self, preparation: Preparation, head: Head) -> None: super().prepare(preparation, head) - # todo + preparation.provides("postgres-user", self.user_name) async def cook(self, kitchen: Kitchen) -> None: ch = await kitchen.start(PostgresTransaction("postgres")) await ch.send( ( - "SELECT 1 AS count FROM pg_catalog.pg_user WHERE usename = ?;", + "SELECT 1 AS count FROM pg_catalog.pg_user WHERE usename = $1", self.user_name, ) ) @@ -80,13 +95,16 @@ class PostgresUser(Recipe): await ch.wait_close() return + # this is close enough to Postgres escaping I believe. + escaped_password = postgres_dodgy_escape_literal(str(self.password)) + q = f""" CREATE ROLE {self.user_name} - WITH PASSWORD ? - LOGIN; + WITH PASSWORD {escaped_password} + LOGIN """ - await ch.send((q, self.password)) + await ch.send((q,)) res = await ch.recv() if len(res) != 0: raise RuntimeError("expected empty result set.") diff --git a/scone/default/utensils/db_utensils.py b/scone/default/utensils/db_utensils.py index 3d00a82..ba89fb8 100644 --- a/scone/default/utensils/db_utensils.py +++ b/scone/default/utensils/db_utensils.py @@ -1,28 +1,60 @@ +import logging + import attr +try: + import asyncpg +except ImportError: + asyncpg = None + from scone.common.chanpro import Channel from scone.sous import Utensil from scone.sous.utensils import Worktop +logger = logging.getLogger(__name__) + +if not asyncpg: + logger.info("asyncpg not found, install if you need Postgres support") + @attr.s(auto_attribs=True) class PostgresTransaction(Utensil): database: str - async def execute(self, channel: Channel, worktop: Worktop) -> None: - import asyncpg + # statements like CREATE DATABASE are not permitted in transactions. + use_transaction_block: bool = True - conn = await asyncpg.connect(database=self.database) - try: - async with conn.transaction(): - while True: - query, *args = await channel.recv() - if query is None: - break + async def execute(self, channel: Channel, worktop: Worktop) -> None: + if not asyncpg: + raise RuntimeError("asyncpg is not installed.") + + async def queryloop(): + while True: + next_input = await channel.recv() + if next_input is None: + return + query, *args = next_input + if query is None: + break + try: results = [ dict(record) for record in await conn.fetch(query, *args) ] + except asyncpg.PostgresError: + logger.error( + "Failed query %s with args %r", query, args, exc_info=True + ) + await channel.close("Query error") + raise - await channel.send(results) + await channel.send(results) + + conn = await asyncpg.connect(database=self.database) + try: + if self.use_transaction_block: + async with conn.transaction(): + await queryloop() + else: + await queryloop() finally: await conn.close()