Postgres fixes. Much more robust and actually usable now. (#6)

This commit is contained in:
reivilibre 2020-10-24 13:34:28 +01:00 committed by GitHub
parent 04241dc47f
commit 11a62d0023
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 23 deletions

View File

@ -5,6 +5,15 @@ from scone.head.recipe import Recipe, RecipeContext
from scone.head.utils import check_type
def postgres_dodgy_escape_literal(unescaped: str) -> str:
python_esc = repr(unescaped)
if python_esc[0] == '"':
return "E'" + python_esc[1:-1].replace("'", "\\'") + "'"
else:
assert python_esc[0] == "'"
return "E" + python_esc
class PostgresDatabase(Recipe):
_NAME = "pg-db"
@ -14,19 +23,25 @@ class PostgresDatabase(Recipe):
self.database_name = check_type(args.get("name"), str)
self.owner = check_type(args.get("owner"), str)
self.encoding = args.get("encoding", "utf8")
self.collate = args.get("collate", "en_GB.utf8")
self.ctype = args.get("ctype", "en_GB.utf8")
# en_GB.UTF-8 may have perf impact and needs to be installed as a locale
# with locale-gen on Ubuntu. In short, a pain.
# C or POSIX is recommended.
self.collate = args.get("collate", "C")
self.ctype = args.get("ctype", "C")
self.template = args.get("template", "template0")
def prepare(self, preparation: Preparation, head: Head) -> None:
super().prepare(preparation, head)
# todo
preparation.provides("postgres-database", self.database_name)
preparation.needs("postgres-user", self.owner)
async def cook(self, kitchen: Kitchen) -> None:
ch = await kitchen.start(PostgresTransaction("postgres"))
ch = await kitchen.start(
PostgresTransaction("postgres", use_transaction_block=False)
)
await ch.send(
(
"SELECT 1 AS count FROM pg_catalog.pg_database WHERE datname = ?;",
"SELECT 1 AS count FROM pg_catalog.pg_database WHERE datname = $1",
self.database_name,
)
)
@ -40,9 +55,9 @@ class PostgresDatabase(Recipe):
CREATE DATABASE {self.database_name}
WITH OWNER {self.owner}
ENCODING {self.encoding}
LC_COLLATE {self.collate}
LC_CTYPE {self.ctype}
TEMPLATE {self.template};
LC_COLLATE {postgres_dodgy_escape_literal(self.collate)}
LC_CTYPE {postgres_dodgy_escape_literal(self.ctype)}
TEMPLATE {postgres_dodgy_escape_literal(self.template)}
"""
await ch.send((q,))
@ -64,13 +79,13 @@ class PostgresUser(Recipe):
def prepare(self, preparation: Preparation, head: Head) -> None:
super().prepare(preparation, head)
# todo
preparation.provides("postgres-user", self.user_name)
async def cook(self, kitchen: Kitchen) -> None:
ch = await kitchen.start(PostgresTransaction("postgres"))
await ch.send(
(
"SELECT 1 AS count FROM pg_catalog.pg_user WHERE usename = ?;",
"SELECT 1 AS count FROM pg_catalog.pg_user WHERE usename = $1",
self.user_name,
)
)
@ -80,13 +95,16 @@ class PostgresUser(Recipe):
await ch.wait_close()
return
# this is close enough to Postgres escaping I believe.
escaped_password = postgres_dodgy_escape_literal(str(self.password))
q = f"""
CREATE ROLE {self.user_name}
WITH PASSWORD ?
LOGIN;
WITH PASSWORD {escaped_password}
LOGIN
"""
await ch.send((q, self.password))
await ch.send((q,))
res = await ch.recv()
if len(res) != 0:
raise RuntimeError("expected empty result set.")

View File

@ -1,28 +1,60 @@
import logging
import attr
try:
import asyncpg
except ImportError:
asyncpg = None
from scone.common.chanpro import Channel
from scone.sous import Utensil
from scone.sous.utensils import Worktop
logger = logging.getLogger(__name__)
if not asyncpg:
logger.info("asyncpg not found, install if you need Postgres support")
@attr.s(auto_attribs=True)
class PostgresTransaction(Utensil):
database: str
async def execute(self, channel: Channel, worktop: Worktop) -> None:
import asyncpg
# statements like CREATE DATABASE are not permitted in transactions.
use_transaction_block: bool = True
conn = await asyncpg.connect(database=self.database)
try:
async with conn.transaction():
while True:
query, *args = await channel.recv()
if query is None:
break
async def execute(self, channel: Channel, worktop: Worktop) -> None:
if not asyncpg:
raise RuntimeError("asyncpg is not installed.")
async def queryloop():
while True:
next_input = await channel.recv()
if next_input is None:
return
query, *args = next_input
if query is None:
break
try:
results = [
dict(record) for record in await conn.fetch(query, *args)
]
except asyncpg.PostgresError:
logger.error(
"Failed query %s with args %r", query, args, exc_info=True
)
await channel.close("Query error")
raise
await channel.send(results)
await channel.send(results)
conn = await asyncpg.connect(database=self.database)
try:
if self.use_transaction_block:
async with conn.transaction():
await queryloop()
else:
await queryloop()
finally:
await conn.close()