Postgres fixes. Much more robust and actually usable now. (#6)
This commit is contained in:
parent
04241dc47f
commit
11a62d0023
@ -5,6 +5,15 @@ from scone.head.recipe import Recipe, RecipeContext
|
|||||||
from scone.head.utils import check_type
|
from scone.head.utils import check_type
|
||||||
|
|
||||||
|
|
||||||
|
def postgres_dodgy_escape_literal(unescaped: str) -> str:
|
||||||
|
python_esc = repr(unescaped)
|
||||||
|
if python_esc[0] == '"':
|
||||||
|
return "E'" + python_esc[1:-1].replace("'", "\\'") + "'"
|
||||||
|
else:
|
||||||
|
assert python_esc[0] == "'"
|
||||||
|
return "E" + python_esc
|
||||||
|
|
||||||
|
|
||||||
class PostgresDatabase(Recipe):
|
class PostgresDatabase(Recipe):
|
||||||
_NAME = "pg-db"
|
_NAME = "pg-db"
|
||||||
|
|
||||||
@ -14,19 +23,25 @@ class PostgresDatabase(Recipe):
|
|||||||
self.database_name = check_type(args.get("name"), str)
|
self.database_name = check_type(args.get("name"), str)
|
||||||
self.owner = check_type(args.get("owner"), str)
|
self.owner = check_type(args.get("owner"), str)
|
||||||
self.encoding = args.get("encoding", "utf8")
|
self.encoding = args.get("encoding", "utf8")
|
||||||
self.collate = args.get("collate", "en_GB.utf8")
|
# en_GB.UTF-8 may have perf impact and needs to be installed as a locale
|
||||||
self.ctype = args.get("ctype", "en_GB.utf8")
|
# with locale-gen on Ubuntu. In short, a pain.
|
||||||
|
# C or POSIX is recommended.
|
||||||
|
self.collate = args.get("collate", "C")
|
||||||
|
self.ctype = args.get("ctype", "C")
|
||||||
self.template = args.get("template", "template0")
|
self.template = args.get("template", "template0")
|
||||||
|
|
||||||
def prepare(self, preparation: Preparation, head: Head) -> None:
|
def prepare(self, preparation: Preparation, head: Head) -> None:
|
||||||
super().prepare(preparation, head)
|
super().prepare(preparation, head)
|
||||||
# todo
|
preparation.provides("postgres-database", self.database_name)
|
||||||
|
preparation.needs("postgres-user", self.owner)
|
||||||
|
|
||||||
async def cook(self, kitchen: Kitchen) -> None:
|
async def cook(self, kitchen: Kitchen) -> None:
|
||||||
ch = await kitchen.start(PostgresTransaction("postgres"))
|
ch = await kitchen.start(
|
||||||
|
PostgresTransaction("postgres", use_transaction_block=False)
|
||||||
|
)
|
||||||
await ch.send(
|
await ch.send(
|
||||||
(
|
(
|
||||||
"SELECT 1 AS count FROM pg_catalog.pg_database WHERE datname = ?;",
|
"SELECT 1 AS count FROM pg_catalog.pg_database WHERE datname = $1",
|
||||||
self.database_name,
|
self.database_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -40,9 +55,9 @@ class PostgresDatabase(Recipe):
|
|||||||
CREATE DATABASE {self.database_name}
|
CREATE DATABASE {self.database_name}
|
||||||
WITH OWNER {self.owner}
|
WITH OWNER {self.owner}
|
||||||
ENCODING {self.encoding}
|
ENCODING {self.encoding}
|
||||||
LC_COLLATE {self.collate}
|
LC_COLLATE {postgres_dodgy_escape_literal(self.collate)}
|
||||||
LC_CTYPE {self.ctype}
|
LC_CTYPE {postgres_dodgy_escape_literal(self.ctype)}
|
||||||
TEMPLATE {self.template};
|
TEMPLATE {postgres_dodgy_escape_literal(self.template)}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await ch.send((q,))
|
await ch.send((q,))
|
||||||
@ -64,13 +79,13 @@ class PostgresUser(Recipe):
|
|||||||
|
|
||||||
def prepare(self, preparation: Preparation, head: Head) -> None:
|
def prepare(self, preparation: Preparation, head: Head) -> None:
|
||||||
super().prepare(preparation, head)
|
super().prepare(preparation, head)
|
||||||
# todo
|
preparation.provides("postgres-user", self.user_name)
|
||||||
|
|
||||||
async def cook(self, kitchen: Kitchen) -> None:
|
async def cook(self, kitchen: Kitchen) -> None:
|
||||||
ch = await kitchen.start(PostgresTransaction("postgres"))
|
ch = await kitchen.start(PostgresTransaction("postgres"))
|
||||||
await ch.send(
|
await ch.send(
|
||||||
(
|
(
|
||||||
"SELECT 1 AS count FROM pg_catalog.pg_user WHERE usename = ?;",
|
"SELECT 1 AS count FROM pg_catalog.pg_user WHERE usename = $1",
|
||||||
self.user_name,
|
self.user_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -80,13 +95,16 @@ class PostgresUser(Recipe):
|
|||||||
await ch.wait_close()
|
await ch.wait_close()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# this is close enough to Postgres escaping I believe.
|
||||||
|
escaped_password = postgres_dodgy_escape_literal(str(self.password))
|
||||||
|
|
||||||
q = f"""
|
q = f"""
|
||||||
CREATE ROLE {self.user_name}
|
CREATE ROLE {self.user_name}
|
||||||
WITH PASSWORD ?
|
WITH PASSWORD {escaped_password}
|
||||||
LOGIN;
|
LOGIN
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await ch.send((q, self.password))
|
await ch.send((q,))
|
||||||
res = await ch.recv()
|
res = await ch.recv()
|
||||||
if len(res) != 0:
|
if len(res) != 0:
|
||||||
raise RuntimeError("expected empty result set.")
|
raise RuntimeError("expected empty result set.")
|
||||||
|
@ -1,28 +1,60 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncpg
|
||||||
|
except ImportError:
|
||||||
|
asyncpg = None
|
||||||
|
|
||||||
from scone.common.chanpro import Channel
|
from scone.common.chanpro import Channel
|
||||||
from scone.sous import Utensil
|
from scone.sous import Utensil
|
||||||
from scone.sous.utensils import Worktop
|
from scone.sous.utensils import Worktop
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if not asyncpg:
|
||||||
|
logger.info("asyncpg not found, install if you need Postgres support")
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
class PostgresTransaction(Utensil):
|
class PostgresTransaction(Utensil):
|
||||||
database: str
|
database: str
|
||||||
|
|
||||||
async def execute(self, channel: Channel, worktop: Worktop) -> None:
|
# statements like CREATE DATABASE are not permitted in transactions.
|
||||||
import asyncpg
|
use_transaction_block: bool = True
|
||||||
|
|
||||||
conn = await asyncpg.connect(database=self.database)
|
async def execute(self, channel: Channel, worktop: Worktop) -> None:
|
||||||
try:
|
if not asyncpg:
|
||||||
async with conn.transaction():
|
raise RuntimeError("asyncpg is not installed.")
|
||||||
|
|
||||||
|
async def queryloop():
|
||||||
while True:
|
while True:
|
||||||
query, *args = await channel.recv()
|
next_input = await channel.recv()
|
||||||
|
if next_input is None:
|
||||||
|
return
|
||||||
|
query, *args = next_input
|
||||||
if query is None:
|
if query is None:
|
||||||
break
|
break
|
||||||
|
try:
|
||||||
results = [
|
results = [
|
||||||
dict(record) for record in await conn.fetch(query, *args)
|
dict(record) for record in await conn.fetch(query, *args)
|
||||||
]
|
]
|
||||||
|
except asyncpg.PostgresError:
|
||||||
|
logger.error(
|
||||||
|
"Failed query %s with args %r", query, args, exc_info=True
|
||||||
|
)
|
||||||
|
await channel.close("Query error")
|
||||||
|
raise
|
||||||
|
|
||||||
await channel.send(results)
|
await channel.send(results)
|
||||||
|
|
||||||
|
conn = await asyncpg.connect(database=self.database)
|
||||||
|
try:
|
||||||
|
if self.use_transaction_block:
|
||||||
|
async with conn.transaction():
|
||||||
|
await queryloop()
|
||||||
|
else:
|
||||||
|
await queryloop()
|
||||||
finally:
|
finally:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user