]> jfr.im git - erebus.git/commitdiff
add new abc for sockets
authorJohn Runyon <redacted>
Wed, 8 May 2024 11:58:37 +0000 (05:58 -0600)
committerJohn Runyon <redacted>
Wed, 8 May 2024 11:58:37 +0000 (05:58 -0600)
erebus.py
modlib.py
modules/sockets.py

index a93e1e0ee17c90b5d3b6d7c96cd818bc0f1c38cd..fb6fcc80fc703480887c2be351c368fb1a7e5669 100644 (file)
--- a/erebus.py
+++ b/erebus.py
@@ -7,7 +7,7 @@
 from __future__ import print_function
 
 import os, sys, select, time, traceback, random, gc
-import bot, config, ctlmod
+import bot, config, ctlmod, modlib
 
 class Erebus(object): #singleton to pass around
        APIVERSION = 0
@@ -241,6 +241,8 @@ class Erebus(object): #singleton to pass around
                self.bots[nick.lower()] = obj
 
        def newfd(self, obj, fileno):
+               if not isinstance(obj, modlib.Socketlike):
+                       raise Exception('Attempted to hook a socket without a class to process data')
                self.fds[fileno] = obj
                if self.potype == "poll":
                        self.po.register(fileno, select.POLLIN)
index b36c153287f4979c107a8768c60b91536949ec15..11ca1fb7a4acc454cab79b2254583fe8d1d885d8 100644 (file)
--- a/modlib.py
+++ b/modlib.py
@@ -3,6 +3,7 @@
 # module helper functions, see modules/modtest.py for usage
 # This file is released into the public domain; see http://unlicense.org/
 
+import abc
 import sys
 import socket
 from functools import wraps
@@ -60,6 +61,7 @@ class modlib(object):
        WRONGARGS = "Wrong number of arguments."
 
        def __init__(self, name):
+               self.Socketlike = Socketlike
                self.hooks = {} # {command:handler}
                self.chanhooks = {} # {channel:handler}
                self.exceptionhooks = [] # [(exception,handler)]
@@ -206,9 +208,7 @@ class modlib(object):
                return self._hooksocket(socket.AF_UNIX, socket.SOCK_STREAM, path, data)
        def _hooksocket(self, af, ty, address, data):
                def realhook(cls):
-                       if not (hasattr(cls, 'getdata') and callable(cls.getdata)):
-                               # Check early that the object implements getdata.
-                               # If getdata ever returns a non-empty list, then a parse method must also exist, but we don't check that.
+                       if not issubclass(cls, Socketlike):
                                raise Exception('Attempted to hook a socket without a class to process data')
                        self.sockhooks.append((af, ty, address, cls, data))
                        if self.parent is not None:
@@ -322,7 +322,52 @@ class modlib(object):
                        return func
                return realhook
 
-class _ListenSocket(object):
+class Socketlike(abc.ABC):
+       def __init__(self, sock, data):
+               """This default method saves the socket in self.sock and creates self.buffer for getdata(). The data is discarded."""
+               self.sock = sock
+               self.buffer = b''
+
+       def getdata(self):
+               """This default method gets LF or CRLF separated lines from the socket and returns an array of completely-seen lines to the core.
+               This should work well for most line-based protocols (like IRC)."""
+               recvd = self.sock.recv(8192)
+               if recvd == b"": # EOF
+                       if len(self.buffer) != 0:
+                               # Process what's left in the buffer. We'll get called again after.
+                               remaining_buf = self.buffer.decode('utf-8', 'backslashreplace')
+                               self.buffer = b""
+                               return [remaining_buf]
+                       else:
+                               # Nothing left in the buffer. Return None to signal the core to close this socket.
+                               return None
+               self.buffer += recvd
+               lines = []
+
+               while b"\n" in self.buffer:
+                       pieces = self.buffer.split(b"\n", 1)
+                       s = pieces[0].decode('utf-8', 'backslashreplace').rstrip("\r")
+                       lines.append(pieces[0].decode('utf-8', 'backslashreplace'))
+                       self.buffer = pieces[1]
+
+               return lines
+
+       def __str__(self):
+               return '%s#%d' % (self.__class__.__name__, self.sock.fileno())
+       def __repr__(self):
+               return '<%s.%s #%d %s:%d>' % ((self.__class__.__module__, self.__class__.__name__, self.sock.fileno())+self.sock.getpeername())
+
+       @abc.abstractmethod
+       def parse(self, chunk): pass
+
+       @classmethod
+       def __subclasshook__(cls, C):
+               if cls is Socketlike:
+                       if any('parse' in B.__dict__ and 'getdata' in B.__dict__ for B in C.__mro__):
+                               return True
+               return NotImplemented
+
+class _ListenSocket(Socketlike):
        def __init__(self, lib, sock, cls, data):
                self.clients = []
                self.lib = lib
