diff --git a/scone/default/recipes/mysql.py b/scone/default/recipes/mysql.py new file mode 100644 index 0000000..5f89da7 --- /dev/null +++ b/scone/default/recipes/mysql.py @@ -0,0 +1,149 @@ +# Copyright 2020, Olivier 'reivilibre'. +# +# This file is part of Scone. +# +# Scone is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Scone is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Scone. If not, see . +from typing import List + +from scone.default.utensils.db_utensils import PostgresTransaction, MysqlTransaction +from scone.head.head import Head +from scone.head.kitchen import Kitchen, Preparation +from scone.head.recipe import Recipe, RecipeContext +from scone.head.utils import check_type, check_type_opt + + +def mysql_dodgy_escape_literal(unescaped: str) -> str: + python_esc = repr(unescaped) + if python_esc[0] == '"': + return "'" + python_esc[1:-1].replace("'", "\\'") + "'" + else: + assert python_esc[0] == "'" + return python_esc + + +def mysql_dodgy_escape_username(unescaped: str) -> str: + parts = unescaped.split("@") + if len(parts) != 2: + raise ValueError(f"{unescaped!r} is not a valid sconified mysql user name.") + return mysql_dodgy_escape_literal(parts[0]) + "@" + mysql_dodgy_escape_literal(parts[1]) + + +class MysqlDatabase(Recipe): + _NAME = "mysql-db" + + def __init__(self, recipe_context: RecipeContext, args: dict, head): + super().__init__(recipe_context, args, head) + + self.database_name = check_type(args.get("name"), str) + self.charset = args.get("charset", "utf8mb4") + self.collate = args.get("collate", "utf8mb4_unicode_ci") + self.grant_all_to = check_type_opt(args.get("grant_all_to"), List[str]) + + def prepare(self, preparation: Preparation, head: Head) -> None: + super().prepare(preparation, head) + preparation.provides("mysql-database", self.database_name) + if self.grant_all_to: + for user in self.grant_all_to: + preparation.needs("mysql-user", user) + + async def cook(self, kitchen: Kitchen) -> None: + ch = await kitchen.start(MysqlTransaction("mysql", "root", unix_socket=True)) + await ch.send( + ( + "SHOW DATABASES LIKE %s", + self.database_name, + ) + ) + dbs = await ch.recv() + if len(dbs) > 0: + await ch.send(None) + await ch.wait_close() + return + + q = f""" + CREATE DATABASE {self.database_name} + CHARACTER SET = {mysql_dodgy_escape_literal(self.charset)} + COLLATE = {mysql_dodgy_escape_literal(self.collate)} + """ + + await ch.send((q,)) + res = await ch.recv() + if len(res) != 0: + raise RuntimeError("expected empty result set.") + + if self.grant_all_to: + for user in self.grant_all_to: + q = f""" + GRANT ALL PRIVILEGES ON {self.database_name}.* + TO {mysql_dodgy_escape_username(user)} + """ + await ch.send((q,)) + res = await ch.recv() + if len(res) != 0: + raise RuntimeError("expected empty result set.") + + q = f""" + FLUSH PRIVILEGES + """ + await ch.send((q,)) + res = await ch.recv() + if len(res) != 0: + raise RuntimeError("expected empty result set.") + + await ch.send(None) + await ch.wait_close() + + +class MysqlUser(Recipe): + _NAME = "mysql-user" + + def __init__(self, recipe_context: RecipeContext, args: dict, head): + super().__init__(recipe_context, args, head) + + self.user_name = check_type(args.get("name"), str) + self.password = check_type(args.get("password"), str) + + def prepare(self, preparation: Preparation, head: Head) -> None: + super().prepare(preparation, head) + preparation.provides("mysql-user", self.user_name) + + async def cook(self, kitchen: Kitchen) -> None: + ch = await kitchen.start(MysqlTransaction("mysql", "root", unix_socket=True)) + await ch.send( + ( + "SELECT 1 AS count FROM mysql.user " + "WHERE CONCAT(user, '@', host) = %s", + self.user_name, + ) + ) + dbs = await ch.recv() + if len(dbs) > 0 and dbs[0]["count"] == 1: + await ch.send(None) + await ch.wait_close() + return + + # this is close enough to MySQL escaping I believe. + escaped_password = mysql_dodgy_escape_literal(str(self.password)) + + q = f""" + CREATE USER {mysql_dodgy_escape_username(self.user_name)} + IDENTIFIED BY {escaped_password} + """ + + await ch.send((q,)) + res = await ch.recv() + if len(res) != 0: + raise RuntimeError("expected empty result set.") + await ch.send(None) + await ch.wait_close() diff --git a/scone/default/utensils/db_utensils.py b/scone/default/utensils/db_utensils.py index 6969575..967ecac 100644 --- a/scone/default/utensils/db_utensils.py +++ b/scone/default/utensils/db_utensils.py @@ -24,6 +24,11 @@ try: except ImportError: asyncpg = None +try: + from mysql import connector as mysql_connector +except ImportError: + mysql_connector = None + from scone.common.chanpro import Channel from scone.sous import Utensil from scone.sous.utensils import Worktop @@ -75,3 +80,54 @@ class PostgresTransaction(Utensil): await queryloop() finally: await conn.close() + + +@attr.s(auto_attribs=True) +class MysqlTransaction(Utensil): + database: str + user: str + unix_socket: bool = False + + async def execute(self, channel: Channel, worktop: Worktop) -> None: + if not mysql_connector: + raise RuntimeError("mysql-connector-python is not installed.") + + async def queryloop(): + while True: + next_input = await channel.recv() + if next_input is None: + return + query, *args = next_input + if query is None: + break + try: + cur.execute(query, tuple(args)) + + if conn.unread_result: + names = cur.column_names + results = [ + dict(zip(names, rectuple)) for rectuple in cur.fetchall() + ] + else: + results = [] + except mysql_connector.errors.Error: + logger.error( + "Failed query %s with args %r", query, args, exc_info=True + ) + await channel.close("Query error") + raise + + await channel.send(results) + + # TODO(perf): make async + + unix_socket = "/var/run/mysqld/mysqld.sock" if self.unix_socket else None + + conn = mysql_connector.connect(database=self.database, user=self.user, unix_socket=unix_socket) + cur = conn.cursor() + try: + await queryloop() + # autocommit disabled in this mode by default + conn.commit() + finally: + conn.close() diff --git a/setup.py b/setup.py index af41020..ac50ce5 100644 --- a/setup.py +++ b/setup.py @@ -34,9 +34,9 @@ REQUIRED = [ EX_SOUS_BASE = [] EX_SOUS_PG = ["asyncpg"] +EX_SOUS_MYSQL = ["mysql-connector-python"] -EX_SOUS_ALL = EX_SOUS_BASE + EX_SOUS_PG - +EX_SOUS_ALL = EX_SOUS_BASE + EX_SOUS_PG + EX_SOUS_MYSQL # What packages are optional? @@ -54,6 +54,7 @@ EXTRAS = { "sous": EX_SOUS_ALL, "sous-core": EX_SOUS_BASE, "sous-pg": EX_SOUS_PG, + "sous-mysql": EX_SOUS_MYSQL, "docker": ["docker"] # TODO do this more properly if we can... }