]> jfr.im git - erebus.git/commitdiff
prep for alternate databases
authorJohn Runyon <redacted>
Wed, 8 May 2024 07:26:13 +0000 (01:26 -0600)
committerJohn Runyon <redacted>
Wed, 8 May 2024 07:29:42 +0000 (01:29 -0600)
erebus.py

index cd85e3e3b238ba345e3d8ff8b3300439ec0b5d5e..1a06ed3b5f6714b8136291465518e3f4613fa373 100644 (file)
--- a/erebus.py
+++ b/erebus.py
@@ -6,7 +6,7 @@
 
 from __future__ import print_function
 
 
 from __future__ import print_function
 
-import os, sys, select, MySQLdb, MySQLdb.cursors, time, traceback, random, gc
+import os, sys, select, time, traceback, random, gc
 import bot, config, ctlmod
 
 class Erebus(object): #singleton to pass around
 import bot, config, ctlmod
 
 class Erebus(object): #singleton to pass around
@@ -209,11 +209,11 @@ class Erebus(object): #singleton to pass around
                                return curs
                        else:
                                return res
                                return curs
                        else:
                                return res
-               except MySQLdb.DataError as e:
-                       self.log("[SQL]", ".", "MySQL DataError: %r" % (e))
+               except db_api.DataError as e:
+                       self.log("[SQL]", ".", "DB DataError: %r" % (e))
                        return False
                        return False
-               except MySQLdb.MySQLError as e:
-                       self.log("[SQL]", "!", "MySQL error! %r" % (e))
+               except db_api.Error as e:
+                       self.log("[SQL]", "!", "DB error! %r" % (e))
                        if not noretry:
                                dbsetup()
                                return self.query(*args, noretry=True, **kwargs)
                        if not noretry:
                                dbsetup()
                                return self.query(*args, noretry=True, **kwargs)
@@ -357,9 +357,19 @@ class Erebus(object): #singleton to pass around
 def dbsetup():
        main.db = None
        main.dbs = []
 def dbsetup():
        main.db = None
        main.dbs = []
+       dbtype = cfg.get('erebus', 'dbtype', 'mysql')
+       if dbtype == 'mysql':
+               _dbsetup_mysql()
+       else:
+               main.log('*', '!', 'Unknown dbtype in config: %s' % (dbtype))
+
+def _dbsetup_mysql():
+       global db_api
+       import MySQLdb as db_api, MySQLdb.cursors
        for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
        for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
-               main.dbs.append(MySQLdb.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
-       main.db = MySQLdb.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
+               main.dbs.append(db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
+       main.db = db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
+
 
 def setup():
        global cfg, main
 
 def setup():
        global cfg, main