]> jfr.im git - erebus.git/blame - erebus.py
add sqlite database support
[erebus.git] / erebus.py
CommitLineData
b25d4368 1#!/usr/bin/python
4477123d 2# vim: fileencoding=utf-8
b25d4368 3
931c88a4 4# Erebus IRC bot - Author: John Runyon
5# main startup code
6
a28e2ae9 7from __future__ import print_function
8
5b8f6176 9import os, sys, select, time, traceback, random, gc
db50981b 10import bot, config, ctlmod
b25d4368 11
a8553c45 12class Erebus(object): #singleton to pass around
134c1193 13 APIVERSION = 0
a76c4bd8 14 RELEASE = 0
15
49a455aa 16 bots = {}
17 fds = {}
e4a4c762 18 numhandlers = {}
49a455aa 19 msghandlers = {}
9557ee54 20 chanhandlers = {}
e8885384 21 exceptionhandlers = [] # list of (Exception_class, handler_function) tuples
b2a896c8 22 users = {}
23 chans = {}
49a455aa 24
25 class User(object):
49a455aa 26 def __init__(self, nick, auth=None):
27 self.nick = nick
25bf8fc5
JR
28 if auth is None:
29 self.auth = None
30 else:
31 self.auth = auth.lower()
676b2a85 32 self.checklevel()
a4eacae2 33
5477b368 34 self.chans = []
35
5f5d669f
JR
36 def bind_bot(self, bot):
37 return main._BoundUser(self, bot)
38
e80bf7de 39 def msg(self, *args, **kwargs):
e64ac4a0 40 main.randbot().msg(self, *args, **kwargs)
2bb267e0 41 def slowmsg(self, *args, **kwargs):
42 main.randbot().slowmsg(self, *args, **kwargs)
e64ac4a0 43 def fastmsg(self, *args, **kwargs):
44 main.randbot().fastmsg(self, *args, **kwargs)
e80bf7de 45
b2a896c8 46 def isauthed(self):
47 return self.auth is not None
48
49a455aa 49 def authed(self, auth):
de89db13 50 if auth == '0': self.auth = None
51 else: self.auth = auth.lower()
49a455aa 52 self.checklevel()
a4eacae2 53
676b2a85 54 def checklevel(self):
55 if self.auth is None:
839d2b35 56 self.glevel = -1
676b2a85 57 else:
2729abc8 58 c = main.query("SELECT level FROM users WHERE auth = %s", (self.auth,))
59 if c:
4fa1118b 60 row = c.fetchone()
61 if row is not None:
62 self.glevel = row['level']
63 else:
64 self.glevel = 0
676b2a85 65 else:
839d2b35 66 self.glevel = 0
67 return self.glevel
43b98e4e 68
25bf8fc5
JR
69 def setlevel(self, level, savetodb=True):
70 if savetodb:
71 if level != 0:
72 c = main.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self.auth, level))
73 else:
74 c = main.query("DELETE FROM users WHERE auth = %s", (self.auth,))
75 if c == 0: # no rows affected
76 c = True # is fine
77 if c:
78 self.glevel = level
79 return True
80 else:
81 return False
82 else:
83 self.glevel = level
84 return True
85
5477b368 86 def join(self, chan):
84b7c247 87 if chan not in self.chans: self.chans.append(chan)
5477b368 88 def part(self, chan):
3d724d3a 89 try:
90 self.chans.remove(chan)
91 except: pass
d53d073b 92 return len(self.chans) == 0
c695f740 93 def quit(self):
d53d073b 94 pass
124f114c 95 def nickchange(self, newnick):
e80bf7de 96 self.nick = newnick
5477b368 97
49a455aa 98 def __str__(self): return self.nick
71ef8273 99 def __repr__(self): return "<User %r (%d)>" % (self.nick, self.glevel)
43b98e4e 100
5f5d669f
JR
101 class _BoundUser(object):
102 def __init__(self, user, bot):
103 self.__dict__['_bound_user'] = user
104 self.__dict__['_bound_bot'] = bot
105 def __getattr__(self, name):
106 return getattr(self._bound_user, name)
107 def __setattr__(self, name, value):
108 setattr(self._bound_user, name, value)
109 def msg(self, *args, **kwargs):
110 self._bound_bot.msg(self._bound_user, *args, **kwargs)
111 def slowmsg(self, *args, **kwargs):
112 self._bound_bot.slowmsg(self._bound_user, *args, **kwargs)
113 def fastmsg(self, *args, **kwargs):
114 self._bound_bot.fastmsg(self._bound_user, *args, **kwargs)
115 def __repr__(self): return "<_BoundUser %r %r>" % (self._bound_user, self._bound_bot)
116
49a455aa 117 class Channel(object):
586997a7 118 def __init__(self, name, bot):
49a455aa 119 self.name = name
5477b368 120 self.bot = bot
586997a7 121 self.levels = {}
5477b368 122
123 self.users = []
124 self.voices = []
125 self.ops = []
a4eacae2 126
2729abc8 127 c = main.query("SELECT user, level FROM chusers WHERE chan = %s", (self.name,))
128 if c:
586997a7 129 row = c.fetchone()
4fa1118b 130 while row is not None:
131 self.levels[row['user']] = row['level']
132 row = c.fetchone()
586997a7 133
134
fd52fb16 135 def msg(self, *args, **kwargs):
e64ac4a0 136 self.bot.msg(self, *args, **kwargs)
2bb267e0 137 def slowmsg(self, *args, **kwargs):
138 self.bot.slowmsg(self, *args, **kwargs)
e64ac4a0 139 def fastmsg(self, *args, **kwargs):
140 self.bot.fastmsg(self, *args, **kwargs)
fd52fb16 141
586997a7 142 def levelof(self, auth):
a9ce8d6a 143 if auth is None:
144 return 0
586997a7 145 auth = auth.lower()
146 if auth in self.levels:
147 return self.levels[auth]
148 else:
149 return 0
150
151 def setlevel(self, auth, level, savetodb=True):
152 auth = auth.lower()
153 if savetodb:
2729abc8 154 c = main.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self.name, auth, level))
155 if c:
4fa1118b 156 self.levels[auth] = level
157 return True
158 else:
159 return False
25bf8fc5
JR
160 else:
161 self.levels[auth] = level
162 return True
586997a7 163
49a455aa 164 def userjoin(self, user, level=None):
165 if user not in self.users: self.users.append(user)
166 if level == 'op' and user not in self.ops: self.ops.append(user)
167 if level == 'voice' and user not in self.voices: self.voices.append(user)
168 def userpart(self, user):
169 if user in self.ops: self.ops.remove(user)
170 if user in self.voices: self.voices.remove(user)
171 if user in self.users: self.users.remove(user)
a4eacae2 172
49a455aa 173 def userop(self, user):
174 if user in self.users and user not in self.ops: self.ops.append(user)
175 def uservoice(self, user):
176 if user in self.users and user not in self.voices: self.voices.append(user)
177 def userdeop(self, user):
178 if user in self.ops: self.ops.remove(user)
179 def userdevoice(self, user):
180 if user in self.voices: self.voices.remove(user)
181
182 def __str__(self): return self.name
183 def __repr__(self): return "<Channel %r>" % (self.name)
184
c0eee1b4 185 def __init__(self, cfg):
2a44c0cd 186 self.mustquit = None
fc16e064 187 self.starttime = time.time()
c0eee1b4 188 self.cfg = cfg
189 self.trigger = cfg.trigger
fd96a423 190 if os.name == "posix":
191 self.potype = "poll"
192 self.po = select.poll()
193 else: # f.e. os.name == "nt" (Windows)
194 self.potype = "select"
195 self.fdlist = []
49a455aa 196
6b4ba0b6
JR
197 def query(self, sql, parameters=[], noretry=False):
198 # Callers use %s-style (paramstyle='format') placeholders in queries.
199 # There's no provision for a literal '%s' present inside the query; stuff it in a parameter instead.
200 if db_api.paramstyle == 'format' or db_api.paramstyle == 'pyformat': # mysql, postgresql
201 # psycopg actually asks for a mapping with %(name)s style (pyformat) but it will accept %s style.
202 pass
203 elif db_api.paramstyle == 'qmark': # sqlite doesn't like %s style.
204 parameters = [str(p) for p in parameters]
205 sql = sql.replace('%s', '?') # hope that wasn't literal, oopsie
206
207 log_noretry = ''
208 if noretry:
209 log_noretry = ', noretry=True'
210 self.log("[SQL]", "?", "query(%r, %r%s)" % (sql, parameters, log_noretry))
2729abc8 211
2729abc8 212 try:
213 curs = self.db.cursor()
6b4ba0b6 214 res = curs.execute(sql, parameters)
2729abc8 215 if res:
216 return curs
217 else:
218 return res
5b8f6176
JR
219 except db_api.DataError as e:
220 self.log("[SQL]", ".", "DB DataError: %r" % (e))
2c58b913 221 return False
5b8f6176
JR
222 except db_api.Error as e:
223 self.log("[SQL]", "!", "DB error! %r" % (e))
c728e51c 224 if not noretry:
2729abc8 225 dbsetup()
6b4ba0b6 226 return self.query(sql, parameters, noretry=True)
2729abc8 227 else:
228 raise e
229
c728e51c 230 def querycb(self, cb, *args, **kwargs):
6b4ba0b6 231 # TODO this should either get thrown out with getdb()/returndb(), or else be adjusted to make use of it.
c728e51c 232 def run_query():
233 cb(self.query(*args, **kwargs))
234 threading.Thread(target=run_query).start()
235
0af282c6 236 def newbot(self, nick, user, bind, authname, authpass, server, port, realname):
49a455aa 237 if bind is None: bind = ''
0af282c6 238 obj = bot.Bot(self, nick, user, bind, authname, authpass, server, port, realname)
49a455aa 239 self.bots[nick.lower()] = obj
a4eacae2 240
49a455aa 241 def newfd(self, obj, fileno):
49a455aa 242 self.fds[fileno] = obj
fd96a423 243 if self.potype == "poll":
244 self.po.register(fileno, select.POLLIN)
245 elif self.potype == "select":
246 self.fdlist.append(fileno)
9d44d267
JR
247 def delfd(self, fileno):
248 del self.fds[fileno]
249 if self.potype == "poll":
250 self.po.unregister(fileno)
251 elif self.potype == "select":
252 self.fdlist.remove(fileno)
a4eacae2 253
43b98e4e 254 def bot(self, name): #get Bot() by name (nick)
49a455aa 255 return self.bots[name.lower()]
43b98e4e 256 def fd(self, fileno): #get Bot() by fd/fileno
49a455aa 257 return self.fds[fileno]
8af0407d 258 def randbot(self): #get Bot() randomly
71ef8273 259 return self.bots[random.choice(list(self.bots.keys()))]
49a455aa 260
f6386fa7 261 def user(self, _nick, send_who=False, create=True):
c695f740 262 nick = _nick.lower()
f6386fa7
JR
263
264 if send_who and (nick not in self.users or not self.users[nick].isauthed()):
265 self.randbot().conn.send("WHO %s n%%ant,1" % (nick))
266
b2a896c8 267 if nick in self.users:
268 return self.users[nick]
3d724d3a 269 elif create:
c695f740 270 user = self.User(_nick)
b2a896c8 271 self.users[nick] = user
272 return user
3d724d3a 273 else:
274 return None
5477b368 275 def channel(self, name): #get Channel() by name
276 if name.lower() in self.chans:
277 return self.chans[name.lower()]
278 else:
279 return None
280
586997a7 281 def newchannel(self, bot, name):
282 chan = self.Channel(name.lower(), bot)
5477b368 283 self.chans[name.lower()] = chan
284 return chan
49a455aa 285
286 def poll(self):
2a44c0cd 287 timeout_seconds = 30
fd96a423 288 if self.potype == "poll":
2a44c0cd
JR
289 pollres = self.po.poll(timeout_seconds * 1000)
290 return [fd for (fd, ev) in pollres]
fd96a423 291 elif self.potype == "select":
2a44c0cd 292 return select.select(self.fdlist, [], [], timeout_seconds)[0]
49a455aa 293
294 def connectall(self):
a28e2ae9 295 for bot in self.bots.values():
49a455aa 296 if bot.conn.state == 0:
297 bot.connect()
298
fadbf980 299 def module(self, name):
300 return ctlmod.modules[name]
301
a8553c45 302 def log(self, source, level, message):
a28e2ae9 303 print("%09.3f %s [%s] %s" % (time.time() % 100000, source, level, message))
a8553c45 304
f560eb44 305 def getuserbyauth(self, auth):
a28e2ae9 306 return [u for u in self.users.values() if u.auth == auth.lower()]
f560eb44 307
bffe0139 308 def getdb(self):
6b4ba0b6
JR
309 """Get a DB object. The object must be returned to the pool after us, using returndb(). This is intended for use from child threads.
310 It should probably be treated as deprecated though. Where possible new modules should avoid using threads.
311 In the future, timers will be provided (manipulating the timeout_seconds of the poll() method), and that should mostly be used in place of threading."""
bffe0139 312 return self.dbs.pop()
313
314 def returndb(self, db):
315 self.dbs.append(db)
316
49a455aa 317 #bind functions
db50981b 318 def hook(self, word, handler):
e4a4c762 319 try:
320 self.msghandlers[word].append(handler)
321 except:
322 self.msghandlers[word] = [handler]
323 def unhook(self, word, handler):
324 if word in self.msghandlers and handler in self.msghandlers[word]:
325 self.msghandlers[word].remove(handler)
db50981b 326 def hashook(self, word):
e4a4c762 327 return word in self.msghandlers and len(self.msghandlers[word]) != 0
db50981b 328 def gethook(self, word):
329 return self.msghandlers[word]
b25d4368 330
e4a4c762 331 def hooknum(self, word, handler):
332 try:
333 self.numhandlers[word].append(handler)
334 except:
335 self.numhandlers[word] = [handler]
336 def unhooknum(self, word, handler):
337 if word in self.numhandlers and handler in self.numhandlers[word]:
338 self.numhandlers[word].remove(handler)
339 def hasnumhook(self, word):
340 return word in self.numhandlers and len(self.numhandlers[word]) != 0
341 def getnumhook(self, word):
342 return self.numhandlers[word]
343
2a1a69a6 344 def hookchan(self, chan, handler):
345 try:
9557ee54 346 self.chanhandlers[chan].append(handler)
2a1a69a6 347 except:
9557ee54 348 self.chanhandlers[chan] = [handler]
2a1a69a6 349 def unhookchan(self, chan, handler):
350 if chan in self.chanhandlers and handler in self.chanhandlers[chan]:
351 self.chanhandlers[chan].remove(handler)
352 def haschanhook(self, chan):
353 return chan in self.chanhandlers and len(self.chanhandlers[chan]) != 0
354 def getchanhook(self, chan):
355 return self.chanhandlers[chan]
586997a7 356
e8885384
JR
357 def hookexception(self, exc, handler):
358 self.exceptionhandlers.append((exc, handler))
359 def unhookexception(self, exc, handler):
360 self.exceptionhandlers.remove((exc, handler))
361 def hasexceptionhook(self, exc):
362 return any((True for x,h in self.exceptionhandlers if isinstance(exc, x)))
363 def getexceptionhook(self, exc):
364 return (h for x,h in self.exceptionhandlers if isinstance(exc, x))
365
586997a7 366
de89db13 367def dbsetup():
4fa1118b 368 main.db = None
bffe0139 369 main.dbs = []
5b8f6176
JR
370 dbtype = cfg.get('erebus', 'dbtype', 'mysql')
371 if dbtype == 'mysql':
372 _dbsetup_mysql()
6b4ba0b6
JR
373 elif dbtype == 'sqlite':
374 _dbsetup_sqlite()
5b8f6176
JR
375 else:
376 main.log('*', '!', 'Unknown dbtype in config: %s' % (dbtype))
377
378def _dbsetup_mysql():
379 global db_api
380 import MySQLdb as db_api, MySQLdb.cursors
bffe0139 381 for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
5b8f6176
JR
382 main.dbs.append(db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
383 main.db = db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
384
6b4ba0b6
JR
385def _dbsetup_sqlite():
386 global db_api
387 import sqlite3 as db_api
388 for i in range(cfg.get('erebus', 'num_db_connections', 2)):
389 main.db = db_api.connect(cfg.dbhost)
390 main.db.row_factory = db_api.Row
391 main.db.isolation_level = None
392 main.dbs.append(main.db)
586997a7 393
b25d4368 394def setup():
db50981b 395 global cfg, main
396
48479459 397 cfg = config.Config('bot.config')
e64ac4a0 398
dcc5bde3 399 if cfg.getboolean('debug', 'gc'):
2ffef3ff 400 gc.set_debug(gc.DEBUG_LEAK)
401
e64ac4a0 402 pidfile = open(cfg.pidfile, 'w')
403 pidfile.write(str(os.getpid()))
404 pidfile.close()
405
c0eee1b4 406 main = Erebus(cfg)
bffe0139 407 dbsetup()
db50981b 408
409 autoloads = [mod for mod, yes in cfg.items('autoloads') if int(yes) == 1]
410 for mod in autoloads:
b9c6eb1d 411 ctlmod.load(main, mod)
db50981b 412
2729abc8 413 c = main.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
414 if c:
4fa1118b 415 rows = c.fetchall()
416 c.close()
417 for row in rows:
0af282c6 418 main.newbot(row['nick'], row['user'], row['bind'], row['authname'], row['authpass'], cfg.host, cfg.port, cfg.realname)
a12f7519 419 main.connectall()
b25d4368 420
421def loop():
49a455aa 422 poready = main.poll()
fd96a423 423 for fileno in poready:
9d44d267
JR
424 try:
425 data = main.fd(fileno).getdata()
426 except:
427 main.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno))
428 traceback.print_exc()
429 data = None
430 if data is None:
431 main.fd(fileno).close()
432 else:
433 for line in data:
4aa86bbb
JR
434 if cfg.getboolean('debug', 'io'):
435 main.log(str(main.fd(fileno)), 'I', line)
9d44d267
JR
436 try:
437 main.fd(fileno).parse(line)
438 except:
439 main.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno, line))
440 traceback.print_exc()
2a44c0cd 441 if main.mustquit is not None:
dc0f891b 442 main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
2a44c0cd 443 raise main.mustquit
b25d4368 444
445if __name__ == '__main__':
963f2522 446 try: os.rename('logfile', 'oldlogs/%s' % (time.time()))
24b74bb3 447 except: pass
3d724d3a 448 sys.stdout = open('logfile', 'w', 1)
24b74bb3 449 sys.stderr = sys.stdout
b25d4368 450 setup()
49a455aa 451 while True: loop()