]> jfr.im git - erebus.git/blob - erebus.py
use new bind_bot, make sure the bot-on-channel sends the replies
[erebus.git] / erebus.py
1 #!/usr/bin/python
2 # vim: fileencoding=utf-8
3
4 # Erebus IRC bot - Author: John Runyon
5 # main startup code
6
7 from __future__ import print_function
8
9 import os, sys, select, MySQLdb, MySQLdb.cursors, time, traceback, random, gc
10 import bot, config, ctlmod
11
12 class Erebus(object): #singleton to pass around
13 APIVERSION = 0
14 RELEASE = 0
15
16 bots = {}
17 fds = {}
18 numhandlers = {}
19 msghandlers = {}
20 chanhandlers = {}
21 exceptionhandlers = [] # list of (Exception_class, handler_function) tuples
22 users = {}
23 chans = {}
24
25 class User(object):
26 def __init__(self, nick, auth=None):
27 self.nick = nick
28 if auth is None:
29 self.auth = None
30 else:
31 self.auth = auth.lower()
32 self.checklevel()
33
34 self.chans = []
35
36 def bind_bot(self, bot):
37 return main._BoundUser(self, bot)
38
39 def msg(self, *args, **kwargs):
40 main.randbot().msg(self, *args, **kwargs)
41 def slowmsg(self, *args, **kwargs):
42 main.randbot().slowmsg(self, *args, **kwargs)
43 def fastmsg(self, *args, **kwargs):
44 main.randbot().fastmsg(self, *args, **kwargs)
45
46 def isauthed(self):
47 return self.auth is not None
48
49 def authed(self, auth):
50 if auth == '0': self.auth = None
51 else: self.auth = auth.lower()
52 self.checklevel()
53
54 def checklevel(self):
55 if self.auth is None:
56 self.glevel = -1
57 else:
58 c = main.query("SELECT level FROM users WHERE auth = %s", (self.auth,))
59 if c:
60 row = c.fetchone()
61 if row is not None:
62 self.glevel = row['level']
63 else:
64 self.glevel = 0
65 else:
66 self.glevel = 0
67 return self.glevel
68
69 def setlevel(self, level, savetodb=True):
70 if savetodb:
71 if level != 0:
72 c = main.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self.auth, level))
73 else:
74 c = main.query("DELETE FROM users WHERE auth = %s", (self.auth,))
75 if c == 0: # no rows affected
76 c = True # is fine
77 if c:
78 self.glevel = level
79 return True
80 else:
81 return False
82 else:
83 self.glevel = level
84 return True
85
86 def join(self, chan):
87 if chan not in self.chans: self.chans.append(chan)
88 def part(self, chan):
89 try:
90 self.chans.remove(chan)
91 except: pass
92 return len(self.chans) == 0
93 def quit(self):
94 pass
95 def nickchange(self, newnick):
96 self.nick = newnick
97
98 def __str__(self): return self.nick
99 def __repr__(self): return "<User %r (%d)>" % (self.nick, self.glevel)
100
101 class _BoundUser(object):
102 def __init__(self, user, bot):
103 self.__dict__['_bound_user'] = user
104 self.__dict__['_bound_bot'] = bot
105 def __getattr__(self, name):
106 return getattr(self._bound_user, name)
107 def __setattr__(self, name, value):
108 setattr(self._bound_user, name, value)
109 def msg(self, *args, **kwargs):
110 self._bound_bot.msg(self._bound_user, *args, **kwargs)
111 def slowmsg(self, *args, **kwargs):
112 self._bound_bot.slowmsg(self._bound_user, *args, **kwargs)
113 def fastmsg(self, *args, **kwargs):
114 self._bound_bot.fastmsg(self._bound_user, *args, **kwargs)
115 def __repr__(self): return "<_BoundUser %r %r>" % (self._bound_user, self._bound_bot)
116
117 class Channel(object):
118 def __init__(self, name, bot):
119 self.name = name
120 self.bot = bot
121 self.levels = {}
122
123 self.users = []
124 self.voices = []
125 self.ops = []
126
127 c = main.query("SELECT user, level FROM chusers WHERE chan = %s", (self.name,))
128 if c:
129 row = c.fetchone()
130 while row is not None:
131 self.levels[row['user']] = row['level']
132 row = c.fetchone()
133
134
135 def msg(self, *args, **kwargs):
136 self.bot.msg(self, *args, **kwargs)
137 def slowmsg(self, *args, **kwargs):
138 self.bot.slowmsg(self, *args, **kwargs)
139 def fastmsg(self, *args, **kwargs):
140 self.bot.fastmsg(self, *args, **kwargs)
141
142 def levelof(self, auth):
143 if auth is None:
144 return 0
145 auth = auth.lower()
146 if auth in self.levels:
147 return self.levels[auth]
148 else:
149 return 0
150
151 def setlevel(self, auth, level, savetodb=True):
152 auth = auth.lower()
153 if savetodb:
154 c = main.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self.name, auth, level))
155 if c:
156 self.levels[auth] = level
157 return True
158 else:
159 return False
160 else:
161 self.levels[auth] = level
162 return True
163
164 def userjoin(self, user, level=None):
165 if user not in self.users: self.users.append(user)
166 if level == 'op' and user not in self.ops: self.ops.append(user)
167 if level == 'voice' and user not in self.voices: self.voices.append(user)
168 def userpart(self, user):
169 if user in self.ops: self.ops.remove(user)
170 if user in self.voices: self.voices.remove(user)
171 if user in self.users: self.users.remove(user)
172
173 def userop(self, user):
174 if user in self.users and user not in self.ops: self.ops.append(user)
175 def uservoice(self, user):
176 if user in self.users and user not in self.voices: self.voices.append(user)
177 def userdeop(self, user):
178 if user in self.ops: self.ops.remove(user)
179 def userdevoice(self, user):
180 if user in self.voices: self.voices.remove(user)
181
182 def __str__(self): return self.name
183 def __repr__(self): return "<Channel %r>" % (self.name)
184
185 def __init__(self, cfg):
186 self.mustquit = None
187 self.starttime = time.time()
188 self.cfg = cfg
189 self.trigger = cfg.trigger
190 if os.name == "posix":
191 self.potype = "poll"
192 self.po = select.poll()
193 else: # f.e. os.name == "nt" (Windows)
194 self.potype = "select"
195 self.fdlist = []
196
197 def query(self, *args, **kwargs):
198 if 'noretry' in kwargs:
199 noretry = kwargs['noretry']
200 del kwargs['noretry']
201 else:
202 noretry = False
203
204 self.log("[SQL]", "?", "query(%s, %s)" % (', '.join([repr(i) for i in args]), ', '.join([str(key)+"="+repr(kwargs[key]) for key in kwargs])))
205 try:
206 curs = self.db.cursor()
207 res = curs.execute(*args, **kwargs)
208 if res:
209 return curs
210 else:
211 return res
212 except MySQLdb.MySQLError as e:
213 self.log("[SQL]", "!", "MySQL error! %r" % (e))
214 if not noretry:
215 dbsetup()
216 return self.query(*args, noretry=True, **kwargs)
217 else:
218 raise e
219
220 def querycb(self, cb, *args, **kwargs):
221 def run_query():
222 cb(self.query(*args, **kwargs))
223 threading.Thread(target=run_query).start()
224
225 def newbot(self, nick, user, bind, authname, authpass, server, port, realname):
226 if bind is None: bind = ''
227 obj = bot.Bot(self, nick, user, bind, authname, authpass, server, port, realname)
228 self.bots[nick.lower()] = obj
229
230 def newfd(self, obj, fileno):
231 self.fds[fileno] = obj
232 if self.potype == "poll":
233 self.po.register(fileno, select.POLLIN)
234 elif self.potype == "select":
235 self.fdlist.append(fileno)
236 def delfd(self, fileno):
237 del self.fds[fileno]
238 if self.potype == "poll":
239 self.po.unregister(fileno)
240 elif self.potype == "select":
241 self.fdlist.remove(fileno)
242
243 def bot(self, name): #get Bot() by name (nick)
244 return self.bots[name.lower()]
245 def fd(self, fileno): #get Bot() by fd/fileno
246 return self.fds[fileno]
247 def randbot(self): #get Bot() randomly
248 return self.bots[random.choice(list(self.bots.keys()))]
249
250 def user(self, _nick, send_who=False, create=True):
251 nick = _nick.lower()
252
253 if send_who and (nick not in self.users or not self.users[nick].isauthed()):
254 self.randbot().conn.send("WHO %s n%%ant,1" % (nick))
255
256 if nick in self.users:
257 return self.users[nick]
258 elif create:
259 user = self.User(_nick)
260 self.users[nick] = user
261 return user
262 else:
263 return None
264 def channel(self, name): #get Channel() by name
265 if name.lower() in self.chans:
266 return self.chans[name.lower()]
267 else:
268 return None
269
270 def newchannel(self, bot, name):
271 chan = self.Channel(name.lower(), bot)
272 self.chans[name.lower()] = chan
273 return chan
274
275 def poll(self):
276 timeout_seconds = 30
277 if self.potype == "poll":
278 pollres = self.po.poll(timeout_seconds * 1000)
279 return [fd for (fd, ev) in pollres]
280 elif self.potype == "select":
281 return select.select(self.fdlist, [], [], timeout_seconds)[0]
282
283 def connectall(self):
284 for bot in self.bots.values():
285 if bot.conn.state == 0:
286 bot.connect()
287
288 def module(self, name):
289 return ctlmod.modules[name]
290
291 def log(self, source, level, message):
292 print("%09.3f %s [%s] %s" % (time.time() % 100000, source, level, message))
293
294 def getuserbyauth(self, auth):
295 return [u for u in self.users.values() if u.auth == auth.lower()]
296
297 def getdb(self):
298 """Get a DB object. The object must be returned to the pool after us, using returndb()."""
299 return self.dbs.pop()
300
301 def returndb(self, db):
302 self.dbs.append(db)
303
304 #bind functions
305 def hook(self, word, handler):
306 try:
307 self.msghandlers[word].append(handler)
308 except:
309 self.msghandlers[word] = [handler]
310 def unhook(self, word, handler):
311 if word in self.msghandlers and handler in self.msghandlers[word]:
312 self.msghandlers[word].remove(handler)
313 def hashook(self, word):
314 return word in self.msghandlers and len(self.msghandlers[word]) != 0
315 def gethook(self, word):
316 return self.msghandlers[word]
317
318 def hooknum(self, word, handler):
319 try:
320 self.numhandlers[word].append(handler)
321 except:
322 self.numhandlers[word] = [handler]
323 def unhooknum(self, word, handler):
324 if word in self.numhandlers and handler in self.numhandlers[word]:
325 self.numhandlers[word].remove(handler)
326 def hasnumhook(self, word):
327 return word in self.numhandlers and len(self.numhandlers[word]) != 0
328 def getnumhook(self, word):
329 return self.numhandlers[word]
330
331 def hookchan(self, chan, handler):
332 try:
333 self.chanhandlers[chan].append(handler)
334 except:
335 self.chanhandlers[chan] = [handler]
336 def unhookchan(self, chan, handler):
337 if chan in self.chanhandlers and handler in self.chanhandlers[chan]:
338 self.chanhandlers[chan].remove(handler)
339 def haschanhook(self, chan):
340 return chan in self.chanhandlers and len(self.chanhandlers[chan]) != 0
341 def getchanhook(self, chan):
342 return self.chanhandlers[chan]
343
344 def hookexception(self, exc, handler):
345 self.exceptionhandlers.append((exc, handler))
346 def unhookexception(self, exc, handler):
347 self.exceptionhandlers.remove((exc, handler))
348 def hasexceptionhook(self, exc):
349 return any((True for x,h in self.exceptionhandlers if isinstance(exc, x)))
350 def getexceptionhook(self, exc):
351 return (h for x,h in self.exceptionhandlers if isinstance(exc, x))
352
353
354 def dbsetup():
355 main.db = None
356 main.dbs = []
357 for i in range(cfg.get('erebus', 'num_db_connections', 2)-1):
358 main.dbs.append(MySQLdb.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor))
359 main.db = MySQLdb.connect(host=cfg.dbhost, user=cfg.dbuser, passwd=cfg.dbpass, db=cfg.dbname, cursorclass=MySQLdb.cursors.DictCursor)
360
361 def setup():
362 global cfg, main
363
364 cfg = config.Config('bot.config')
365
366 if cfg.getboolean('debug', 'gc'):
367 gc.set_debug(gc.DEBUG_LEAK)
368
369 pidfile = open(cfg.pidfile, 'w')
370 pidfile.write(str(os.getpid()))
371 pidfile.close()
372
373 main = Erebus(cfg)
374 dbsetup()
375
376 autoloads = [mod for mod, yes in cfg.items('autoloads') if int(yes) == 1]
377 for mod in autoloads:
378 ctlmod.load(main, mod)
379
380 c = main.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
381 if c:
382 rows = c.fetchall()
383 c.close()
384 for row in rows:
385 main.newbot(row['nick'], row['user'], row['bind'], row['authname'], row['authpass'], cfg.host, cfg.port, cfg.realname)
386 main.connectall()
387
388 def loop():
389 poready = main.poll()
390 for fileno in poready:
391 try:
392 data = main.fd(fileno).getdata()
393 except:
394 main.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno))
395 traceback.print_exc()
396 data = None
397 if data is None:
398 main.fd(fileno).close()
399 else:
400 for line in data:
401 if cfg.getboolean('debug', 'io'):
402 main.log(str(main.fd(fileno)), 'I', line)
403 try:
404 main.fd(fileno).parse(line)
405 except:
406 main.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno, line))
407 traceback.print_exc()
408 if main.mustquit is not None:
409 main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
410 raise main.mustquit
411
412 if __name__ == '__main__':
413 try: os.rename('logfile', 'oldlogs/%s' % (time.time()))
414 except: pass
415 sys.stdout = open('logfile', 'w', 1)
416 sys.stderr = sys.stdout
417 setup()
418 while True: loop()