From: John Runyon Date: Thu, 2 Nov 2023 00:02:32 +0000 (-0600) Subject: add ability to hook extra sockets X-Git-Url: https://jfr.im/git/erebus.git/commitdiff_plain/9d44d267b7cb9639739979dc5f0c4b7828be9a4f add ability to hook extra sockets --- diff --git a/ctlmod.py b/ctlmod.py index c273638..3e30892 100644 --- a/ctlmod.py +++ b/ctlmod.py @@ -4,7 +4,7 @@ from __future__ import print_function -import sys, time, importlib +import sys, time, importlib, traceback import modlib if sys.version_info.major >= 3: @@ -31,6 +31,8 @@ def load(parent, modname, dependent=False): print("failed: %s)" % (modstatus), end=' ') else: print("failed: %s." % (modstatus)) + if isinstance(modstatus, modlib.error) and isinstance(modstatus.errormsg, BaseException): + traceback.print_exception(modstatus.errormsg) elif modstatus == True: if dependent: print("OK)", end=' ') @@ -85,7 +87,10 @@ def _load(parent, modname, dependent=False): #swallow errors loading - softdeps are preferred, not required - ret = mod.modstart(parent) + try: + ret = mod.modstart(parent) + except Exception as e: + return modlib.error(e) if ret is None: ret = True if not ret: diff --git a/erebus.py b/erebus.py index 91e01bf..47f9a59 100644 --- a/erebus.py +++ b/erebus.py @@ -6,7 +6,7 @@ from __future__ import print_function -import os, sys, select, MySQLdb, MySQLdb.cursors, time, random, gc +import os, sys, select, MySQLdb, MySQLdb.cursors, time, traceback, random, gc import bot, config, ctlmod class Erebus(object): #singleton to pass around @@ -214,6 +214,12 @@ class Erebus(object): #singleton to pass around self.po.register(fileno, select.POLLIN) elif self.potype == "select": self.fdlist.append(fileno) + def delfd(self, fileno): + del self.fds[fileno] + if self.potype == "poll": + self.po.unregister(fileno) + elif self.potype == "select": + self.fdlist.remove(fileno) def bot(self, name): #get Bot() by name (nick) return self.bots[name.lower()] @@ -363,8 +369,21 @@ def setup(): def loop(): poready = main.poll() for fileno in poready: - for line in main.fd(fileno).getdata(): - main.fd(fileno).parse(line) + try: + data = main.fd(fileno).getdata() + except: + main.log('*', '!', 'Super-mega-emergency: getdata raised exception for socket %d' % (fileno)) + traceback.print_exc() + data = None + if data is None: + main.fd(fileno).close() + else: + for line in data: + try: + main.fd(fileno).parse(line) + except: + main.log('*', '!', 'Super-mega-emergency: parse raised exception for socket %d data %r' % (fileno, line)) + traceback.print_exc() if main.mustquit is not None: main.log('*', '!', 'Core exiting due to: %s' % (main.mustquit)) raise main.mustquit diff --git a/modlib.py b/modlib.py index ddb201a..a04dd78 100644 --- a/modlib.py +++ b/modlib.py @@ -4,6 +4,7 @@ # This file is released into the public domain; see http://unlicense.org/ import sys +import socket from functools import wraps if sys.version_info.major < 3: @@ -59,10 +60,12 @@ class modlib(object): WRONGARGS = "Wrong number of arguments." def __init__(self, name): - self.hooks = {} - self.chanhooks = {} - self.exceptionhooks = [] - self.numhooks = {} + self.hooks = {} # {command:handler} + self.chanhooks = {} # {channel:handler} + self.exceptionhooks = [] # [(exception,handler)] + self.numhooks = {} # {numeric:handler} + self.sockhooks = [] # [(af,ty,address,handler_class)] + self.sockets = [] # [(sock,obj)] self.helps = [] self.parent = None @@ -86,6 +89,8 @@ class modlib(object): parent.hookexception(exc, func) for num, func in self.numhooks.items(): parent.hooknum(num, func) + for hookdata in self.sockhooks: + self._create_socket(*hookdata) for func, args, kwargs in self.helps: try: @@ -103,6 +108,8 @@ class modlib(object): parent.unhookexception(exc, func) for num, func in self.numhooks.items(): parent.unhooknum(num, func) + for sock, obj in self.sockets: + self._destroy_socket(sock, obj) for func, args, kwargs in self.helps: try: @@ -163,6 +170,36 @@ class modlib(object): return func return realhook + def bind_tcp(self, host, port): + return self._hooksocket(socket.AF_INET, socket.SOCK_STREAM, (host, port)) + def bind_udp(self, host, port): + return self._hooksocket(socket.AF_INET, socket.SOCK_DGRAM, (host, port)) + def bind_unix(self, path): + return self._hooksocket(socket.AF_UNIX, socket.SOCK_STREAM, path) + def _hooksocket(self, af, ty, address): + def realhook(cls): + if not (hasattr(cls, 'getdata') and callable(cls.getdata)): + # Check early that the object implements getdata. + # If getdata ever returns a non-empty list, then a parse method must also exist, but we don't check that. + raise Exception('Attempted to hook a socket without a class to process data') + self.sockhooks.append((af, ty, address, cls)) + if self.parent is not None: + self._create_socket(af, ty, address, cls) + return cls + return realhook + def _create_socket(self, af, ty, address, cls): + ty = ty | socket.SOCK_NONBLOCK + sock = socket.socket(af, ty) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + obj = _ListenSocket(self, sock, cls) + self.sockets.append((sock,obj)) + sock.listen(5) + self.parent.newfd(obj, sock.fileno()) + self.parent.log(repr(obj), '?', 'Socket ready to accept new connections') + def _destroy_socket(self, sock, obj): + obj.close() + def mod(self, modname): if self.parent is not None: return self.parent.module(modname) @@ -211,3 +248,44 @@ class modlib(object): self.helps.append((func, args, kwargs)) return func return realhook + +class _ListenSocket(object): + def __init__(self, lib, sock, cls): + self.clients = [] + self.lib = lib + self.sock = sock + self.cls = cls + + def _make_closer(self, obj, client): + def close(): + print(repr(self), repr(obj)) + self.lib.parent.log(repr(self), '?', 'Closing child socket %d' % (client.fileno())) + try: + obj.closing() + except AttributeError: + pass + self.lib.parent.delfd(client.fileno()) + client.shutdown(socket.SHUT_RDWR) + client.close() + self.clients.remove((client,obj)) + return close + + def getdata(self): + client, addr = self.sock.accept() + obj = self.cls(client) + obj.close = self._make_closer(obj, client) + self.lib.parent.log(repr(self), '?', 'New connection %d from %s' % (client.fileno(), addr)) + self.clients.append((client,obj)) + self.lib.parent.newfd(obj, client.fileno()) + return [] + + def close(self): + self.lib.parent.log(repr(self), '?', 'Socket closing') + if self.sock.fileno() != -1: + self.lib.parent.delfd(self.sock.fileno()) + self.sock.shutdown(socket.SHUT_RDWR) + self.sock.close() + for client, obj in self.clients: + obj.close() + + def __repr__(self): return '<_ListenSocket #%d>' % (self.sock.fileno()) diff --git a/modules/basic_socket.py b/modules/basic_socket.py new file mode 100644 index 0000000..3d3b7e9 --- /dev/null +++ b/modules/basic_socket.py @@ -0,0 +1,75 @@ +# Erebus IRC bot - Author: Erebus Team +# vim: fileencoding=utf-8 +# This file is released into the public domain; see http://unlicense.org/ + +# module info +modinfo = { + 'author': 'Erebus Team', + 'license': 'public domain', + 'compatible': [0], + 'depends': [], + 'softdeps': ['help'], +} + +# preamble +import modlib +lib = modlib.modlib(__name__) +modstart = lib.modstart +modstop = lib.modstop + +# module code + +# Note: bind_* does all of the following: +# - create a socket `sock = socket.socket()` +# - bind the socket `sock.bind()` +# - listen on the socket `sock.listen()` +# - accept `sock.accept()` +# +# Once a connection is accepted, your class is instantiated with the client socket. +# - When data comes in on the client socket, your `getdata` method will be called. It should return a list of strings. +# - For each element in the list returned by `getdata`, `parse` will be called. +# - When the socket is being closed by the bot (f.e. your module is unloaded), the optional method `closing` will be called. +# Then the bot will call `sock.shutdown()` and `sock.close()` for you. +# XXX error handling? what happens when the other side closes the socket? +# +# You can interact with the rest of the bot through `lib.parent`. +@lib.bind_tcp('0.0.0.0', 12543) +class BasicServer(object): + def __init__(self, sock): + self.buffer = b'' + self.sock = sock + + def getdata(self): + recvd = self.sock.recv(8192) + if recvd == b"": + if len(self.buffer) != 0: + # Process what's left in the buffer. We'll get called again after. + remaining_buf = self.buffer.decode('utf-8', 'backslashreplace') + self.buffer = b"" + return [remaining_buf] + else: + # Nothing left in the buffer. Return None to signal the core to close this socket. + return None + self.buffer += recvd + lines = [] + + while b"\n" in self.buffer: + pieces = self.buffer.split(b"\n", 1) + s = pieces[0].decode('utf-8', 'backslashreplace').rstrip("\r") + lines.append(pieces[0].decode('utf-8', 'backslashreplace')) + self.buffer = pieces[1] + + return lines + + def parse(self, line): + peer = self.sock.getpeername() + lib.parent.randbot().msg('#', "%s:%d says: %s" % (peer[0], peer[1], line)) + + def send(self, line): + self.socket.sendall(line.encode('utf-8', 'backslashreplace')+b"\r\n") + + def _getsockerr(self): + try: # SO_ERROR might not exist on all platforms + return self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + except: + return None