]> jfr.im git - erebus.git/blob - erebus.py
add sqlite database support
[erebus.git] / erebus.py
1 #!/usr/bin/python
2 # vim: fileencoding=utf-8
3
4 # Erebus IRC bot - Author: John Runyon
5 # main startup code
6
7 from __future__ import print_function
8
9 import os, sys, select, time, traceback, random, gc
10 import bot, config, ctlmod
11
12 class Erebus(object): #singleton to pass around
13 APIVERSION = 0
14 RELEASE = 0
15
16 bots = {}
17 fds = {}
18 numhandlers = {}
19 msghandlers = {}
20 chanhandlers = {}
21 exceptionhandlers = [] # list of (Exception_class, handler_function) tuples
22 users = {}
23 chans = {}
24
25 class User(object):
26 def __init__(self, nick, auth=None):
27 self.nick = nick
28 if auth is None:
29 self.auth = None
30 else:
31 self.auth = auth.lower()
32 self.checklevel()
33
34 self.chans = []
35
36 def bind_bot(self, bot):
37 return main._BoundUser(self, bot)
38
39 def msg(self, *args, **kwargs):
40 main.randbot().msg(self, *args, **kwargs)
41 def slowmsg(self, *args, **kwargs):
42 main.randbot().slowmsg(self, *args, **kwargs)
43 def fastmsg(self, *args, **kwargs):
44 main.randbot().fastmsg(self, *args, **kwargs)
45
46 def isauthed(self):
47 return self.auth is not None
48
49 def authed(self, auth):
50 if auth == '0': self.auth = None
51 else: self.auth = auth.lower()
52 self.checklevel()
53
54 def checklevel(self):
55 if self.auth is None:
56 self.glevel = -1
57 else:
58 c = main.query("SELECT level FROM users WHERE auth = %s", (self.auth,))
59 if c:
60 row = c.fetchone()
61 if row is not None:
62 self.glevel = row['level']
63 else:
64 self.glevel = 0
65 else:
66 self.glevel = 0
67 return self.glevel
68
69 def setlevel(self, level, savetodb=True):
70 if savetodb:
71 if level != 0:
72 c = main.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self.auth, level))
73 else:
74 c = main.query("DELETE FROM users WHERE auth = %s", (self.auth,))
75 if c == 0: # no rows affected
76 c = True # is fine
77 if c:
78 self.glevel = level
79 return True
80 else:
81 return False
82 else:
83 self.glevel = level
84 return True
85
86 def join(self, chan):
87 if chan not in self.chans: self.chans.append(chan)
88 def part(self, chan):
89 try:
90 self.chans.remove(chan)
91 except: pass
92 return len(self.chans) == 0
93 def quit(self):
94 pass
95 def nickchange(self, newnick):
96 self.nick = newnick
97
98 def __str__(self): return self.nick
99 def __repr__(self): return "<User %r (%d)>" % (self.nick, self.glevel)
100
101 class _BoundUser(object):
102 def __init__(self, user, bot):
103 self.__dict__['_bound_user'] = user
104 self.__dict__['_bound_bot'] = bot
105 def __getattr__(self, name):
106 return getattr(self._bound_user, name)
107 def __setattr__(self, name, value):
108 setattr(self._bound_user, name, value)
109 def msg(self, *args, **kwargs):
110 self._bound_bot.msg(self._bound_user, *args, **kwargs)
111 def slowmsg(self, *args, **kwargs):
112 self._bound_bot.slowmsg(self._bound_user, *args, **kwargs)
113 def fastmsg(self, *args, **kwargs):
114 self._bound_bot.fastmsg(self._bound_user, *args, **kwargs)
115 def __repr__(self): return "<_BoundUser %r %r>" % (self._bound_user, self._bound_bot)
116
117 class Channel(object):
118 def __init__(self, name, bot):
119 self.name = name
120 self.bot = bot
121 self.levels = {}
122
123 self.users = []
124 self.voices = []
125 self.ops = []
126
127 c = main.query("SELECT user, level FROM chusers WHERE chan = %s", (self.name,))
128 if c:
129 row = c.fetchone()
130 while row is not None:
131 self.levels[row['user']] = row['level']
132 row = c.fetchone()
133
134
135 def msg(self, *args, **kwargs):
136 self.bot.msg(self, *args, **kwargs)
137 def slowmsg(self, *args, **kwargs):
138 self.bot.slowmsg(self, *args, **kwargs)
139 def fastmsg(self, *args, **kwargs):
140 self.bot.fastmsg(self, *args, **kwargs)
141
142 def levelof(self, auth):
143 if auth is None:
144 return 0
145 auth = auth.lower()
146 if auth in self.levels:
147 return self.levels[auth]
148 else:
149 return 0
150
151 def setlevel(self, auth, level, savetodb=True):
152 auth = auth.lower()
153 if savetodb:
154 c = main.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self.name, auth, level))
155 if c:
156 self.levels[auth] = level
157 return True
158 else:
159 return False
160 else:
161 self.levels[auth] = level
162 return True
163
164 def userjoin(self, user, level=None):
165 if user not in self.users: self.users.append(user)
166 if level == 'op' and user not in self.ops: self.ops.append(user)
167 if level == 'voice' and user not in self.voices: self.voices.append(user)
168 def userpart(self, user):
169 if user in self.ops: self.ops.remove(user)
170 if user in self.voices: self.voices.remove(user)
171 if user in self.users: self.users.remove(user)
172
173 def userop(self, user):
174 if user in self.users and user not in self.ops: self.ops.append(user)
175 def uservoice(self, user):
176 if user in self.users and user not in self.voices: self.voices.append(user)
177 def userdeop(self, user):
178 if user in self.ops: self.ops.remove(user)
179 def userdevoice(self, user):
180 if user in self.voices: self.voices.remove(user)
181
182 def __str__(self): return self.name
183 def __repr__(self): return "<Channel %r>" % (self.name)
184
185 def __init__(self, cfg):
186 self.mustquit = None
187 self.starttime = time.time()
188 self.cfg = cfg
189 self.trigger = cfg.trigger
190 if os.name == "posix":
191 self.potype = "poll"
192 self.po = select.poll()
193 else: # f.e. os.name == "nt" (Windows)
194 self.potype = "select"
195 self.fdlist = []
196
197 def query(self, sql, parameters=[], noretry=False):
198 # Callers use %s-style (paramstyle='format') placeholders in queries.
199 # There's no provision for a literal '%s' present inside the query; stuff it in a parameter instead.
200 if db_api.paramstyle == 'format' or db_api.paramstyle == 'pyformat': # mysql, postgresql
201 # psycopg actually asks for a mapping with %(name)s style (pyformat) but it will accept %s style.
202 pass
203 elif db_api.paramstyle == 'qmark': # sqlite doesn't like %s style.
204 parameters = [str(p) for p in parameters]
205 sql = sql.replace('%s', '?') # hope that wasn't literal, oopsie
206
207 log_noretry = ''
208 if noretry:
209 log_noretry = ', noretry=True'
210 self.log("[SQL]", "?", "query(%r, %r%s)" % (sql, parameters, log_noretry))
211
212 try:
213 curs = self.db.cursor()
214 res = curs.execute(sql, parameters)
215 if res:
216 return curs
217 else:
218 return res
219 except db_api.DataError as e:
220 self.log("[SQL]", ".", "DB DataError: %r" % (e))
221 return False
222 except db_api.Error as e:
223 self.log("[SQL]", "!", "DB error! %r" % (e))
224 if not noretry:
225 dbsetup()
226 return self.query(sql, parameters, noretry=True)
227 else:
228 raise e
229
230 def querycb(self, cb, *args, **kwargs):
231 # TODO this should either get thrown out with getdb()/returndb(), or else be adjusted to make use of it.
232 def run_query():
233 cb(self.query(*args, **kwargs))
234 threading.Thread(target=run_query).start()
235
236 def newbot(self, nick, user, bind, authname, authpass, server, port, realname):
237 if bind is None: bind = ''
238 obj = bot.Bot(self, nick, user, bind, authname, authpass, server, port, realname)
239 self.bots[nick.lower()] = obj
240
241 def newfd(self, obj, fileno):
242 self.fds[fileno] = obj
243 if self.potype == "poll":
244 self.po.register(fileno, select.POLLIN)
245 elif self.potype == "select":
246 self.fdlist.append(fileno)
247 def delfd(self, fileno):
248 del self.fds[fileno]
249 if self.potype == "poll":
250 self.po.unregister(fileno)
251 elif self.potype == "select":
252 self.fdlist.remove(fileno)
253
254 def bot(self, name): #get Bot() by name (nick)
255 return self.bots[name.lower()]
256 def fd(self, fileno): #get Bot() by fd/fileno
257 return self.fds[fileno]
258 def randbot(self): #get Bot() randomly
259 return self.bots[random.choice(list(self.bots.keys()))]
260
261 def user(self, _nick, send_who=False, create=True):
262 nick = _nick.lower()
263
264 if send_who and (nick not in self.users or not self.users[nick].isauthed()):
265 self.randbot().conn.send("WHO %s n%%ant,1" % (nick))
266
267 if nick in self.users:
268 return self.users[nick]
269 elif create:
270 user = self.User(_nick)
271 self.users[nick] = user
272 return user
273 else:
274 return None
275 def channel(self, name): #get Channel() by name
276 if name.lower() in self.chans:
277 return self.chans[name.lower()]
278 else:
279 return None
280
281 def newchannel(self, bot, name):
282 chan = self.Channel(name.lower(), bot)
283 self.chans[name.lower()] = chan
284 return chan
285
286 def poll(self):
287 timeout_seconds = 30
288 if self.potype == "poll":
289 pollres = self.po.poll(timeout_seconds * 1000)
290 return [fd for (fd, ev) in pollres]
291 elif self.potype == "select":
292 return select.select(self.fdlist, [], [], timeout_seconds)[0]
293
294 def connectall(self):
295 for bot in self.bots.values():
296 if bot.conn.state == 0:
297 bot.connect()
298
299 def module(self, name):
300 return ctlmod.modules[name]
301
302 def log(self, source, level, message):
303 print("%09.3f %s [%s] %s" % (time.time() % 100000, source, level, message))
304
305 def getuserbyauth(self, auth):
306 return [u for u in self.users.values() if u.auth == auth.lower()]
307
308 def getdb(self):
309 """Get a DB object. The object must be returned to the pool after us, using returndb(). This is intended for use from child threads.
310 It should probably be treated as deprecated though. Where possible new modules should avoid using threads.
311 In the future, timers will be provided (manipulating the timeout_seconds of the poll() method), and that should mostly be used in place of threading."""
312 return self.dbs.pop()
313
314 def returndb(self, db):
315 self.dbs.append(db)
316
317 #bind functions
318 def hook(self, word, handler):
319 try:
320 self.msghandlers[word].append(handler)
321 except:
322 self.msghandlers[word] = [handler]
323 def unhook(self, word, handler):
324 if word in self.msghandlers and handler in self.msghandlers[word]:
325 self.msghandlers[word].remove(handler)
326 def hashook(self, word):
327 return word in self.msghandlers and len(self.msghandlers[word]) != 0
328 def gethook(self, word):
329 return self.msghandlers[word]
330
331 def hooknum(self, word, handler):
332 try:
333 self.numhandlers[word].append(handler)
334 except:
335 self.numhandlers[word] = [handler]
336 def unhooknum(self, word, handler):
337 if word in self.numhandlers and handler in self.numhandlers[word]:
338 self.numhandlers[word].remove(handler)
339 def hasnumhook(self, word):
340 return word in self.numhandlers and len(self.numhandlers[word]) != 0
341 def getnumhook(self, word):
342 return self.numhandlers[word]
343
344 def hookchan(self, chan, handler):
345 try:
346 self.chanhandlers[chan].append(handler)
347 except:
348 self.chanhandlers[chan] = [handler]
349 def unhookchan(self, chan, handler):
350 if chan in self.chanhandlers and handler in self.chanhandlers[chan]:
351 self.chanhandlers[chan].remove(handler)
352 def haschanhook(self, chan):
353 return chan in self.chanhandlers and len(self.chanhandlers[chan]) != 0
354 def getchanhook(self, chan):
355 return self.chanhandlers[chan]
356
357 def hookexception(self, exc, handler):
358 self.exceptionhandlers.append((exc, handler))
359 def unhookexception(self, exc, handler):
360 self.exceptionhandlers.remove((exc, handler))
361 def hasexceptionhook(self, exc):
362 return any((True for x,h in self.exceptionhandlers if isinstance(exc, x)))
363 def getexceptionhook(self, exc):
364 return (h for x,h in self.exceptionhandlers if isinstance(exc, x))
365
366
367 def dbsetup():
368 main.db = None
369 main.dbs = []
370 dbtype = cfg.get('erebus', 'dbtype', 'mysql')
371 if dbtype == 'mysql':
372 _dbsetup_mysql()
373 elif dbtype == 'sqlite':
374 _dbsetup_sqlite()
375 else:
376 main.log('*', '!', 'Unknown dbtype in config: %s' % (dbtype))
377
378 def _dbsetup_mysql():
379 global db_api
380 import MySQLdb as db_api, MySQLdb.cursors
381 for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
382 main.dbs.append(db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
383 main.db = db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
384
385 def _dbsetup_sqlite():
386 global db_api
387 import sqlite3 as db_api
388 for i in range(cfg.get('erebus', 'num_db_connections', 2)):
389 main.db = db_api.connect(cfg.dbhost)
390 main.db.row_factory = db_api.Row
391 main.db.isolation_level = None
392 main.dbs.append(main.db)
393
394 def setup():
395 global cfg, main
396
397 cfg = config.Config('bot.config')
398
399 if cfg.getboolean('debug', 'gc'):
400 gc.set_debug(gc.DEBUG_LEAK)
401
402 pidfile = open(cfg.pidfile, 'w')
403 pidfile.write(str(os.getpid()))
404 pidfile.close()
405
406 main = Erebus(cfg)
407 dbsetup()
408
409 autoloads = [mod for mod, yes in cfg.items('autoloads') if int(yes) == 1]
410 for mod in autoloads:
411 ctlmod.load(main, mod)
412
413 c = main.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
414 if c:
415 rows = c.fetchall()
416 c.close()
417 for row in rows:
418 main.newbot(row['nick'], row['user'], row['bind'], row['authname'], row['authpass'], cfg.host, cfg.port, cfg.realname)
419 main.connectall()
420
421 def loop():
422 poready = main.poll()
423 for fileno in poready:
424 try:
425 data = main.fd(fileno).getdata()
426 except:
427 main.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno))
428 traceback.print_exc()
429 data = None
430 if data is None:
431 main.fd(fileno).close()
432 else:
433 for line in data:
434 if cfg.getboolean('debug', 'io'):
435 main.log(str(main.fd(fileno)), 'I', line)
436 try:
437 main.fd(fileno).parse(line)
438 except:
439 main.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno, line))
440 traceback.print_exc()
441 if main.mustquit is not None:
442 main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
443 raise main.mustquit
444
445 if __name__ == '__main__':
446 try: os.rename('logfile', 'oldlogs/%s' % (time.time()))
447 except: pass
448 sys.stdout = open('logfile', 'w', 1)
449 sys.stderr = sys.stdout
450 setup()
451 while True: loop()