]> jfr.im git - erebus.git/blob - erebus.py
update comments
[erebus.git] / erebus.py
1 #!/usr/bin/python
2 # vim: fileencoding=utf-8
3
4 # Erebus IRC bot - Author: John Runyon
5 # main startup code
6
7 from __future__ import print_function
8
9 import os, sys, select, MySQLdb, MySQLdb.cursors, time, traceback, random, gc
10 import bot, config, ctlmod
11
12 class Erebus(object): #singleton to pass around
13 APIVERSION = 0
14 RELEASE = 0
15
16 bots = {}
17 fds = {}
18 numhandlers = {}
19 msghandlers = {}
20 chanhandlers = {}
21 exceptionhandlers = [] # list of (Exception_class, handler_function) tuples
22 users = {}
23 chans = {}
24
25 class User(object):
26 def __init__(self, nick, auth=None):
27 self.nick = nick
28 if auth is None:
29 self.auth = None
30 else:
31 self.auth = auth.lower()
32 self.checklevel()
33
34 self.chans = []
35
36 def bind_bot(self, bot):
37 return main._BoundUser(self, bot)
38
39 def msg(self, *args, **kwargs):
40 main.randbot().msg(self, *args, **kwargs)
41 def slowmsg(self, *args, **kwargs):
42 main.randbot().slowmsg(self, *args, **kwargs)
43 def fastmsg(self, *args, **kwargs):
44 main.randbot().fastmsg(self, *args, **kwargs)
45
46 def isauthed(self):
47 return self.auth is not None
48
49 def authed(self, auth):
50 if auth == '0': self.auth = None
51 else: self.auth = auth.lower()
52 self.checklevel()
53
54 def checklevel(self):
55 if self.auth is None:
56 self.glevel = -1
57 else:
58 c = main.query("SELECT level FROM users WHERE auth = %s", (self.auth,))
59 if c:
60 row = c.fetchone()
61 if row is not None:
62 self.glevel = row['level']
63 else:
64 self.glevel = 0
65 else:
66 self.glevel = 0
67 return self.glevel
68
69 def setlevel(self, level, savetodb=True):
70 if savetodb:
71 if level != 0:
72 c = main.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self.auth, level))
73 else:
74 c = main.query("DELETE FROM users WHERE auth = %s", (self.auth,))
75 if c == 0: # no rows affected
76 c = True # is fine
77 if c:
78 self.glevel = level
79 return True
80 else:
81 return False
82 else:
83 self.glevel = level
84 return True
85
86 def join(self, chan):
87 if chan not in self.chans: self.chans.append(chan)
88 def part(self, chan):
89 try:
90 self.chans.remove(chan)
91 except: pass
92 return len(self.chans) == 0
93 def quit(self):
94 pass
95 def nickchange(self, newnick):
96 self.nick = newnick
97
98 def __str__(self): return self.nick
99 def __repr__(self): return "<User %r (%d)>" % (self.nick, self.glevel)
100
101 class _BoundUser(object):
102 def __init__(self, user, bot):
103 self.__dict__['_bound_user'] = user
104 self.__dict__['_bound_bot'] = bot
105 def __getattr__(self, name):
106 return getattr(self._bound_user, name)
107 def __setattr__(self, name, value):
108 setattr(self._bound_user, name, value)
109 def msg(self, *args, **kwargs):
110 self._bound_bot.msg(self._bound_user, *args, **kwargs)
111 def slowmsg(self, *args, **kwargs):
112 self._bound_bot.slowmsg(self._bound_user, *args, **kwargs)
113 def fastmsg(self, *args, **kwargs):
114 self._bound_bot.fastmsg(self._bound_user, *args, **kwargs)
115 def __repr__(self): return "<_BoundUser %r %r>" % (self._bound_user, self._bound_bot)
116
117 class Channel(object):
118 def __init__(self, name, bot):
119 self.name = name
120 self.bot = bot
121 self.levels = {}
122
123 self.users = []
124 self.voices = []
125 self.ops = []
126
127 c = main.query("SELECT user, level FROM chusers WHERE chan = %s", (self.name,))
128 if c:
129 row = c.fetchone()
130 while row is not None:
131 self.levels[row['user']] = row['level']
132 row = c.fetchone()
133
134
135 def msg(self, *args, **kwargs):
136 self.bot.msg(self, *args, **kwargs)
137 def slowmsg(self, *args, **kwargs):
138 self.bot.slowmsg(self, *args, **kwargs)
139 def fastmsg(self, *args, **kwargs):
140 self.bot.fastmsg(self, *args, **kwargs)
141
142 def levelof(self, auth):
143 if auth is None:
144 return 0
145 auth = auth.lower()
146 if auth in self.levels:
147 return self.levels[auth]
148 else:
149 return 0
150
151 def setlevel(self, auth, level, savetodb=True):
152 auth = auth.lower()
153 if savetodb:
154 c = main.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self.name, auth, level))
155 if c:
156 self.levels[auth] = level
157 return True
158 else:
159 return False
160 else:
161 self.levels[auth] = level
162 return True
163
164 def userjoin(self, user, level=None):
165 if user not in self.users: self.users.append(user)
166 if level == 'op' and user not in self.ops: self.ops.append(user)
167 if level == 'voice' and user not in self.voices: self.voices.append(user)
168 def userpart(self, user):
169 if user in self.ops: self.ops.remove(user)
170 if user in self.voices: self.voices.remove(user)
171 if user in self.users: self.users.remove(user)
172
173 def userop(self, user):
174 if user in self.users and user not in self.ops: self.ops.append(user)
175 def uservoice(self, user):
176 if user in self.users and user not in self.voices: self.voices.append(user)
177 def userdeop(self, user):
178 if user in self.ops: self.ops.remove(user)
179 def userdevoice(self, user):
180 if user in self.voices: self.voices.remove(user)
181
182 def __str__(self): return self.name
183 def __repr__(self): return "<Channel %r>" % (self.name)
184
185 def __init__(self, cfg):
186 self.mustquit = None
187 self.starttime = time.time()
188 self.cfg = cfg
189 self.trigger = cfg.trigger
190 if os.name == "posix":
191 self.potype = "poll"
192 self.po = select.poll()
193 else: # f.e. os.name == "nt" (Windows)
194 self.potype = "select"
195 self.fdlist = []
196
197 def query(self, *args, **kwargs):
198 if 'noretry' in kwargs:
199 noretry = kwargs['noretry']
200 del kwargs['noretry']
201 else:
202 noretry = False
203
204 self.log("[SQL]", "?", "query(%s, %s)" % (', '.join([repr(i) for i in args]), ', '.join([str(key)+"="+repr(kwargs[key]) for key in kwargs])))
205 try:
206 curs = self.db.cursor()
207 res = curs.execute(*args, **kwargs)
208 if res:
209 return curs
210 else:
211 return res
212 except MySQLdb.DataError as e:
213 self.log("[SQL]", ".", "MySQL DataError: %r" % (e))
214 return False
215 except MySQLdb.MySQLError as e:
216 self.log("[SQL]", "!", "MySQL error! %r" % (e))
217 if not noretry:
218 dbsetup()
219 return self.query(*args, noretry=True, **kwargs)
220 else:
221 raise e
222
223 def querycb(self, cb, *args, **kwargs):
224 def run_query():
225 cb(self.query(*args, **kwargs))
226 threading.Thread(target=run_query).start()
227
228 def newbot(self, nick, user, bind, authname, authpass, server, port, realname):
229 if bind is None: bind = ''
230 obj = bot.Bot(self, nick, user, bind, authname, authpass, server, port, realname)
231 self.bots[nick.lower()] = obj
232
233 def newfd(self, obj, fileno):
234 self.fds[fileno] = obj
235 if self.potype == "poll":
236 self.po.register(fileno, select.POLLIN)
237 elif self.potype == "select":
238 self.fdlist.append(fileno)
239 def delfd(self, fileno):
240 del self.fds[fileno]
241 if self.potype == "poll":
242 self.po.unregister(fileno)
243 elif self.potype == "select":
244 self.fdlist.remove(fileno)
245
246 def bot(self, name): #get Bot() by name (nick)
247 return self.bots[name.lower()]
248 def fd(self, fileno): #get Bot() by fd/fileno
249 return self.fds[fileno]
250 def randbot(self): #get Bot() randomly
251 return self.bots[random.choice(list(self.bots.keys()))]
252
253 def user(self, _nick, send_who=False, create=True):
254 nick = _nick.lower()
255
256 if send_who and (nick not in self.users or not self.users[nick].isauthed()):
257 self.randbot().conn.send("WHO %s n%%ant,1" % (nick))
258
259 if nick in self.users:
260 return self.users[nick]
261 elif create:
262 user = self.User(_nick)
263 self.users[nick] = user
264 return user
265 else:
266 return None
267 def channel(self, name): #get Channel() by name
268 if name.lower() in self.chans:
269 return self.chans[name.lower()]
270 else:
271 return None
272
273 def newchannel(self, bot, name):
274 chan = self.Channel(name.lower(), bot)
275 self.chans[name.lower()] = chan
276 return chan
277
278 def poll(self):
279 timeout_seconds = 30
280 if self.potype == "poll":
281 pollres = self.po.poll(timeout_seconds * 1000)
282 return [fd for (fd, ev) in pollres]
283 elif self.potype == "select":
284 return select.select(self.fdlist, [], [], timeout_seconds)[0]
285
286 def connectall(self):
287 for bot in self.bots.values():
288 if bot.conn.state == 0:
289 bot.connect()
290
291 def module(self, name):
292 return ctlmod.modules[name]
293
294 def log(self, source, level, message):
295 print("%09.3f %s [%s] %s" % (time.time() % 100000, source, level, message))
296
297 def getuserbyauth(self, auth):
298 return [u for u in self.users.values() if u.auth == auth.lower()]
299
300 def getdb(self):
301 """Get a DB object. The object must be returned to the pool after us, using returndb()."""
302 return self.dbs.pop()
303
304 def returndb(self, db):
305 self.dbs.append(db)
306
307 #bind functions
308 def hook(self, word, handler):
309 try:
310 self.msghandlers[word].append(handler)
311 except:
312 self.msghandlers[word] = [handler]
313 def unhook(self, word, handler):
314 if word in self.msghandlers and handler in self.msghandlers[word]:
315 self.msghandlers[word].remove(handler)
316 def hashook(self, word):
317 return word in self.msghandlers and len(self.msghandlers[word]) != 0
318 def gethook(self, word):
319 return self.msghandlers[word]
320
321 def hooknum(self, word, handler):
322 try:
323 self.numhandlers[word].append(handler)
324 except:
325 self.numhandlers[word] = [handler]
326 def unhooknum(self, word, handler):
327 if word in self.numhandlers and handler in self.numhandlers[word]:
328 self.numhandlers[word].remove(handler)
329 def hasnumhook(self, word):
330 return word in self.numhandlers and len(self.numhandlers[word]) != 0
331 def getnumhook(self, word):
332 return self.numhandlers[word]
333
334 def hookchan(self, chan, handler):
335 try:
336 self.chanhandlers[chan].append(handler)
337 except:
338 self.chanhandlers[chan] = [handler]
339 def unhookchan(self, chan, handler):
340 if chan in self.chanhandlers and handler in self.chanhandlers[chan]:
341 self.chanhandlers[chan].remove(handler)
342 def haschanhook(self, chan):
343 return chan in self.chanhandlers and len(self.chanhandlers[chan]) != 0
344 def getchanhook(self, chan):
345 return self.chanhandlers[chan]
346
347 def hookexception(self, exc, handler):
348 self.exceptionhandlers.append((exc, handler))
349 def unhookexception(self, exc, handler):
350 self.exceptionhandlers.remove((exc, handler))
351 def hasexceptionhook(self, exc):
352 return any((True for x,h in self.exceptionhandlers if isinstance(exc, x)))
353 def getexceptionhook(self, exc):
354 return (h for x,h in self.exceptionhandlers if isinstance(exc, x))
355
356
357 def dbsetup():
358 main.db = None
359 main.dbs = []
360 for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
361 main.dbs.append(MySQLdb.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
362 main.db = MySQLdb.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
363
364 def setup():
365 global cfg, main
366
367 cfg = config.Config('bot.config')
368
369 if cfg.getboolean('debug', 'gc'):
370 gc.set_debug(gc.DEBUG_LEAK)
371
372 pidfile = open(cfg.pidfile, 'w')
373 pidfile.write(str(os.getpid()))
374 pidfile.close()
375
376 main = Erebus(cfg)
377 dbsetup()
378
379 autoloads = [mod for mod, yes in cfg.items('autoloads') if int(yes) == 1]
380 for mod in autoloads:
381 ctlmod.load(main, mod)
382
383 c = main.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
384 if c:
385 rows = c.fetchall()
386 c.close()
387 for row in rows:
388 main.newbot(row['nick'], row['user'], row['bind'], row['authname'], row['authpass'], cfg.host, cfg.port, cfg.realname)
389 main.connectall()
390
391 def loop():
392 poready = main.poll()
393 for fileno in poready:
394 try:
395 data = main.fd(fileno).getdata()
396 except:
397 main.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno))
398 traceback.print_exc()
399 data = None
400 if data is None:
401 main.fd(fileno).close()
402 else:
403 for line in data:
404 if cfg.getboolean('debug', 'io'):
405 main.log(str(main.fd(fileno)), 'I', line)
406 try:
407 main.fd(fileno).parse(line)
408 except:
409 main.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno, line))
410 traceback.print_exc()
411 if main.mustquit is not None:
412 main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
413 raise main.mustquit
414
415 if __name__ == '__main__':
416 try: os.rename('logfile', 'oldlogs/%s' % (time.time()))
417 except: pass
418 sys.stdout = open('logfile', 'w', 1)
419 sys.stderr = sys.stdout
420 setup()
421 while True: loop()