]> jfr.im git - erebus.git/commitdiff
add ability to hook extra sockets
authorJohn Runyon <redacted>
Thu, 2 Nov 2023 00:02:32 +0000 (18:02 -0600)
committerJohn Runyon <redacted>
Thu, 2 Nov 2023 00:02:32 +0000 (18:02 -0600)
ctlmod.py
erebus.py
modlib.py
modules/basic_socket.py [new file with mode: 0644]

index c273638013bb96f94cddcc139397224c0a069577..3e30892d25fbb7c19e32c014c99cce82dbacd9c1 100644 (file)
--- a/ctlmod.py
+++ b/ctlmod.py
@@ -4,7 +4,7 @@
 
 from __future__ import print_function
 
-import sys, time, importlib
+import sys, time, importlib, traceback
 import modlib
 
 if sys.version_info.major >= 3:
@@ -31,6 +31,8 @@ def load(parent, modname, dependent=False):
                        print("failed: %s)" % (modstatus), end=' ')
                else:
                        print("failed: %s." % (modstatus))
+                       if isinstance(modstatus, modlib.error) and isinstance(modstatus.errormsg, BaseException):
+                               traceback.print_exception(modstatus.errormsg)
        elif modstatus == True:
                if dependent:
                        print("OK)", end=' ')
@@ -85,7 +87,10 @@ def _load(parent, modname, dependent=False):
                        #swallow errors loading - softdeps are preferred, not required
 
 
-               ret = mod.modstart(parent)
+               try:
+                       ret = mod.modstart(parent)
+               except Exception as e:
+                       return modlib.error(e)
                if ret is None:
                        ret = True
                if not ret:
index 91e01bfad02c4e91e0ac47c924811315ef3882c4..47f9a59d6a4633263166f82eb4664ead4e625f10 100644 (file)
--- a/erebus.py
+++ b/erebus.py
@@ -6,7 +6,7 @@
 
 from __future__ import print_function
 
-import os, sys, select, MySQLdb, MySQLdb.cursors, time, random, gc
+import os, sys, select, MySQLdb, MySQLdb.cursors, time, traceback, random, gc
 import bot, config, ctlmod
 
 class Erebus(object): #singleton to pass around
@@ -214,6 +214,12 @@ class Erebus(object): #singleton to pass around
                        self.po.register(fileno, select.POLLIN)
                elif self.potype == "select":
                        self.fdlist.append(fileno)
+       def delfd(self, fileno):
+               del self.fds[fileno]
+               if self.potype == "poll":
+                       self.po.unregister(fileno)
+               elif self.potype == "select":
+                       self.fdlist.remove(fileno)
 
        def bot(self, name): #get Bot() by name (nick)
                return self.bots[name.lower()]
@@ -363,8 +369,21 @@ def setup():
 def loop():
        poready = main.poll()
        for fileno in poready:
-               for line in main.fd(fileno).getdata():
-                       main.fd(fileno).parse(line)
+               try:
+                       data = main.fd(fileno).getdata()
+               except:
+                       main.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno))
+                       traceback.print_exc()
+                       data = None
+               if data is None:
+                       main.fd(fileno).close()
+               else:
+                       for line in data:
+                               try:
+                                       main.fd(fileno).parse(line)
+                               except:
+                                       main.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno, line))
+                                       traceback.print_exc()
        if main.mustquit is not None:
                main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit))
                raise main.mustquit
index ddb201a252b9cc70e02ccc0d335f6eb105b2fc83..a04dd7878e22898f222cb06318ddcf6569b7e502 100644 (file)
--- a/modlib.py
+++ b/modlib.py
@@ -4,6 +4,7 @@
 # This file is released into the public domain; see http://unlicense.org/
 
 import sys
+import socket
 from functools import wraps
 
 if sys.version_info.major < 3:
@@ -59,10 +60,12 @@ class modlib(object):
        WRONGARGS = "Wrong number of arguments."
 
        def __init__(self, name):
-               self.hooks = {}
-               self.chanhooks = {}
-               self.exceptionhooks = []
-               self.numhooks = {}
+               self.hooks = {} # {command:handler}
+               self.chanhooks = {} # {channel:handler}
+               self.exceptionhooks = [] # [(exception,handler)]
+               self.numhooks = {} # {numeric:handler}
+               self.sockhooks = [] # [(af,ty,address,handler_class)]
+               self.sockets = [] # [(sock,obj)]
                self.helps = []
                self.parent = None
 
@@ -86,6 +89,8 @@ class modlib(object):
                        parent.hookexception(exc, func)
                for num, func in self.numhooks.items():
                        parent.hooknum(num, func)
+               for hookdata in self.sockhooks:
+                       self._create_socket(*hookdata)
 
                for func, args, kwargs in self.helps:
                        try:
@@ -103,6 +108,8 @@ class modlib(object):
                        parent.unhookexception(exc, func)
                for num, func in self.numhooks.items():
                        parent.unhooknum(num, func)
+               for sock, obj in self.sockets:
+                       self._destroy_socket(sock, obj)
 
                for func, args, kwargs in self.helps:
                        try:
@@ -163,6 +170,36 @@ class modlib(object):
                        return func
                return realhook
 
