]>
jfr.im git - erebus.git/blob - erebus.py
2 # vim: fileencoding=utf-8
4 # Erebus IRC bot - Author: John Runyon
7 from __future__
import print_function
9 import os
, sys
, select
, MySQLdb
, MySQLdb
.cursors
, time
, traceback
, random
, gc
10 import bot
, config
, ctlmod
12 class Erebus(object): #singleton to pass around
21 exceptionhandlers
= [] # list of (Exception_class, handler_function) tuples
26 def __init__(self
, nick
, auth
=None):
31 self
.auth
= auth
.lower()
36 def msg(self
, *args
, **kwargs
):
37 main
.randbot().msg(self
, *args
, **kwargs
)
38 def slowmsg(self
, *args
, **kwargs
):
39 main
.randbot().slowmsg(self
, *args
, **kwargs
)
40 def fastmsg(self
, *args
, **kwargs
):
41 main
.randbot().fastmsg(self
, *args
, **kwargs
)
44 return self
.auth
is not None
46 def authed(self
, auth
):
47 if auth
== '0': self
.auth
= None
48 else: self
.auth
= auth
.lower()
55 c
= main
.query("SELECT level FROM users WHERE auth = %s", (self
.auth
,))
59 self
.glevel
= row
['level']
66 def setlevel(self
, level
, savetodb
=True):
69 c
= main
.query("REPLACE INTO users (auth, level) VALUES (%s, %s)", (self
.auth
, level
))
71 c
= main
.query("DELETE FROM users WHERE auth = %s", (self
.auth
,))
72 if c
== 0: # no rows affected
84 if chan
not in self
.chans
: self
.chans
.append(chan
)
87 self
.chans
.remove(chan
)
89 return len(self
.chans
) == 0
92 def nickchange(self
, newnick
):
95 def __str__(self
): return self
.nick
96 def __repr__(self
): return "<User %r (%d)>" % (self
.nick
, self
.glevel
)
98 class Channel(object):
99 def __init__(self
, name
, bot
):
108 c
= main
.query("SELECT user, level FROM chusers WHERE chan = %s", (self
.name
,))
111 while row
is not None:
112 self
.levels
[row
['user']] = row
['level']
116 def msg(self
, *args
, **kwargs
):
117 self
.bot
.msg(self
, *args
, **kwargs
)
118 def slowmsg(self
, *args
, **kwargs
):
119 self
.bot
.slowmsg(self
, *args
, **kwargs
)
120 def fastmsg(self
, *args
, **kwargs
):
121 self
.bot
.fastmsg(self
, *args
, **kwargs
)
123 def levelof(self
, auth
):
127 if auth
in self
.levels
:
128 return self
.levels
[auth
]
132 def setlevel(self
, auth
, level
, savetodb
=True):
135 c
= main
.query("REPLACE INTO chusers (chan, user, level) VALUES (%s, %s, %s)", (self
.name
, auth
, level
))
137 self
.levels
[auth
] = level
142 self
.levels
[auth
] = level
145 def userjoin(self
, user
, level
=None):
146 if user
not in self
.users
: self
.users
.append(user
)
147 if level
== 'op' and user
not in self
.ops
: self
.ops
.append(user
)
148 if level
== 'voice' and user
not in self
.voices
: self
.voices
.append(user
)
149 def userpart(self
, user
):
150 if user
in self
.ops
: self
.ops
.remove(user
)
151 if user
in self
.voices
: self
.voices
.remove(user
)
152 if user
in self
.users
: self
.users
.remove(user
)
154 def userop(self
, user
):
155 if user
in self
.users
and user
not in self
.ops
: self
.ops
.append(user
)
156 def uservoice(self
, user
):
157 if user
in self
.users
and user
not in self
.voices
: self
.voices
.append(user
)
158 def userdeop(self
, user
):
159 if user
in self
.ops
: self
.ops
.remove(user
)
160 def userdevoice(self
, user
):
161 if user
in self
.voices
: self
.voices
.remove(user
)
163 def __str__(self
): return self
.name
164 def __repr__(self
): return "<Channel %r>" % (self
.name
)
166 def __init__(self
, cfg
):
168 self
.starttime
= time
.time()
170 self
.trigger
= cfg
.trigger
171 if os
.name
== "posix":
173 self
.po
= select
.poll()
174 else: # f.e. os.name == "nt" (Windows)
175 self
.potype
= "select"
178 def query(self
, *args
, **kwargs
):
179 if 'noretry' in kwargs
:
180 noretry
= kwargs
['noretry']
181 del kwargs
['noretry']
185 self
.log("[SQL]", "?", "query(%s, %s)" % (', '.join([repr(i
) for i
in args
]), ', '.join([str(key
)+"="+repr(kwargs
[key
]) for key
in kwargs
])))
187 curs
= self
.db
.cursor()
188 res
= curs
.execute(*args
, **kwargs
)
193 except MySQLdb
.MySQLError
as e
:
194 self
.log("[SQL]", "!", "MySQL error! %r" % (e
))
197 return self
.query(*args
, noretry
=True, **kwargs
)
201 def querycb(self
, cb
, *args
, **kwargs
):
203 cb(self
.query(*args
, **kwargs
))
204 threading
.Thread(target
=run_query
).start()
206 def newbot(self
, nick
, user
, bind
, authname
, authpass
, server
, port
, realname
):
207 if bind
is None: bind
= ''
208 obj
= bot
.Bot(self
, nick
, user
, bind
, authname
, authpass
, server
, port
, realname
)
209 self
.bots
[nick
.lower()] = obj
211 def newfd(self
, obj
, fileno
):
212 self
.fds
[fileno
] = obj
213 if self
.potype
== "poll":
214 self
.po
.register(fileno
, select
.POLLIN
)
215 elif self
.potype
== "select":
216 self
.fdlist
.append(fileno
)
217 def delfd(self
, fileno
):
219 if self
.potype
== "poll":
220 self
.po
.unregister(fileno
)
221 elif self
.potype
== "select":
222 self
.fdlist
.remove(fileno
)
224 def bot(self
, name
): #get Bot() by name (nick)
225 return self
.bots
[name
.lower()]
226 def fd(self
, fileno
): #get Bot() by fd/fileno
227 return self
.fds
[fileno
]
228 def randbot(self
): #get Bot() randomly
229 return self
.bots
[random
.choice(list(self
.bots
.keys()))]
231 def user(self
, _nick
, send_who
=False, create
=True):
234 if send_who
and (nick
not in self
.users
or not self
.users
[nick
].isauthed()):
235 self
.randbot().conn
.send("WHO %s n%%ant,1" % (nick
))
237 if nick
in self
.users
:
238 return self
.users
[nick
]
240 user
= self
.User(_nick
)
241 self
.users
[nick
] = user
245 def channel(self
, name
): #get Channel() by name
246 if name
.lower() in self
.chans
:
247 return self
.chans
[name
.lower()]
251 def newchannel(self
, bot
, name
):
252 chan
= self
.Channel(name
.lower(), bot
)
253 self
.chans
[name
.lower()] = chan
258 if self
.potype
== "poll":
259 pollres
= self
.po
.poll(timeout_seconds
* 1000)
260 return [fd
for (fd
, ev
) in pollres
]
261 elif self
.potype
== "select":
262 return select
.select(self
.fdlist
, [], [], timeout_seconds
)[0]
264 def connectall(self
):
265 for bot
in self
.bots
.values():
266 if bot
.conn
.state
== 0:
269 def module(self
, name
):
270 return ctlmod
.modules
[name
]
272 def log(self
, source
, level
, message
):
273 print("%09.3f %s [%s] %s" % (time
.time() % 100000, source
, level
, message
))
275 def getuserbyauth(self
, auth
):
276 return [u
for u
in self
.users
.values() if u
.auth
== auth
.lower()]
279 """Get a DB object. The object must be returned to the pool after us, using returndb()."""
280 return self
.dbs
.pop()
282 def returndb(self
, db
):
286 def hook(self
, word
, handler
):
288 self
.msghandlers
[word
].append(handler
)
290 self
.msghandlers
[word
] = [handler
]
291 def unhook(self
, word
, handler
):
292 if word
in self
.msghandlers
and handler
in self
.msghandlers
[word
]:
293 self
.msghandlers
[word
].remove(handler
)
294 def hashook(self
, word
):
295 return word
in self
.msghandlers
and len(self
.msghandlers
[word
]) != 0
296 def gethook(self
, word
):
297 return self
.msghandlers
[word
]
299 def hooknum(self
, word
, handler
):
301 self
.numhandlers
[word
].append(handler
)
303 self
.numhandlers
[word
] = [handler
]
304 def unhooknum(self
, word
, handler
):
305 if word
in self
.numhandlers
and handler
in self
.numhandlers
[word
]:
306 self
.numhandlers
[word
].remove(handler
)
307 def hasnumhook(self
, word
):
308 return word
in self
.numhandlers
and len(self
.numhandlers
[word
]) != 0
309 def getnumhook(self
, word
):
310 return self
.numhandlers
[word
]
312 def hookchan(self
, chan
, handler
):
314 self
.chanhandlers
[chan
].append(handler
)
316 self
.chanhandlers
[chan
] = [handler
]
317 def unhookchan(self
, chan
, handler
):
318 if chan
in self
.chanhandlers
and handler
in self
.chanhandlers
[chan
]:
319 self
.chanhandlers
[chan
].remove(handler
)
320 def haschanhook(self
, chan
):
321 return chan
in self
.chanhandlers
and len(self
.chanhandlers
[chan
]) != 0
322 def getchanhook(self
, chan
):
323 return self
.chanhandlers
[chan
]
325 def hookexception(self
, exc
, handler
):
326 self
.exceptionhandlers
.append((exc
, handler
))
327 def unhookexception(self
, exc
, handler
):
328 self
.exceptionhandlers
.remove((exc
, handler
))
329 def hasexceptionhook(self
, exc
):
330 return any((True for x
,h
in self
.exceptionhandlers
if isinstance(exc
, x
)))
331 def getexceptionhook(self
, exc
):
332 return (h
for x
,h
in self
.exceptionhandlers
if isinstance(exc
, x
))
338 for i
in range(cfg
.get('erebus', 'num_db_connections', 2)-1):
339 main
.dbs
.append(MySQLdb
.connect(host
=cfg
.dbhost
, user
=cfg
.dbuser
, passwd
=cfg
.dbpass
, db
=cfg
.dbname
, cursorclass
=MySQLdb
.cursors
.DictCursor
))
340 main
.db
= MySQLdb
.connect(host
=cfg
.dbhost
, user
=cfg
.dbuser
, passwd
=cfg
.dbpass
, db
=cfg
.dbname
, cursorclass
=MySQLdb
.cursors
.DictCursor
)
345 cfg
= config
.Config('bot.config')
347 if cfg
.getboolean('debug', 'gc'):
348 gc
.set_debug(gc
.DEBUG_LEAK
)
350 pidfile
= open(cfg
.pidfile
, 'w')
351 pidfile
.write(str(os
.getpid()))
357 autoloads
= [mod
for mod
, yes
in cfg
.items('autoloads') if int(yes
) == 1]
358 for mod
in autoloads
:
359 ctlmod
.load(main
, mod
)
361 c
= main
.query("SELECT nick, user, bind, authname, authpass FROM bots WHERE active = 1")
366 main
.newbot(row
['nick'], row
['user'], row
['bind'], row
['authname'], row
['authpass'], cfg
.host
, cfg
.port
, cfg
.realname
)
370 poready
= main
.poll()
371 for fileno
in poready
:
373 data
= main
.fd(fileno
).getdata()
375 main
.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno
))
376 traceback
.print_exc()
379 main
.fd(fileno
).close()
382 if cfg
.getboolean('debug', 'io'):
383 main
.log(str(main
.fd(fileno
)), 'I', line
)
385 main
.fd(fileno
).parse(line
)
387 main
.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno
, line
))
388 traceback
.print_exc()
389 if main
.mustquit
is not None:
390 main
.log('*', '!', 'Core exiting due to: %s' % (main
.mustquit
))
393 if __name__
== '__main__':
394 try: os
.rename('logfile', 'oldlogs/%s' % (time
.time()))
396 sys
.stdout
= open('logfile', 'w', 1)
397 sys
.stderr
= sys
.stdout