@@ -343,6 +388,10 @@ class _ListenSocket(object):
                        self.clients.remove((client,obj))
                return close
 
+       def parse(self):
+               # getdata will never return a non-empty array, so parse will never be called; but Socketlike requires this method
+               pass
+
        def getdata(self):
                client, addr = self.sock.accept()
                obj = self.cls(client, self.data)
index f9577e18fbbee6eb70f0941d86720624bd6237ca..af650ed5f3a5a68e9e7dc7dc4ac748a5e77d7cda 100644 (file)
@@ -47,66 +47,32 @@ def modstart(parent, *args, **kwargs):
 modstop = lib.modstop
 
 # module code
+class BasicServer(lib.Socketlike):
+       def __init__(self, sock, data):
+               super(BasicServer, self).__init__(sock, data)
+               # NB neither directly referencing `channel`, nor trying to pass it through a default-arg-to-a-lambda like the python docs suggest, works here.
+               # Yay python. At least passing it via bind works.
+               self.chan = data
+
+       # default getdata() and send() methods are defined by lib.Socketlike
+       # suitable for line-based protocols like IRC
+
+       def parse(self, line):
+               try:
+                       bot = lib.parent.channel(self.chan).bot
+               except AttributeError: # <class 'AttributeError'> 'NoneType' object has no attribute 'bot'
+                       bot = lib.parent.randbot()
+               maxlen = bot.maxmsglen() - len("PRIVMSG  :") - len(self.chan)
+               while len(line) > maxlen:
+                       cutat = line.rfind(' ', 0, maxlen)
+                       if cutat == -1:
+                               cutat = maxlen
+                       bot.msg(self.chan, line[0:cutat])
+                       line = line[cutat:].strip()
+               bot.msg(self.chan, line)
+
+       # default __str__() and __repr__() methods are defined by lib.Socketlike
 
 def gotParent(parent):
        for bindto, channel in parent.cfg.items('sockets'):
-               @lib.bind(bindto, data=channel)
-               class BasicServer(object):
-                       def __init__(self, sock, data):
-                               # NB neither directly referencing `channel`, nor trying to pass it through a default-arg-to-a-lambda like the python docs suggest, works here.
-                               # Yay python. At least passing it via bind works.
-                               self.chan = data
-                               self.buffer = b''
-                               self.sock = sock
-
-                       def getdata(self):
-                               recvd = self.sock.recv(8192)
-                               if recvd == b"": # EOF
-                                       if len(self.buffer) != 0:
-                                               # Process what's left in the buffer. We'll get called again after.
-                                               remaining_buf = self.buffer.decode('utf-8', 'backslashreplace')
-                                               self.buffer = b""
-                                               return [remaining_buf]
-                                       else:
-                                               # Nothing left in the buffer. Return None to signal the core to close this socket.
-                                               return None
-                               self.buffer += recvd
-                               lines = []
-
-                               while b"\n" in self.buffer:
-                                       pieces = self.buffer.split(b"\n", 1)
-                                       s = pieces[0].decode('utf-8', 'backslashreplace').rstrip("\r")
-                                       lines.append(pieces[0].decode('utf-8', 'backslashreplace'))
-                                       self.buffer = pieces[1]
-
-                               return lines
-
-                       def parse(self, line):
-                               try:
-                                       bot = lib.parent.channel(self.chan).bot
-                               except AttributeError: # <class 'AttributeError'> 'NoneType' object has no attribute 'bot'
-                                       bot = lib.parent.randbot()
-                               maxlen = bot.maxmsglen() - len("PRIVMSG  :") - len(self.chan)
-                               while len(line) > maxlen:
-                                       cutat = line.rfind(' ', 0, maxlen)
-                                       if cutat == -1:
-                                               cutat = maxlen
-                                       bot.msg(self.chan, line[0:cutat])
-                                       line = line[cutat:].strip()
-                               bot.msg(self.chan, line)
-
-                       def send(self, line):
-                               if lib.parent.parent.cfg.getboolean('debug', 'io'):
-                                       lib.parent.log(str(self), 'O', line)
-                               self.sock.sendall(line.encode('utf-8', 'backslashreplace')+b"\r\n")
-
-                       def _getsockerr(self):
-                               try: # SO_ERROR might not exist on all platforms
-                                       return self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
-                               except:
-                                       return None
-
-                       def __str__(self):
-                               return '%s#%d' % (__name__, self.sock.fileno())
-                       def __repr__(self):
-                               return '<%s #%d %s:%d>' % ((__name__, self.sock.fileno())+self.sock.getpeername())
+               lib.bind(bindto, data=channel)(BasicServer)