]> jfr.im git - erebus.git/blobdiff - bot.py
bot - allow hooks for numerics/commands that get sent without source
[erebus.git] / bot.py
diff --git a/bot.py b/bot.py
index 3e02f9443007e9b2d07124d6f902821ff54a6102..793e27f6b6c923ad162e0f70ee41317c105a40b1 100644 (file)
--- a/bot.py
+++ b/bot.py
@@ -4,11 +4,9 @@
 # Erebus IRC bot - Author: John Runyon
 # "Bot" and "BotConnection" classes (handling a specific "arm")
 
-import socket, sys, time, threading, os, random
+import socket, sys, time, threading, os, random, struct
 from collections import deque
 
-MAXLEN = 400 # arbitrary max length of a command generated by Bot.msg functions
-
 if sys.version_info.major < 3:
        timerbase = threading._Timer
        stringbase = basestring
@@ -38,6 +36,11 @@ class Bot(object):
                self.authname = authname
                self.authpass = authpass
 
+               self.connecttime = 0 # time at which we received numeric 001
+               self.server = server # the address we try to (re-)connect to
+               self.port = port
+               self.servername = server # the name of the server we got connected to
+
                curs = self.parent.query("SELECT chname FROM chans WHERE bot = %s AND active = 1", (self.permnick,))
                if curs:
                        chansres = curs.fetchall()
@@ -55,6 +58,7 @@ class Bot(object):
                self.slowmsgqueue = deque()
                self._makemsgtimer()
                self._msgtimer.start()
+               self.joined_chans = False
 
        def __del__(self):
                try:
@@ -88,20 +92,22 @@ class Bot(object):
                        self.log('I', line)
                pieces = line.split()
 
+               if pieces[0][0] == ":":
+                       numeric = pieces[1]
+               else:
+                       numeric = pieces[0]
+
                # dispatch dict