+       def bind_tcp(self, host, port):
+               return self._hooksocket(socket.AF_INET, socket.SOCK_STREAM, (host, port))
+       def bind_udp(self, host, port):
+               return self._hooksocket(socket.AF_INET, socket.SOCK_DGRAM, (host, port))
+       def bind_unix(self, path):
+               return self._hooksocket(socket.AF_UNIX, socket.SOCK_STREAM, path)
+       def _hooksocket(self, af, ty, address):
+               def realhook(cls):
+                       if not (hasattr(cls, 'getdata') and callable(cls.getdata)):
+                               # Check early that the object implements getdata.
+                               # If getdata ever returns a non-empty list, then a parse method must also exist, but we don't check that.
+                               raise Exception('Attempted to hook a socket without a class to process data')
+                       self.sockhooks.append((af, ty, address, cls))
+                       if self.parent is not None:
+                               self._create_socket(af, ty, address, cls)
+                       return cls
+               return realhook
+       def _create_socket(self, af, ty, address, cls):
+               ty = ty | socket.SOCK_NONBLOCK
+               sock = socket.socket(af, ty)
+               sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+               sock.bind(address)
+               obj = _ListenSocket(self, sock, cls)
+               self.sockets.append((sock,obj))
+               sock.listen(5)
+               self.parent.newfd(obj, sock.fileno())
+               self.parent.log(repr(obj), '?', 'Socket ready to accept new connections')
+       def _destroy_socket(self, sock, obj):
+               obj.close()
+
        def mod(self, modname):
                if self.parent is not None:
                        return self.parent.module(modname)
@@ -211,3 +248,44 @@ class modlib(object):
                        self.helps.append((func, args, kwargs))
                        return func
                return realhook
+
+class _ListenSocket(object):
+       def __init__(self, lib, sock, cls):
+               self.clients = []
+               self.lib = lib
+               self.sock = sock
+               self.cls = cls
+
+       def _make_closer(self, obj, client):
+               def close():
+                       print(repr(self), repr(obj))
+                       self.lib.parent.log(repr(self), '?', 'Closing child socket %d' % (client.fileno()))
+                       try:
+                               obj.closing()
+                       except AttributeError:
+                               pass
+                       self.lib.parent.delfd(client.fileno())
+                       client.shutdown(socket.SHUT_RDWR)
+                       client.close()
+                       self.clients.remove((client,obj))
+               return close
+
+       def getdata(self):
+               client, addr = self.sock.accept()
+               obj = self.cls(client)
+               obj.close = self._make_closer(obj, client)
+               self.lib.parent.log(repr(self), '?', 'New connection %d from %s' % (client.fileno(), addr))
+               self.clients.append((client,obj))
+               self.lib.parent.newfd(obj, client.fileno())
+               return []
+
+       def close(self):
+               self.lib.parent.log(repr(self), '?', 'Socket closing')
+               if self.sock.fileno() != -1:
+                       self.lib.parent.delfd(self.sock.fileno())
+                       self.sock.shutdown(socket.SHUT_RDWR)
+                       self.sock.close()
+               for client, obj in self.clients:
+                       obj.close()
+
+       def __repr__(self): return '<_ListenSocket #%d>' % (self.sock.fileno())
diff --git a/modules/basic_socket.py b/modules/basic_socket.py
new file mode 100644 (file)
index 0000000..3d3b7e9
--- /dev/null
@@ -0,0 +1,75 @@
+# Erebus IRC bot - Author: Erebus Team
+# vim: fileencoding=utf-8
+# This file is released into the public domain; see http://unlicense.org/
+
+# module info
+modinfo = {
+       'author': 'Erebus Team',
+       'license': 'public domain',
+       'compatible': [0],
+       'depends': [],
+       'softdeps': ['help'],
+}
+
+# preamble
+import modlib
+lib = modlib.modlib(__name__)
+modstart = lib.modstart
+modstop = lib.modstop
+
+# module code
+
+# Note: bind_* does all of the following:
+# - create a socket `sock = socket.socket()`
+# - bind the socket `sock.bind()`
+# - listen on the socket `sock.listen()`
+# - accept `sock.accept()`
+#
+# Once a connection is accepted, your class is instantiated with the client socket.
+# - When data comes in on the client socket, your `getdata` method will be called. It should return a list of strings.
+# - For each element in the list returned by `getdata`, `parse` will be called.
+# - When the socket is being closed by the bot (f.e. your module is unloaded), the optional method `closing` will be called.
+#   Then the bot will call `sock.shutdown()` and `sock.close()` for you.
+# XXX error handling? what happens when the other side closes the socket?
+#
+# You can interact with the rest of the bot through `lib.parent`.
+@lib.bind_tcp('0.0.0.0', 12543)
+class BasicServer(object):
+       def __init__(self, sock):
+               self.buffer = b''
+               self.sock = sock
+
+       def getdata(self):
+               recvd = self.sock.recv(8192)
+               if recvd == b"":
+                       if len(self.buffer) != 0:
+                               # Process what's left in the buffer. We'll get called again after.
+                               remaining_buf = self.buffer.decode('utf-8', 'backslashreplace')
+                               self.buffer = b""
+                               return [remaining_buf]
+                       else:
+                               # Nothing left in the buffer. Return None to signal the core to close this socket.
+                               return None
+               self.buffer += recvd
+               lines = []
+
+               while b"\n" in self.buffer:
+                       pieces = self.buffer.split(b"\n", 1)
+                       s = pieces[0].decode('utf-8', 'backslashreplace').rstrip("\r")
+                       lines.append(pieces[0].decode('utf-8', 'backslashreplace'))
+                       self.buffer = pieces[1]
+
+               return lines
+
+       def parse(self, line):
+               peer = self.sock.getpeername()
+               lib.parent.randbot().msg('#', "%s:%d says: %s" % (peer[0], peer[1], line))
+
+       def send(self, line):
+               self.socket.sendall(line.encode('utf-8', 'backslashreplace')+b"\r\n")
+
+       def _getsockerr(self):
+               try: # SO_ERROR might not exist on all platforms
+                       return self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+               except:
+                       return None