]> jfr.im git - erebus.git/blob - erebus.py
cleanup some error messages
[erebus.git] / erebus.py
1 #!/usr/bin/python
2 # vim: fileencoding=utf-8
3
4 # Erebus IRC bot - Author: John Runyon
5 # main startup code
6
7 from __future__ import print_function
8
9 import os, sys, select, time, traceback, random, gc
10 import bot, config, ctlmod
11
12 class Erebus(object): #singleton to pass around
13 APIVERSION = 0
14 RELEASE = 0
15
16 bots = {}
17 fds = {}
18 numhandlers = {}
19 msghandlers = {}
20 chanhandlers = {}
21 exceptionhandlers = [] # list of (Exception_class, handler_function) tuples
22 users = {}
23 chans = {}
24
25 class User(object):
26 def __init__(self, nick, auth=None):
27 self.nick = nick
28 if auth is None:
29 self.auth = None
30 else:
31 self.auth = auth.lower()
32 self.checklevel()
33
34 self.chans = []
35
36 def bind_bot(self, bot):
37 return main._BoundUser(self, bot)
38
39 def msg(self, *args, **kwargs):
40 main.randbot().msg(self, *args, **kwargs)
41 def slowmsg(self, *args, **kwargs):
42 main.randbot().slowmsg(self, *args, **kwargs)
43 def fastmsg(self, *args, **kwargs):
44 main.randbot().fastmsg(self, *args, **kwargs)
45
46 def isauthed(self):
47 return self.auth is not None
48
49 def authed(self, auth):
50 if auth == '0': self.auth = None
51 else: self.auth = auth.lower()
52 self.checklevel()
53
54 def checklevel(self):
55 if self.auth is None:
56 self.glevel = -1
57 else:
58 c = main.query("SELECT level FROM users WHERE auth = %s", (self.auth,))
59 if c:
60 row = c.fetchone()
61 if row is not None:
62 self.glevel = row['level']
63 else:
64 self.glevel = 0
65 else:
66 self.glevel = 0
67 return self.glevel
68
69 def setlevel(self, level, savetodb=True):
70 if savetodb:
71 if level != 0:
72 c = main.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self.auth, level))
73 else:
74 c = main.query("DELETE FROM users WHERE auth = %s", (self.auth,))
75 if c == 0: # no rows affected
76 c = True # is fine
77 if c:
78 self.glevel = level
79 return True
80 else:
81 return False
82 else:
83 self.glevel = level
84 return True
85
86 def join(self, chan):
87 if chan not in self.chans: self.chans.append(chan)
88 def part(self, chan):
89 try:
90 self.chans.remove(chan)
91 except: pass
92 return len(self.chans) == 0
93 def quit(self):
94 pass
95 def nickchange(self, newnick):
96 self.nick = newnick
97
98 def __str__(self): return self.nick
99 def __repr__(self): return "<User %r (%d)>" % (self.nick, self.glevel)
100
101 class _BoundUser(object):
102 def __init__(self, user, bot):
103 self.__dict__['_bound_user'] = user
104 self.__dict__['_bound_bot'] = bot
105 def __getattr__(self, name):
106 return getattr(self._bound_user, name)
107 def __setattr__(self, name, value):
108 setattr(self._bound_user, name, value)
109 def msg(self, *args, **kwargs):
110 self._bound_bot.msg(self._bound_user, *args, **kwargs)
111 def slowmsg(self, *args, **kwargs):
112 self._bound_bot.slowmsg(self._bound_user, *args, **kwargs)
113 def fastmsg(self, *args, **kwargs):
114 self._bound_bot.fastmsg(self._bound_user, *args, **kwargs)
115 def __repr__(self): return "<_BoundUser %r %r>" % (self._bound_user, self._bound_bot)
116
117 class Channel(object):
118 def __init__(self, name, bot):
119 self.name = name
120 self.bot = bot
121 self.levels = {}
122
123 self.users = []
124 self.voices = []
125 self.ops = []
126
127 self.deleting = False # if true, the bot will remove cached records of this channel when the bot sees that it has left the channel
128
129 c = main.query("SELECT user, level FROM chusers WHERE chan = %s", (self.name,))
130 if c:
131 row = c.fetchone()
132 while row is not None:
133 self.levels[row['user']] = row['level']
134 row = c.fetchone()
135
136
137 def msg(self, *args, **kwargs):
138 self.bot.msg(self, *args, **kwargs)
139 def slowmsg(self, *args, **kwargs):
140 self.bot.slowmsg(self, *args, **kwargs)
141 def fastmsg(self, *args, **kwargs):
142 self.bot.fastmsg(self, *args, **kwargs)
143
144 def levelof(self, auth):
145 if auth is None:
146 return 0
147 auth = auth.lower()
148 if auth in self.levels:
149 return self.levels[auth]
150 else:
151 return 0
152
153 def setlevel(self, auth, level, savetodb=True):
154 auth = auth.lower()
155 if savetodb:
156 c = main.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self.name, auth, level))
157 if c:
158 self.levels[auth] = level
159 return True
160 else:
161 return False
162 else:
163 self.levels[auth] = level
164 return True
165
166 def userjoin(self, user, level=None):
167 if user not in self.users: self.users.append(user)
168 if level == 'op' and user not in self.ops: self.ops.append(user)
169 if level == 'voice' and user not in self.voices: self.voices.append(user)
170 def userpart(self, user):
171 if user in self.ops: self.ops.remove(user)
172 if user in self.voices: self.voices.remove(user)
173 if user in self.users: self.users.remove(user)
174
175 def userop(self, user):
176 if user in self.users and user not in self.ops: self.ops.append(user)
177 def uservoice(self, user):
178 if user in self.users and user not in self.voices: self.voices.append(user)
179 def userdeop(self, user):
180 if user in self.ops: self.ops.remove(user)
181 def userdevoice(self, user):
182 if user in self.voices: self.voices.remove(user)
183
184 def __str__(self): return self.name
185 def __repr__(self): return "<Channel %r>" % (self.name)
186
187 def __init__(self, cfg):
188 self.mustquit = None
189 self.starttime = time.time()
190 self.cfg = cfg
191 self.trigger = cfg.trigger
192 if os.name == "posix":
193 self.potype = "poll"
194 self.po = select.poll()
195 else: # f.e. os.name == "nt" (Windows)
196 self.potype = "select"
197 self.fdlist = []
198
199 def query(self, sql, parameters=[], noretry=False):
200 # Callers use %s-style (paramstyle='format') placeholders in queries.
201 # There's no provision for a literal '%s' present inside the query; stuff it in a parameter instead.
202 if db_api.paramstyle == 'format' or db_api.paramstyle == 'pyformat': # mysql, postgresql
203 # psycopg actually asks for a mapping with %(name)s style (pyformat) but it will accept %s style.
204 pass
205 elif db_api.paramstyle == 'qmark': # sqlite doesn't like %s style.
206 parameters = [str(p) for p in parameters]
207 sql = sql.replace('%s', '?') # hope that wasn't literal, oopsie
208
209 log_noretry = ''
210 if noretry:
211 log_noretry = ', noretry=True'
212 self.log("[SQL]", "?", "query(%r, %r%s)" % (sql, parameters, log_noretry))
213
214 try:
215 curs = self.db.cursor()
216 res = curs.execute(sql, parameters)
217 if res:
218 return curs
219 else:
220 return res
221 except db_api.DataError as e:
222 self.log("[SQL]", ".", "DB DataError: %r" % (e))
223 return False
224 except db_api.Error as e:
225 self.log("[SQL]", "!", "DB error! %r" % (e))
226 if not noretry:
227 dbsetup()
228 return self.query(sql, parameters, noretry=True)
229 else:
230 raise e
231
232 def querycb(self, cb, *args, **kwargs):
233 # TODO this should either get thrown out with getdb()/returndb(), or else be adjusted to make use of it.
234 def run_query():
235 cb(self.query(*args, **kwargs))
236 threading.Thread(target=run_query).start()
237
238 def newbot(self, nick, user, bind, authname, authpass, server, port, realname):
239 if bind is None: bind = ''
240 obj = bot.Bot(self, nick, user, bind, authname, authpass, server, port, realname)
241 self.bots[nick.lower()] = obj
242
243 def newfd(self, obj, fileno):
244 self.fds[fileno] = obj
245 if self.potype == "poll":
246 self.po.register(fileno, select.POLLIN)
247 elif self.potype == "select":
248 self.fdlist.append(fileno)
249 def delfd(self, fileno):
250 del self.fds[fileno]
251 if self.potype == "poll":
252 self.po.unregister(fileno)
253 elif self.potype == "select":
254 self.fdlist.remove(fileno)
255
256 def bot(self, name): #get Bot() by name (nick)
257 return self.bots[name.lower()]
258 def fd(self, fileno): #get Bot() by fd/fileno
259 return self.fds[fileno]
260 def randbot(self): #get Bot() randomly
261 return self.bots[random.choice(list(self.bots.keys()))]
262
263 def user(self, _nick, send_who=False, create=True):
264 nick = _nick.lower()
265
266 if send_who and (nick not in self.users or not self.users[nick].isauthed()):
267 self.randbot().conn.send("WHO %s n%%ant,1" % (nick))
268
269 if nick in self.users:
270 return self.users[nick]
271 elif create:
272 user = self.User(_nick)
273 self.users[nick] = user
274 return user
275 else:
276 return None
277 def channel(self, name): #get Channel() by name
278 if name.lower() in self.chans:
279 return self.chans[name.lower()]
280 else:
281 return None
282
283 def newchannel(self, bot, name):
284 chan = self.Channel(name.lower(), bot)
285 self.chans[name.lower()] = chan
286 return chan
287
288 def poll(self):
289 timeout_seconds = 30
290 if self.potype == "poll":
291 pollres = self.po.poll(timeout_seconds * 1000)
292 return [fd for (fd, ev) in pollres]
293 elif self.potype == "select":
294 return select.select(self.fdlist, [], [], timeout_seconds)[0]
295
296 def connectall(self):
297 for bot in self.bots.values():
298 if bot.conn.state == 0:
299 bot.connect()
300
301 def module(self, name):
302 return ctlmod.modules[name]
303
304 def log(self, source, level, message):
305 print("%09.3f %s [%s] %s" % (time.time() % 100000, source, level, message))
306
307 def getuserbyauth(self, auth):
308 return [u for u in self.users.values() if u.auth == auth.lower()]
309
310 def getdb(self):
311 """Get a DB object. The object must be returned to the pool after us, using returndb(). This is intended for use from child threads.
312 It should probably be treated as deprecated though. Where possible new modules should avoid using threads.
313 In the future, timers will be provided (manipulating the timeout_seconds of the poll() method), and that should mostly be used in place of threading."""
314 return self.dbs.pop()
315
316 def returndb(self, db):
317 self.dbs.append(db)
318
319 #bind functions
320 def hook(self, word, handler):
321 try:
322 self.msghandlers[word].append(handler)
323 except:
324 self.msghandlers[word] = [handler]
325 def unhook(self, word, handler):
326 if word in self.msghandlers and handler in self.msghandlers[word]:
327 self.msghandlers[word].remove(handler)
328 def hashook(self, word):
329 return word in self.msghandlers and len(self.msghandlers[word]) != 0
330 def gethook(self, word):
331 return self.msghandlers[word]
332
333 def hooknum(self, word, handler):
334 try:
335 self.numhandlers[word].append(handler)
336 except:
337 self.numhandlers[word] = [handler]
338 def unhooknum(self, word, handler):
339 if word in self.numhandlers and handler in self.numhandlers[word]:
340 self.numhandlers[word].remove(handler)
341 def hasnumhook(self, word):
342 return word in self.numhandlers and len(self.numhandlers[word]) != 0
343 def getnumhook(self, word):
344 return self.numhandlers[word]
345
346 def hookchan(self, chan, handler):
347 try:
348 self.chanhandlers[chan].append(handler)
349 except:
350 self.chanhandlers[chan] = [handler]
351 def unhookchan(self, chan, handler):
352 if chan in self.chanhandlers and handler in self.chanhandlers[chan]:
353 self.chanhandlers[chan].remove(handler)
354 def haschanhook(self, chan):
355 return chan in self.chanhandlers and len(self.chanhandlers[chan]) != 0
356 def getchanhook(self, chan):
357 return self.chanhandlers[chan]
358
359 def hookexception(self, exc, handler):
360 self.exceptionhandlers.append((exc, handler))
361 def unhookexception(self, exc, handler):
362 self.exceptionhandlers.remove((exc, handler))
363 def hasexceptionhook(self, exc):
364 return any((True for x,h in self.exceptionhandlers if isinstance(exc, x)))
365 def getexceptionhook(self, exc):
366 return (h for x,h in self.exceptionhandlers if isinstance(exc, x))
367
368
369 def dbsetup():
370 main.db = None
371 main.dbs = []
372 dbtype = cfg.get('erebus', 'dbtype', 'mysql')
373 if dbtype == 'mysql':
374 _dbsetup_mysql()
375 elif dbtype == 'sqlite':
376 _dbsetup_sqlite()
377 else:
378 main.log('*', '!', 'Unknown dbtype in config: %s' % (dbtype))
379
380 def _dbsetup_mysql():
381 global db_api
382 import MySQLdb as db_api, MySQLdb.cursors
383 for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
384 main.dbs.append(db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
385 main.db = db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
386
387 def _dbsetup_sqlite():
388 global db_api
389 import sqlite3 as db_api
390 for i in range(cfg.get('erebus', 'num_db_connections', 2)):
391 main.db = db_api.connect(cfg.dbhost)
392 main.db.row_factory = db_api.Row
393 main.db.isolation_level = None
394 main.dbs.append(main.db)
395
396 def setup():
397 global cfg, main
398
399 cfg = config.Config('bot.config')
400
401 if cfg.getboolean('debug', 'gc'):
402 gc.set_debug(gc.DEBUG_LEAK)
403
404 pidfile = open(cfg.pidfile, 'w')
405 pidfile.write(str(os.getpid()))
406 pidfile.close()
407
408 main = Erebus(cfg)
409 dbsetup()
410
411 autoloads = [mod for mod, yes in cfg.items('autoloads') if int(yes) == 1]
412 for mod in autoloads:
413 ctlmod.load(main, mod)
414
415 c = main.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
416 if c:
417 rows = c.fetchall()
418 c.close()
419 for row in rows:
420 main.newbot(row['nick'], row['user'], row['bind'], row['authname'], row['authpass'], cfg.host, cfg.port, cfg.realname)
421 main.connectall()
422
423 def loop():
424 poready = main.poll()
425 for fileno in poready:
426 try:
427 data = main.fd(fileno).getdata()
428 except:
429 main.log('*', '!', 'Error receiving data: getdata raised exception for socket %d, closing' % (fileno))
430 traceback.print_exc()
431 data = None
432 if data is None:
433 main.fd(fileno).close()
434 else:
435 for line in data:
436 if cfg.getboolean('debug', 'io'):
437 main.log(str(main.fd(fileno)), 'I', line)
438 try:
439 main.fd(fileno).parse(line)
440 except:
441 main.log('*', '!', 'Error receiving data: parse raised exception for socket %d data %r, ignoring' % (fileno, line))
442 traceback.print_exc()
443 if main.mustquit is not None:
444 main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
445 raise main.mustquit
446
447 if __name__ == '__main__':
448 try: os.rename('logfile', 'oldlogs/%s' % (time.time()))
449 except: pass
450 sys.stdout = open('logfile', 'w', 1)
451 sys.stderr = sys.stdout
452 setup()
453 while True: loop()