-               zero = { #things to look for without source
-                       'NOTICE': self._gotconnected,
-                       'PING': self._gotping,
-                       'ERROR': self._goterror,
-               }
-               one = { #things to look for after source
+               dispatch = { #things to look for after source
                        'NOTICE': self._gotconnected,
                        '001': self._got001,
+                       '004': self._got004,
                        '376': self._gotRegistered,
                        '422': self._gotRegistered,
                        'PRIVMSG': self._gotprivmsg,
                        '353': self._got353, #NAMES
                        '354': self._got354, #WHO
+                       '396': self._gotHiddenHost, # hidden host has been set
                        '433': self._got433, #nick in use
                        'JOIN': self._gotjoin,
                        'PART': self._gotpart,
@@ -109,20 +115,20 @@ class Bot(object):
                        'QUIT': self._gotquit,
                        'NICK': self._gotnick,
                        'MODE': self._gotmode,
+                       'PING': self._gotping,
+                       'ERROR': self._goterror,
                }
 
-               if self.parent.hasnumhook(pieces[1]):
-                       hooks = self.parent.getnumhook(pieces[1])
+               if self.parent.hasnumhook(numeric):
+                       hooks = self.parent.getnumhook(numeric)
                        for callback in hooks:
                                try:
                                        callback(self, line)
                                except Exception:
                                        self.__debug_cbexception("numhook", line)
 
-               if pieces[0] in zero:
-                       zero[pieces[0]](pieces)
-               elif pieces[1] in one:
-                       one[pieces[1]](pieces)
+               if numeric in dispatch:
+                       dispatch[numeric](pieces)
 
        def _gotconnected(self, pieces):
                if not self.conn.registered():
@@ -139,10 +145,12 @@ class Bot(object):
                        curs = self.parent.query("UPDATE bots SET connected = 0")
                        curs.close()
                except: pass
-               sys.exit(2)
-               os._exit(2)
+               os._exit(2) # can't use sys.exit since we might be in a sub-thread
        def _got001(self, pieces):
-               pass # wait until the end of MOTD instead
+               # We wait until the end of MOTD instead to consider ourselves registered, but consider uptime as of 001
+               self.connecttime = time.time()
+       def _got004(self, pieces):
+               self.servername = pieces[3]
        def _gotRegistered(self, pieces):
                self.conn.registered(True)
 
@@ -151,9 +159,16 @@ class Bot(object):
 
                self.conn.send("MODE %s +x" % (pieces[2]))
                if self.authname is not None and self.authpass is not None:
-                       self.conn.send("AUTH %s %s" % (self.authname, self.authpass))
-               for c in self.chans:
-                       self.join(c.name)
+                       self.conn.send(self.parent.cfg.get('erebus', 'auth_command', "AUTH %s %s") % (self.authname, self.authpass))
+               if not self.parent.cfg.getboolean('erebus', 'wait_for_hidden_host'):
+                       for c in self.chans:
+                               self.join(c.name)
+                       self.joined_chans = True
+       def _gotHiddenHost(self, pieces):
+               if not self.joined_chans and self.parent.cfg.getboolean('erebus', 'wait_for_hidden_host'):
+                       for c in self.chans:
+                               self.join(c.name)
+                       self.joined_chans = True
        def _gotprivmsg(self, pieces):
                nick = pieces[0].split('!')[0][1:]
                user = self.parent.user(nick)
@@ -289,7 +304,7 @@ class Bot(object):
                        if msg.startswith("\001"): #ctcp
                                msg = msg.strip("\001")
                                if msg == "VERSION":
-                                       self.msg(user, "\001VERSION Erebus v%d.%d - http://github.com/zonidjan/erebus" % (self.parent.APIVERSION, self.parent.RELEASE))
+                                       self.msg(user, "\001VERSION Erebus v%d.%d - http://jfr.im/git/erebus.git" % (self.parent.APIVERSION, self.parent.RELEASE))
                                return
 
                triggerused = msg.startswith(self.parent.trigger)
@@ -374,54 +389,56 @@ class Bot(object):
                else:
                        self.msg(user, msg)
 
-       def msg(self, target, msg, truncate=False):
-               if self.parent.cfg.getboolean('erebus', 'nofakelag'): return self.fastmsg(target, msg)
-               cmd = self._formatmsg(target, msg)
-               if len(cmd) > MAXLEN:
-                       if not truncate:
-                               return False
-                       else:
-                               cmd = cmd[:MAXLEN]
-               if self.conn.exceeded or self.conn.bytessent+len(cmd) >= self.conn.recvq:
-                       self.msgqueue.append(cmd)
-               else:
-                       self.conn.send(cmd)
-               self.conn.exceeded = True
-               return True
+       """
+               Does the work for msg/slowmsg/fastmsg. Uses the append_callback to append to the correct queue.
 
-       def slowmsg(self, target, msg, truncate=False):
+               In the case of fastmsg, self.conn.exceeded may be True, however, in this case append_callback=self.conn.send, so it will still be sent immediately.
+       """
+       def _msg(self, target, msg, truncate, append_callback, msgtype):
                if self.parent.cfg.getboolean('erebus', 'nofakelag'): return self.fastmsg(target, msg)
-               cmd = self._formatmsg(target, msg)
-               if len(cmd) > MAXLEN:
+
+               cmd = self._formatmsg(target, msg, msgtype)
+               # The max length is much shorter than recvq (510) because of the length the server adds on about the source (us).
+               # If you know your hostmask, you can of course figure the exact length, but it's very difficult to reliably know your hostmask.
+               maxlen = (
+                       self.conn.recvq
+                       - 63 # max hostname len
+                       - 11 # max ident len
+                       - 3  # the symbols in :nick!user@host
+                       - len(self.nick)
+               )
+               if len(cmd) > maxlen:
                        if not truncate:
                                return False
                        else:
-                               cmd = cmd[:MAXLEN]
+                               cmd = cmd[:maxlen]
+
                if self.conn.exceeded or self.conn.bytessent+len(cmd) >= self.conn.recvq:
-                       self.slowmsgqueue.append(cmd)
+                       append_callback(cmd)
                else:
                        self.conn.send(cmd)
-               self.conn.exceeded = True
-               return True
 
-       def fastmsg(self, target, msg, truncate=False):
-               cmd = self._formatmsg(target, msg)
-               if len(cmd) > MAXLEN:
-                       if not truncate:
-                               return False
-                       else:
-                               cmd = cmd[:MAXLEN]
-               self.conn.send(cmd)
                self.conn.exceeded = True
                return True
 
-       def _formatmsg(self, target, msg):
+       def msg(self, target, msg, truncate=False, *, msgtype=None):
+               """msgtype must be a valid IRC command, i.e. NOTICE or PRIVMSG; or leave as None to use default"""
+               return self._msg(target, msg, truncate, self.msgqueue.append, msgtype)
+
+       def slowmsg(self, target, msg, truncate=False, *, msgtype=None):
+               return self._msg(target, msg, truncate, self.slowmsgqueue.append, msgtype)
+
+       def fastmsg(self, target, msg, truncate=False, *, msgtype=None):
+               return self._msg(target, msg, truncate, self.conn.send, msgtype)
+
+       def _formatmsg(self, target, msg, msgtype):
                if target is None or msg is None:
                        return self.__debug_nomsg(target, msg)
 
                target = str(target)
 
-               if target.startswith('#'): command = "PRIVMSG %s :%s" % (target, msg)
+               if msgtype is not None: command = "%s %s :%s" % (msgtype, target, msg)
+               elif target.startswith('#'): command = "PRIVMSG %s :%s" % (target, msg)
                else: command = "NOTICE %s :%s" % (target, msg)
 
                return command
@@ -481,12 +498,21 @@ class BotConnection(object):
                self.state = 0 # 0=disconnected, 1=registering, 2=connected
 
                self.bytessent = 0
-               self.recvq = 500
+               self.recvq = 510
                self.exceeded = False
                self._nowrite = False
 
        def connect(self):
-               self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+               if self.parent.parent.cfg.getboolean('erebus', 'tls'):
+                       import ssl
+                       undersocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                       context = ssl.create_default_context()
+                       self.socket = context.wrap_socket(undersocket, server_hostname=self.server)
+               else:
+                       self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+               self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) # Does Python make SOL_TCP portable? Who knows, it's not documented, and it appears to come from the _socket C lib.
+               self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 0, 0))
+               self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
                self.socket.bind((self.bind, 0))
                self.socket.connect((self.server, self.port))
                return True