]> jfr.im git - erebus.git/blame_incremental - erebus.py
admin_config - add !getconfig, remove some unused functions
[erebus.git] / erebus.py
... / ...
CommitLineData
1#!/usr/bin/python
2# vim: fileencoding=utf-8
3
4# Erebus IRC bot - Author: John Runyon
5# main startup code
6
7from __future__ import print_function
8
9import os, sys, select, time, traceback, random, gc
10import bot, config, ctlmod, modlib
11
12class Erebus(object): #singleton to pass around
13 APIVERSION = 0
14 RELEASE = 0
15
16 bots = {}
17 fds = {}
18 numhandlers = {}
19 msghandlers = {}
20 chanhandlers = {}
21 exceptionhandlers = [] # list of (Exception_class, handler_function) tuples
22 users = {}
23 chans = {}
24
25 class User(object):
26 def __init__(self, nick, auth=None):
27 self.nick = nick
28 if auth is None:
29 self.auth = None
30 else:
31 self.auth = auth.lower()
32 self.checklevel()
33
34 self.chans = []
35
36 def bind_bot(self, bot):
37 return main._BoundUser(self, bot)
38
39 def msg(self, *args, **kwargs):
40 main.randbot().msg(self, *args, **kwargs)
41 def slowmsg(self, *args, **kwargs):
42 main.randbot().slowmsg(self, *args, **kwargs)
43 def fastmsg(self, *args, **kwargs):
44 main.randbot().fastmsg(self, *args, **kwargs)
45
46 def isauthed(self):
47 return self.auth is not None
48
49 def authed(self, auth):
50 if auth == '0': self.auth = None
51 else: self.auth = auth.lower()
52 self.checklevel()
53
54 def checklevel(self):
55 if self.auth is None:
56 self.glevel = -1
57 else:
58 c = main.query("SELECT level FROM users WHERE auth = %s", (self.auth,))
59 if c:
60 row = c.fetchone()
61 if row is not None:
62 self.glevel = row['level']
63 else:
64 self.glevel = 0
65 else:
66 self.glevel = 0
67 return self.glevel
68
69 def setlevel(self, level, savetodb=True):
70 if savetodb:
71 if level != 0:
72 c = main.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self.auth, level))
73 else:
74 c = main.query("DELETE FROM users WHERE auth = %s", (self.auth,))
75 if c == 0: # no rows affected
76 c = True # is fine
77 if c:
78 self.glevel = level
79 return True
80 else:
81 return False
82 else:
83 self.glevel = level
84 return True
85
86 def join(self, chan):
87 if chan not in self.chans: self.chans.append(chan)
88 def part(self, chan):
89 try:
90 self.chans.remove(chan)
91 except: pass
92 return len(self.chans) == 0
93 def quit(self):
94 pass
95 def nickchange(self, newnick):
96 self.nick = newnick
97
98 def __str__(self): return self.nick
99 def __repr__(self): return "<User %r (%d)>" % (self.nick, self.glevel)
100
101 class _BoundUser(object):
102 def __init__(self, user, bot):
103 self.__dict__['_bound_user'] = user
104 self.__dict__['_bound_bot'] = bot
105 def __getattr__(self, name):
106 return getattr(self._bound_user, name)
107 def __setattr__(self, name, value):
108 setattr(self._bound_user, name, value)
109 def msg(self, *args, **kwargs):
110 self._bound_bot.msg(self._bound_user, *args, **kwargs)
111 def slowmsg(self, *args, **kwargs):
112 self._bound_bot.slowmsg(self._bound_user, *args, **kwargs)
113 def fastmsg(self, *args, **kwargs):
114 self._bound_bot.fastmsg(self._bound_user, *args, **kwargs)
115 def __repr__(self): return "<_BoundUser %r %r>" % (self._bound_user, self._bound_bot)
116
117 class Channel(object):
118 def __init__(self, name, bot):
119 self.name = name
120 self.bot = bot
121 self.levels = {}
122
123 self.users = []
124 self.voices = []
125 self.ops = []
126
127 self.deleting = False # if true, the bot will remove cached records of this channel when the bot sees that it has left the channel
128
129 c = main.query("SELECT user, level FROM chusers WHERE chan = %s", (self.name,))
130 if c:
131 row = c.fetchone()
132 while row is not None:
133 self.levels[row['user']] = row['level']
134 row = c.fetchone()
135
136
137 def msg(self, *args, **kwargs):
138 self.bot.msg(self, *args, **kwargs)
139 def slowmsg(self, *args, **kwargs):
140 self.bot.slowmsg(self, *args, **kwargs)
141 def fastmsg(self, *args, **kwargs):
142 self.bot.fastmsg(self, *args, **kwargs)
143
144 def levelof(self, auth):
145 if auth is None:
146 return 0
147 auth = auth.lower()
148 if auth in self.levels:
149 return self.levels[auth]
150 else:
151 return 0
152
153 def setlevel(self, auth, level, savetodb=True):
154 auth = auth.lower()
155 if savetodb:
156 c = main.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self.name, auth, level))
157 if c:
158 self.levels[auth] = level
159 return True
160 else:
161 return False
162 else:
163 self.levels[auth] = level
164 return True
165
166 def userjoin(self, user, level=None):
167 if user not in self.users: self.users.append(user)
168 if level == 'op' and user not in self.ops: self.ops.append(user)
169 if level == 'voice' and user not in self.voices: self.voices.append(user)
170 def userpart(self, user):
171 if user in self.ops: self.ops.remove(user)
172 if user in self.voices: self.voices.remove(user)
173 if user in self.users: self.users.remove(user)
174
175 def userop(self, user):
176 if user in self.users and user not in self.ops: self.ops.append(user)
177 def uservoice(self, user):
178 if user in self.users and user not in self.voices: self.voices.append(user)
179 def userdeop(self, user):
180 if user in self.ops: self.ops.remove(user)
181 def userdevoice(self, user):
182 if user in self.voices: self.voices.remove(user)
183
184 def __str__(self): return self.name
185 def __repr__(self): return "<Channel %r>" % (self.name)
186
187 def __init__(self, cfg):
188 self.mustquit = None
189 self.starttime = time.time()
190 self.cfg = cfg
191 self.trigger = cfg.trigger
192 if os.name == "posix":
193 self.potype = "poll"
194 self.po = select.poll()
195 else: # f.e. os.name == "nt" (Windows)
196 self.potype = "select"
197 self.fdlist = []
198
199 def query(self, sql, parameters=[], noretry=False):
200 # Callers use %s-style (paramstyle='format') placeholders in queries.
201 # There's no provision for a literal '%s' present inside the query; stuff it in a parameter instead.
202 if db_api.paramstyle == 'format' or db_api.paramstyle == 'pyformat': # mysql, postgresql
203 # psycopg actually asks for a mapping with %(name)s style (pyformat) but it will accept %s style.
204 pass
205 elif db_api.paramstyle == 'qmark': # sqlite doesn't like %s style.
206 parameters = [str(p) for p in parameters]
207 sql = sql.replace('%s', '?') # hope that wasn't literal, oopsie
208
209 log_noretry = ''
210 if noretry:
211 log_noretry = ', noretry=True'
212 self.log("[SQL]", "?", "query(%r, %r%s)" % (sql, parameters, log_noretry))
213
214 try:
215 curs = self.db.cursor()
216 res = curs.execute(sql, parameters)
217 if res:
218 return curs
219 else:
220 return res
221 except db_api.DataError as e:
222 self.log("[SQL]", ".", "DB DataError: %r" % (e))
223 return False
224 except db_api.Error as e:
225 self.log("[SQL]", "!", "DB error! %r" % (e))
226 if not noretry:
227 dbsetup()
228 return self.query(sql, parameters, noretry=True)
229 else:
230 raise e
231
232 def querycb(self, cb, *args, **kwargs):
233 # TODO this should either get thrown out with getdb()/returndb(), or else be adjusted to make use of it.
234 def run_query():
235 cb(self.query(*args, **kwargs))
236 threading.Thread(target=run_query).start()
237
238 def newbot(self, nick, user, bind, authname, authpass, server, port, realname):
239 if bind is None: bind = ''
240 obj = bot.Bot(self, nick, user, bind, authname, authpass, server, port, realname)
241 self.bots[nick.lower()] = obj
242
243 def newfd(self, obj, fileno):
244 if not isinstance(obj, modlib.Socketlike):
245 raise Exception('Attempted to hook a socket without a class to process data')
246 self.fds[fileno] = obj
247 if self.potype == "poll":
248 self.po.register(fileno, select.POLLIN)
249 elif self.potype == "select":
250 self.fdlist.append(fileno)
251 def delfd(self, fileno):
252 del self.fds[fileno]
253 if self.potype == "poll":
254 self.po.unregister(fileno)
255 elif self.potype == "select":
256 self.fdlist.remove(fileno)
257
258 def bot(self, name): #get Bot() by name (nick)
259 return self.bots[name.lower()]
260 def fd(self, fileno): #get Bot() by fd/fileno
261 return self.fds[fileno]
262 def randbot(self): #get Bot() randomly
263 return self.bots[random.choice(list(self.bots.keys()))]
264
265 def user(self, _nick, send_who=False, create=True):
266 nick = _nick.lower()
267
268 if send_who and (nick not in self.users or not self.users[nick].isauthed()):
269 self.randbot().conn.send("WHO %s n%%ant,1" % (nick))
270
271 if nick in self.users:
272 return self.users[nick]
273 elif create:
274 user = self.User(_nick)
275 self.users[nick] = user
276 return user
277 else:
278 return None
279 def channel(self, name): #get Channel() by name
280 if name.lower() in self.chans:
281 return self.chans[name.lower()]
282 else:
283 return None
284
285 def newchannel(self, bot, name):
286 chan = self.Channel(name.lower(), bot)
287 self.chans[name.lower()] = chan
288 return chan
289
290 def poll(self):
291 timeout_seconds = 30
292 if self.potype == "poll":
293 pollres = self.po.poll(timeout_seconds * 1000)
294 return [fd for (fd, ev) in pollres]
295 elif self.potype == "select":
296 return select.select(self.fdlist, [], [], timeout_seconds)[0]
297
298 def connectall(self):
299 for bot in self.bots.values():
300 if bot.conn.state == 0:
301 bot.connect()
302
303 def module(self, name):
304 return ctlmod.modules[name]
305
306 def log(self, source, level, message):
307 print("%09.3f %s [%s] %s" % (time.time() % 100000, source, level, message))
308
309 def getuserbyauth(self, auth):
310 return [u for u in self.users.values() if u.auth == auth.lower()]
311
312 def getdb(self):
313 """Get a DB object. The object must be returned to the pool after us, using returndb(). This is intended for use from child threads.
314 It should probably be treated as deprecated though. Where possible new modules should avoid using threads.
315 In the future, timers will be provided (manipulating the timeout_seconds of the poll() method), and that should mostly be used in place of threading."""
316 return self.dbs.pop()
317
318 def returndb(self, db):
319 self.dbs.append(db)
320
321 #bind functions
322 def hook(self, word, handler):
323 try:
324 self.msghandlers[word].append(handler)
325 except:
326 self.msghandlers[word] = [handler]
327 def unhook(self, word, handler):
328 if word in self.msghandlers and handler in self.msghandlers[word]:
329 self.msghandlers[word].remove(handler)
330 def hashook(self, word):
331 return word in self.msghandlers and len(self.msghandlers[word]) != 0
332 def gethook(self, word):
333 return self.msghandlers[word]
334
335 def hooknum(self, word, handler):
336 try:
337 self.numhandlers[word].append(handler)
338 except:
339 self.numhandlers[word] = [handler]
340 def unhooknum(self, word, handler):
341 if word in self.numhandlers and handler in self.numhandlers[word]:
342 self.numhandlers[word].remove(handler)
343 def hasnumhook(self, word):
344 return word in self.numhandlers and len(self.numhandlers[word]) != 0
345 def getnumhook(self, word):
346 return self.numhandlers[word]
347
348 def hookchan(self, chan, handler):
349 try:
350 self.chanhandlers[chan].append(handler)
351 except:
352 self.chanhandlers[chan] = [handler]
353 def unhookchan(self, chan, handler):
354 if chan in self.chanhandlers and handler in self.chanhandlers[chan]:
355 self.chanhandlers[chan].remove(handler)
356 def haschanhook(self, chan):
357 return chan in self.chanhandlers and len(self.chanhandlers[chan]) != 0
358 def getchanhook(self, chan):
359 return self.chanhandlers[chan]
360
361 def hookexception(self, exc, handler):
362 self.exceptionhandlers.append((exc, handler))
363 def unhookexception(self, exc, handler):
364 self.exceptionhandlers.remove((exc, handler))
365 def hasexceptionhook(self, exc):
366 return any((True for x,h in self.exceptionhandlers if isinstance(exc, x)))
367 def getexceptionhook(self, exc):
368 return (h for x,h in self.exceptionhandlers if isinstance(exc, x))
369
370
371def dbsetup():
372 main.db = None
373 main.dbs = []
374 dbtype = cfg.get('erebus', 'dbtype', 'mysql')
375 if dbtype == 'mysql':
376 _dbsetup_mysql()
377 elif dbtype == 'sqlite':
378 _dbsetup_sqlite()
379 else:
380 main.log('*', '!', 'Unknown dbtype in config: %s' % (dbtype))
381
382def _dbsetup_mysql():
383 global db_api
384 import MySQLdb as db_api, MySQLdb.cursors
385 for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
386 main.dbs.append(db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
387 main.db = db_api.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
388
389def _dbsetup_sqlite():
390 global db_api
391 import sqlite3 as db_api
392 for i in range(cfg.get('erebus', 'num_db_connections', 2)):
393 main.db = db_api.connect(cfg.dbhost)
394 main.db.row_factory = db_api.Row
395 main.db.isolation_level = None
396 main.dbs.append(main.db)
397
398def setup():
399 global cfg, main
400
401 cfg = config.Config('bot.config')
402
403 if cfg.getboolean('debug', 'gc'):
404 gc.set_debug(gc.DEBUG_LEAK)
405
406 pidfile = open(cfg.pidfile, 'w')
407 pidfile.write(str(os.getpid()))
408 pidfile.close()
409
410 main = Erebus(cfg)
411 dbsetup()
412
413 autoloads = [mod for mod, yes in cfg.items('autoloads') if int(yes) == 1]
414 for mod in autoloads:
415 ctlmod.load(main, mod)
416
417 c = main.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
418 if c:
419 rows = c.fetchall()
420 c.close()
421 for row in rows:
422 main.newbot(row['nick'], row['user'], row['bind'], row['authname'], row['authpass'], cfg.host, cfg.port, cfg.realname)
423 main.connectall()
424
425def loop():
426 poready = main.poll()
427 for fileno in poready:
428 try:
429 data = main.fd(fileno).getdata()
430 except:
431 main.log('*', '!', 'Error receiving data: getdata raised exception for socket %d, closing' % (fileno))
432 traceback.print_exc()
433 data = None
434 if data is None:
435 main.fd(fileno).close()
436 else:
437 for line in data:
438 if cfg.getboolean('debug', 'io'):
439 main.log(str(main.fd(fileno)), 'I', line)
440 try:
441 main.fd(fileno).parse(line)
442 except:
443 main.log('*', '!', 'Error receiving data: parse raised exception for socket %d data %r, ignoring' % (fileno, line))
444 traceback.print_exc()
445 if main.mustquit is not None:
446 main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
447 raise main.mustquit
448
449if __name__ == '__main__':
450 try: os.rename('logfile', 'oldlogs/%s' % (time.time()))
451 except: pass
452 sys.stdout = open('logfile', 'w', 1)
453 sys.stderr = sys.stdout
454 setup()
455 while True: loop()