iro/test_helpers/dbtestcase.py
branchdevel
changeset 230 448dd8d36839
parent 219 4e9d79c35088
child 241 546316b0b09c
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/iro/test_helpers/dbtestcase.py	Sun Mar 18 14:05:11 2012 +0100
@@ -0,0 +1,78 @@
+from twisted.trial import unittest
+
+from sqlalchemy import create_engine, pool
+from tempfile import mkdtemp
+
+from shutil import rmtree
+import atexit 
+from ngdatabase.mysql import Server, createConfig, Database
+
+from iro.model import setEngine, setPool
+from iro.model.utils import WithSession
+from iro.model.schema import Base
+
+from iro.controller.pool import dbPool
+
+class DBTestCase(unittest.TestCase):
+    '''a TestCase with DB connection
+    you have to set self.session manually in setUp'''
+    def __init__(self,*args,**kwargs):
+        unittest.TestCase.__init__(self,*args,**kwargs)
+        self.engine = md.engine
+
+    def setUp(self):
+        md.setUp()
+
+    def tearDown(self):
+        self.__cleanDB()
+    
+    def session(self,autocommit=True):
+        '''returns a WithSession'''
+        return WithSession(self.engine,autocommit)
+    
+    def __cleanDB(self):
+        '''cleaning database'''
+        with self.session() as session:
+            for table in reversed(Base.metadata.sorted_tables):
+                session.execute(table.delete())
+
+
+class SampleDatabase(Database):
+    def createPassword(self):
+        self.password="test"
+        return self.password
+
+class ModuleData(object):
+    def __init__(self):
+        self.create()
+
+    def close(self):
+        if self.valid:
+            self.server.stop()
+            rmtree(self.tdir)
+            self.valid= False
+
+    def create(self):
+        self.tdir = mkdtemp(prefix='iro-mysql-')
+        self.server = Server('%s/my.cnf'%self.tdir)
+        self.db = SampleDatabase("test","test",'%s/my.cnf'%self.tdir)
+        self.dburl = 'mysql://test:test@localhost/test?unix_socket=%s/socket'%self.tdir
+        self.engine = create_engine(self.dburl, 
+                poolclass = pool.SingletonThreadPool,  pool_size=dbPool.maxthreads, )#echo=True)
+        self.valid = False
+
+    def setUp(self):
+        if not self.valid:
+            with open('%s/my.cnf'%self.tdir,'w') as cnf:
+                cnf.write(createConfig(self.tdir))
+            self.server.create()
+            self.server.start()
+            self.db.create()
+            Base.metadata.create_all(self.engine)
+            setEngine(self.engine)
+            setPool(dbPool)
+            self.valid = True
+    
+
+md=ModuleData()
+atexit.register(md.close)