from sqlalchemy import create_engine, pool
from sqlalchemy.orm import sessionmaker
from twisted.internet import threads
from ..model.schema import Base
engine = create_engine('sqlite:///:memory:', echo=True)
def createDatabase():
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
def toThread(f):
def wrapper(*args, **kwargs):
return threads.deferToThread(f, *args, **kwargs)
return wrapper
class WithSession():
def __init__(self,autocommit=False):
self.autocommit=autocommit
def __enter__(self):
self.conn = engine.connect()
self.session = Session(bind=self.conn)
return self.session
def __exit__(self,exc_type, exc_value, traceback):
if exc_type is None:
if self.autocommit:
self.session.commit()
else:
self.session.rollback()
self.session.close()
self.conn.close()
class DBDefer(object):
def __init__(self, dsn, poolclass = pool.SingletonThreadPool, *args, **kargs):
self.engine = create_engine(dsn, poolclass=poolclass, *args, **kargs)
def __call__(self, func):
@toThread
def wrapper(*args, **kwargs):
session = sessionmaker(bind=self.engine)()
try:
return func(*args, session=session, **kwargs)
except:
session.rollback()
raise
finally:
session.close()
return wrapper