+
+if has_websocket:
+ class WebSocketChannel(object):
+ def __init__(self, channel):
+ self.channel = channel
+
+ def write(self, data, seqNo):
+ self.channel.send("c", "%d,%s" % (seqNo, data))
+ return True
+
+ def close(self):
+ self.channel.close()
+
+ class WebSocketEngineProtocol(autobahn.twisted.websocket.WebSocketServerProtocol):
+ AWAITING_AUTH, AUTHED = 0, 1
+
+ def __init__(self, *args, **kwargs):
+ super(WebSocketEngineProtocol, self).__init__(*args, **kwargs)
+ self.__state = self.AWAITING_AUTH
+ self.__session = None
+ self.__channel = None
+ self.__timeout = None
+
+ def onOpen(self):
+ self.__timeout = reactor.callLater(5, self.close, "Authentication timeout")
+
+ def onClose(self, wasClean, code, reason):
+ self.__cancelTimeout()
+ if self.__session:
+ self.__session.unsubscribe(self.__channel)
+ self.__session = None
+
+ def onMessage(self, msg, isBinary):
+ # we don't bother checking the Origin header, as if you can auth then you've been able to pass the browser's
+ # normal origin handling (POSTed the new connection request and managed to get the session id)
+ state = self.__state
+ message_type, message = msg[:1], msg[1:]
+ if state == self.AWAITING_AUTH:
+ if message_type == "s": # subscribe
+ tokens = message.split(",", 1)
+ if len(tokens) != 2:
+ self.close("Bad tokens")
+ return
+
+ seq_no, message = tokens[0], tokens[1]
+ try:
+ seq_no = int(seq_no)
+ if seq_no < 0 or seq_no > MAX_SEQNO:
+ raise ValueError
+ except ValueError:
+ self.close("Bad value")
+
+ session = Sessions.get(message)
+ if not session:
+ self.close(BAD_SESSION_MESSAGE)
+ return
+
+ self.__cancelTimeout()
+ self.__session = session
+ self.send("s", "True")
+ self.__state = self.AUTHED
+ self.__channel = WebSocketChannel(self)
+ session.subscribe(self.__channel, seq_no)
+ return
+ elif state == self.AUTHED:
+ if message_type == "p": # push
+ tokens = message.split(",", 1)
+ if len(tokens) != 2:
+ self.close("Bad tokens")
+ return
+
+ seq_no, message = tokens[0], tokens[1]
+ try:
+ seq_no = int(seq_no)
+ if seq_no < 0 or seq_no > MAX_SEQNO:
+ raise ValueError
+ except ValueError:
+ self.close("Bad value")
+ self.__session.push(ircclient.irc_decode(message))
+ return
+
+ self.close("Bad message type")
+
+ def __cancelTimeout(self):
+ if self.__timeout is not None:
+ try:
+ self.__timeout.cancel()
+ except error.AlreadyCalled:
+ pass
+ self.__timeout = None
+
+ def close(self, reason=None):
+ self.__cancelTimeout()
+ if reason:
+ self.sendClose(4999, unicode(reason))
+ else:
+ self.sendClose(4998)
+
+ if self.__session:
+ self.__session.unsubscribe(self.__channel)
+ self.__session = None
+
+ def send(self, message_type, message):
+ self.sendMessage(message_type + message)
+
+ class WebSocketResource(autobahn.twisted.resource.WebSocketResource):
+ def render(self, request):
+ request.channel.setTimeout(None)
+ return autobahn.twisted.resource.WebSocketResource.render(self, request)
+
+ def WebSocketEngine(path=None):
+ factory = autobahn.twisted.websocket.WebSocketServerFactory("ws://localhost")
+ factory.externalPort = None
+ factory.protocol = WebSocketEngineProtocol
+ factory.setProtocolOptions(maxMessagePayloadSize=512, maxFramePayloadSize=512, tcpNoDelay=False)
+ resource = WebSocketResource(factory)
+ return resource
+