You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
isso/isso/db.py

260 lines
7.8 KiB

# -*- encoding: utf-8 -*-
from __future__ import unicode_literals
import logging
import sqlite3
import binascii
import operator
import threading
import os.path
from collections import defaultdict
from sqlalchemy import Table, Column, MetaData, create_engine
from sqlalchemy import ForeignKey, Integer, Float, String, LargeBinary
from sqlalchemy.sql import select
logger = logging.getLogger("isso")
class Adapter(object):
MAX_VERSION = 3
def __init__(self, db):
self.engine = create_engine(db, echo=False)
self.metadata = MetaData()
self.comments = Table("comments", self.metadata,
Column("id", Integer, primary_key=True),
Column("parent", Integer),
Column("thread", None, ForeignKey("threads.id")),
Column("created", Float),
Column("modified", Float),
Column("mode", Integer),
Column("remote_addr", String(48)), # XXX use a BigInt
Column("text", String(65535)),
Column("author", String(255)),
Column("email", String(255)),
Column("website", String(255)),
Column("likes", Integer),
Column("dislikes", Integer),
Column("voters", LargeBinary(256)))
self.threads = Table("threads", self.metadata,
Column("id", Integer, primary_key=True),
Column("uri", String(255), unique=True),
Column("title", String(255)))
preferences = Table("preferences", self.metadata,
Column("key", String(255), primary_key=True),
Column("value", String(255)))
self.metadata.create_all(self.engine)
self.preferences = Preferences(self.engine, preferences)
@property
def transaction(self):
return self.engine.begin()
class Preferences(object):
"""A simple key-value store using SQL.
"""
defaults = [
("session-key", binascii.b2a_hex(os.urandom(24))),
]
def __init__(self, engine, preferences):
self.engine = engine
self.preferences = preferences
for (key, value) in Preferences.defaults:
if self.get(key) is None:
self.set(key, value)
def get(self, key, default=None):
rv = self.engine.execute(
select([self.preferences.c.value])
.where(self.preferences.c.key == key)).fetchone()
if rv is None:
return default
return rv[0]
def set(self, key, value):
self.engine.execute(
self.preferences.insert().values(
key=key, value=value))
class Transaction(object):
"""A context manager to lock the database across processes and automatic
rollback on failure. On success, reset the isolation level back to normal.
SQLite3's DEFERRED (default) transaction mode causes database corruption
for concurrent writes to the database from multiple processes. IMMEDIATE
ensures a global write lock, but reading is still possible.
"""
def __init__(self, con):
self.con = con
def __enter__(self):
self._orig = self.con.isolation_level
self.con.isolation_level = "IMMEDIATE"
self.con.execute("BEGIN IMMEDIATE")
return self.con
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type:
self.con.rollback()
else:
self.con.commit()
finally:
self.con.isolation_level = self._orig
class SQLite3(object):
"""SQLite3 connection pool across multiple threads. Implementation idea
from `Peewee <https://github.com/coleifer/peewee>`_.
"""
def __init__(self, db):
self.db = os.path.expanduser(db)
self.lock = threading.Lock()
self.local = threading.local()
def connect(self):
with self.lock:
self.local.conn = sqlite3.connect(self.db, isolation_level=None)
def close(self):
with self.lock:
self.local.conn.close()
self.local.conn = None
def execute(self, sql, args=()):
if isinstance(sql, (list, tuple)):
sql = ' '.join(sql)
return self.connection.execute(sql, args)
@property
def connection(self):
if not hasattr(self.local, 'conn') or self.local.conn is None:
self.connect()
return self.local.conn
@property
def transaction(self):
return Transaction(self.connection)
@property
def total_changes(self):
return self.connection.total_changes
class Foo(object):
"""DB-dependend wrapper around SQLite3.
Runs migration if `user_version` is older than `MAX_VERSION` and register
a trigger for automated orphan removal.
"""
MAX_VERSION = 3
def __init__(self, conn, conf):
self.connection = conn
self.conf = conf
rv = self.execute([
"SELECT name FROM sqlite_master"
" WHERE type='table' AND name IN ('threads', 'comments', 'preferences')"]
).fetchone()
self.preferences = Preferences(self)
self.threads = Threads(self)
self.comments = Comments(self)
self.guard = Guard(self)
if rv is None:
self.execute("PRAGMA user_version = %i" % Adapter.MAX_VERSION)
else:
self.migrate(to=Adapter.MAX_VERSION)
self.execute([
'CREATE TRIGGER IF NOT EXISTS remove_stale_threads',
'AFTER DELETE ON comments',
'BEGIN',
' DELETE FROM threads WHERE id NOT IN (SELECT tid FROM comments);',
'END'])
@property
def version(self):
return self.execute("PRAGMA user_version").fetchone()[0]
def migrate(self, to):
if self.version >= to:
return
logger.info("migrate database from version %i to %i", self.version, to)
# re-initialize voters blob due a bug in the bloomfilter signature
# which added older commenter's ip addresses to the current voters blob
if self.version == 0:
from isso.utils import Bloomfilter
bf = buffer(Bloomfilter(iterable=["127.0.0.0"]).array)
with self.connection.transaction as con:
con.execute('UPDATE comments SET voters=?', (bf, ))
con.execute('PRAGMA user_version = 1')
logger.info("%i rows changed", con.total_changes)
# move [general] session-key to database
if self.version == 1:
with self.connection.transaction as con:
if self.conf.has_option("general", "session-key"):
con.execute('UPDATE preferences SET value=? WHERE key=?', (
self.conf.get("general", "session-key"), "session-key"))
con.execute('PRAGMA user_version = 2')
logger.info("%i rows changed", con.total_changes)
# limit max. nesting level to 1
if self.version == 2:
first = lambda rv: list(map(operator.itemgetter(0), rv))
with self.connection.transaction as con:
top = first(con.execute("SELECT id FROM comments WHERE parent IS NULL").fetchall())
flattened = defaultdict(set)
for id in top:
ids = [id, ]
while ids:
rv = first(con.execute("SELECT id FROM comments WHERE parent=?", (ids.pop(), )))
ids.extend(rv)
flattened[id].update(set(rv))
for id in flattened.keys():
for n in flattened[id]:
con.execute("UPDATE comments SET parent=? WHERE id=?", (id, n))
con.execute('PRAGMA user_version = 3')
logger.info("%i rows changed", con.total_changes)
def execute(self, sql, args=()):
return self.connection.execute(sql, args)