1 from sqlalchemy import create_engine, pool |
|
2 |
|
3 from sqlalchemy.orm import sessionmaker |
|
4 |
|
5 |
|
6 from twisted.internet import threads |
|
7 |
|
8 from ..model.schema import Base |
|
9 |
|
10 engine = create_engine('sqlite:///:memory:', echo=True) |
|
11 |
|
12 def createDatabase(): |
|
13 Base.metadata.create_all(engine) |
|
14 |
|
15 Session = sessionmaker(bind=engine) |
|
16 |
|
17 def toThread(f): |
|
18 def wrapper(*args, **kwargs): |
|
19 return threads.deferToThread(f, *args, **kwargs) |
|
20 return wrapper |
|
21 |
|
22 |
|
23 class WithSession(): |
|
24 def __init__(self,autocommit=False): |
|
25 self.autocommit=autocommit |
|
26 |
|
27 def __enter__(self): |
|
28 self.conn = engine.connect() |
|
29 self.session = Session(bind=self.conn) |
|
30 return self.session |
|
31 |
|
32 def __exit__(self,exc_type, exc_value, traceback): |
|
33 if exc_type is None: |
|
34 if self.autocommit: |
|
35 self.session.commit() |
|
36 else: |
|
37 self.session.rollback() |
|
38 self.session.close() |
|
39 self.conn.close() |
|
40 |
|
41 class DBDefer(object): |
|
42 def __init__(self, dsn, poolclass = pool.SingletonThreadPool, *args, **kargs): |
|
43 self.engine = create_engine(dsn, poolclass=poolclass, *args, **kargs) |
|
44 |
|
45 def __call__(self, func): |
|
46 @toThread |
|
47 def wrapper(*args, **kwargs): |
|
48 session = sessionmaker(bind=self.engine)() |
|
49 try: |
|
50 return func(*args, session=session, **kwargs) |
|
51 except: |
|
52 session.rollback() |
|
53 raise |
|
54 finally: |
|
55 session.close() |
|
56 return wrapper |
|