]> jfr.im git - irc/quakenet/qwebirc.git/blobdiff - qwebirc/engines/ajaxengine.py
tidy up autobahn support -- now requires 0.17.2
[irc/quakenet/qwebirc.git] / qwebirc / engines / ajaxengine.py
index a7acb464e6ee8bec3b9a8b1c2ce180d7dd3930bb..33bc67bf8585601125ff3424e037535d9339de06 100644 (file)
@@ -2,11 +2,36 @@ from twisted.web import resource, server, static, error as http_error
 from twisted.names import client
 from twisted.internet import reactor, error
 from authgateengine import login_optional, getSessionData
 from twisted.names import client
 from twisted.internet import reactor, error
 from authgateengine import login_optional, getSessionData
-import simplejson, md5, sys, os, time, config, qwebirc.config_options as config_options, traceback, socket
+import md5, sys, os, time, config, qwebirc.config_options as config_options, traceback, socket
 import qwebirc.ircclient as ircclient
 from adminengine import AdminEngineAction
 from qwebirc.util import HitCounter
 import qwebirc.dns as qdns
 import qwebirc.ircclient as ircclient
 from adminengine import AdminEngineAction
 from qwebirc.util import HitCounter
 import qwebirc.dns as qdns
+import qwebirc.util.qjson as json
+import urlparse
+import qwebirc.util.autobahn_check as autobahn_check
+
+TRANSPORTS = ["longpoll"]
+
+has_websocket = False
+autobahn_status = autobahn_check.check()
+if autobahn_status == True:
+  import autobahn
+  import autobahn.twisted.websocket
+  import autobahn.twisted.resource
+  has_websocket = True
+  TRANSPORTS.append("websocket")
+elif autobahn_status == False:
+  # they've been warned already
+  pass
+else:
+  print >>sys.stderr, "WARNING:"
+  print >>sys.stderr, "  %s" % autobahn_status
+  print >>sys.stderr, "  as a result websocket support is disabled."
+  print >>sys.stderr, "  upgrade your version of autobahn from http://autobahn.ws/python/getstarted/"
+
+BAD_SESSION_MESSAGE = "Invalid session, this most likely means the server has restarted; close this dialog and then try refreshing the page."
+MAX_SEQNO = 9223372036854775807  # 2**63 - 1... yeah it doesn't wrap
 Sessions = {}
 
 def get_session_id():
 Sessions = {}
 
 def get_session_id():
@@ -21,25 +46,10 @@ class AJAXException(Exception):
 class IDGenerationException(Exception):
   pass
 
 class IDGenerationException(Exception):
   pass
 
-class PassthruException(Exception):
+class LineTooLongException(Exception):
   pass
   pass
-  
-NOT_DONE_YET = None
 
 
-def jsondump(fn):
-  def decorator(*args, **kwargs):
-    try:
-      x = fn(*args, **kwargs)
-      if x is None:
-        return server.NOT_DONE_YET
-      x = (True, x)
-    except AJAXException, e:
-      x = (False, e[0])
-    except PassthruException, e:
-      return str(e)
-      
-    return simplejson.dumps(x)
-  return decorator
+EMPTY_JSON_LIST = json.dumps([])
 
 def cleanupSession(id):
   try:
 
 def cleanupSession(id):
   try:
@@ -52,37 +62,43 @@ class IRCSession:
     self.id = id
     self.subscriptions = []
     self.buffer = []
     self.id = id
     self.subscriptions = []
     self.buffer = []
+    self.old_buffer = None
     self.buflen = 0
     self.throttle = 0
     self.schedule = None
     self.closed = False
     self.cleanupschedule = None
     self.buflen = 0
     self.throttle = 0
     self.schedule = None
     self.closed = False
     self.cleanupschedule = None
+    self.pubSeqNo = -1
+    self.subSeqNo = 0
 
 
-  def subscribe(self, channel, notifier):
-    timeout_entry = reactor.callLater(config.HTTP_AJAX_REQUEST_TIMEOUT, self.timeout, channel)
-    def cancel_timeout(result):
-      if channel in self.subscriptions:
-        self.subscriptions.remove(channel)
-      try:
-        timeout_entry.cancel()
-      except error.AlreadyCalled:
-        pass
-    notifier.addCallbacks(cancel_timeout, cancel_timeout)
-    
+  def subscribe(self, channel, seqNo=None):
     if len(self.subscriptions) >= config.MAXSUBSCRIPTIONS:
       self.subscriptions.pop(0).close()
 
     if len(self.subscriptions) >= config.MAXSUBSCRIPTIONS:
       self.subscriptions.pop(0).close()
 
+    if seqNo is not None and seqNo < self.subSeqNo:
+      if self.old_buffer is None or seqNo != self.old_buffer[0]:
+        channel.write(json.dumps([False, "Unable to reconnect -- sequence number too old."]), seqNo + 1)
+        return
+
+      if not channel.write(self.old_buffer[1], self.old_buffer[0] + 1):
+        return
+
     self.subscriptions.append(channel)
     self.subscriptions.append(channel)
-    self.flush()
-      
+    self.flush(seqNo)
+
+  def unsubscribe(self, channel):
+    try:
+      self.subscriptions.remove(channel)
+    except ValueError:
+      pass
+
   def timeout(self, channel):
     if self.schedule:
       return
   def timeout(self, channel):
     if self.schedule:
       return
-      
-    channel.write(simplejson.dumps([]))
-    if channel in self.subscriptions:
-      self.subscriptions.remove(channel)
-      
+
+    self.unsubscribe(channel)
+    channel.write(EMPTY_JSON_LIST, self.subSeqNo)
+
   def flush(self, scheduled=False):
     if scheduled:
       self.schedule = None
   def flush(self, scheduled=False):
     if scheduled:
       self.schedule = None
@@ -102,20 +118,23 @@ class IRCSession:
         if not self.schedule:
           self.schedule = reactor.callLater(0, self.flush, True)
         return
         if not self.schedule:
           self.schedule = reactor.callLater(0, self.flush, True)
         return
-        
+
     self.throttle = t + config.UPDATE_FREQ
 
     self.throttle = t + config.UPDATE_FREQ
 
-    encdata = simplejson.dumps(self.buffer)
+    encdata = json.dumps(self.buffer)
+    self.old_buffer = (self.subSeqNo, encdata)
+    self.subSeqNo+=1
     self.buffer = []
     self.buflen = 0
 
     self.buffer = []
     self.buflen = 0
 
-    newsubs = []
-    for x in self.subscriptions:
-      if x.write(encdata):
+    subs = self.subscriptions
+    self.subscriptions = newsubs = []
+
+    for x in subs:
+      if x.write(encdata, self.subSeqNo):
         newsubs.append(x)
 
         newsubs.append(x)
 
-    self.subscriptions = newsubs
-    if self.closed and not self.subscriptions:
+    if self.closed and not newsubs:
       cleanupSession(self.id)
 
   def event(self, data):
       cleanupSession(self.id)
 
   def event(self, data):
@@ -129,9 +148,18 @@ class IRCSession:
     self.buflen = newbuflen
     self.flush()
     
     self.buflen = newbuflen
     self.flush()
     
-  def push(self, data):
-    if not self.closed:
-      self.client.write(data)
+  def push(self, data, seq_no=None):
+    if self.closed:
+      return
+
+    if len(data) > config.MAXLINELEN:
+      raise LineTooLongException
+
+    if seq_no is not None:
+      if seq_no <= self.pubSeqNo:
+        return
+      self.pubSeqNo = seq_no
+    self.client.write(data)
 
   def disconnect(self):
     # keep the session hanging around for a few seconds so the
 
   def disconnect(self):
     # keep the session hanging around for a few seconds so the
@@ -144,12 +172,12 @@ class IRCSession:
 def connect_notice(line):
   return "c", "NOTICE", "", ("AUTH", "*** (qwebirc) %s" % line)
 
 def connect_notice(line):
   return "c", "NOTICE", "", ("AUTH", "*** (qwebirc) %s" % line)
 
-class Channel:
+class RequestChannel(object):
   def __init__(self, request):
     self.request = request
   def __init__(self, request):
     self.request = request
-  
-class SingleUseChannel(Channel):
-  def write(self, data):
+
+  def write(self, data, seqNo):
+    self.request.setHeader("n", str(seqNo))
     self.request.write(data)
     self.request.finish()
     return False
     self.request.write(data)
     self.request.finish()
     return False
@@ -157,11 +185,6 @@ class SingleUseChannel(Channel):
   def close(self):
     self.request.finish()
     
   def close(self):
     self.request.finish()
     
-class MultipleUseChannel(Channel):
-  def write(self, data):
-    self.request.write(data)
-    return True
-
 class AJAXEngine(resource.Resource):
   isLeaf = True
   
 class AJAXEngine(resource.Resource):
   isLeaf = True
   
@@ -170,20 +193,22 @@ class AJAXEngine(resource.Resource):
     self.__connect_hit = HitCounter()
     self.__total_hit = HitCounter()
     
     self.__connect_hit = HitCounter()
     self.__total_hit = HitCounter()
     
-  @jsondump
   def render_POST(self, request):
     path = request.path[len(self.prefix):]
     if path[0] == "/":
       handler = self.COMMANDS.get(path[1:])
       if handler is not None:
   def render_POST(self, request):
     path = request.path[len(self.prefix):]
     if path[0] == "/":
       handler = self.COMMANDS.get(path[1:])
       if handler is not None:
-        return handler(self, request)
-        
-    raise PassthruException, http_error.NoResource().render(request)
+        try:
+          return handler(self, request)
+        except AJAXException, e:
+          return json.dumps((False, e[0]))
+
+    return "404" ## TODO: tidy up
 
   def newConnection(self, request):
     ticket = login_optional(request)
     
 
   def newConnection(self, request):
     ticket = login_optional(request)
     
-    _, ip, port = request.transport.getPeer()
+    ip = request.getClientIP()
 
     nick = request.args.get("nick")
     if not nick:
 
     nick = request.args.get("nick")
     if not nick:
@@ -194,7 +219,7 @@ class AJAXEngine(resource.Resource):
     if password is not None:
       password = ircclient.irc_decode(password[0])
       
     if password is not None:
       password = ircclient.irc_decode(password[0])
       
-    for i in xrange(10):
+    for i in range(10):
       id = get_session_id()
       if not Sessions.get(id):
         break
       id = get_session_id()
       if not Sessions.get(id):
         break
@@ -242,8 +267,8 @@ class AJAXEngine(resource.Resource):
 
     Sessions[id] = session
     
 
     Sessions[id] = session
     
-    return id
-  
+    return json.dumps((True, id, TRANSPORTS))
+
   def getSession(self, request):
     bad_session_message = "Invalid session, this most likely means the server has restarted; close this dialog and then try refreshing the page."
     
   def getSession(self, request):
     bad_session_message = "Invalid session, this most likely means the server has restarted; close this dialog and then try refreshing the page."
     
@@ -257,26 +282,51 @@ class AJAXEngine(resource.Resource):
     return session
     
   def subscribe(self, request):
     return session
     
   def subscribe(self, request):
-    request.channel.cancelTimeout()
-    self.getSession(request).subscribe(SingleUseChannel(request), request.notifyFinish())
-    return NOT_DONE_YET
+    request.channel.setTimeout(None)
+
+    channel = RequestChannel(request)
+    session = self.getSession(request)
+    notifier = request.notifyFinish()
+
+    seq_no = request.args.get("n")
+    try:
+      if seq_no is not None:
+        seq_no = int(seq_no[0])
+        if seq_no < 0 or seq_no > MAX_SEQNO:
+          raise ValueError
+    except ValueError:
+      raise AJAXEngine, "Bad sequence number"
+
+    session.subscribe(channel, seq_no)
+
+    timeout_entry = reactor.callLater(config.HTTP_AJAX_REQUEST_TIMEOUT, session.timeout, channel)
+    def cancel_timeout(result):
+      try:
+        timeout_entry.cancel()
+      except error.AlreadyCalled:
+        pass
+      session.unsubscribe(channel)
+    notifier.addCallbacks(cancel_timeout, cancel_timeout)
+    return server.NOT_DONE_YET
 
   def push(self, request):
     command = request.args.get("c")
     if command is None:
       raise AJAXException, "No command specified."
     self.__total_hit()
 
   def push(self, request):
     command = request.args.get("c")
     if command is None:
       raise AJAXException, "No command specified."
     self.__total_hit()
-    
-    decoded = ircclient.irc_decode(command[0])
-    
-    session = self.getSession(request)
 
 
-    if len(decoded) > config.MAXLINELEN:
-      session.disconnect()
-      raise AJAXException, "Line too long."
+    seq_no = request.args.get("n")
+    try:
+      if seq_no is not None:
+        seq_no = int(seq_no[0])
+        if seq_no < 0 or seq_no > MAX_SEQNO:
+          raise ValueError
+    except ValueError:
+      raise AJAXEngine("Bad sequence number %r" % seq_no)
 
 
+    session = self.getSession(request)
     try:
     try:
-      session.push(decoded)
+      session.push(ircclient.irc_decode(command[0]), seq_no)
     except AttributeError: # occurs when we haven't noticed an error
       session.disconnect()
       raise AJAXException, "Connection closed by server; try reconnecting by reloading the page."
     except AttributeError: # occurs when we haven't noticed an error
       session.disconnect()
       raise AJAXException, "Connection closed by server; try reconnecting by reloading the page."
@@ -285,7 +335,7 @@ class AJAXEngine(resource.Resource):
       traceback.print_exc(file=sys.stderr)
       raise AJAXException, "Unknown error."
   
       traceback.print_exc(file=sys.stderr)
       raise AJAXException, "Unknown error."
   
-    return True
+    return json.dumps((True, True))
   
   def closeById(self, k):
     s = Sessions.get(k)
   
   def closeById(self, k):
     s = Sessions.get(k)
@@ -303,3 +353,120 @@ class AJAXEngine(resource.Resource):
     
   COMMANDS = dict(p=push, n=newConnection, s=subscribe)
   
     
   COMMANDS = dict(p=push, n=newConnection, s=subscribe)
   
+if has_websocket:
+  class WebSocketChannel(object):
+    def __init__(self, channel):
+      self.channel = channel
+
+    def write(self, data, seqNo):
+      self.channel.send("c", "%d,%s" % (seqNo, data))
+      return True
+
+    def close(self):
+      self.channel.close()
+
+  class WebSocketEngineProtocol(autobahn.twisted.websocket.WebSocketServerProtocol):
+    AWAITING_AUTH, AUTHED = 0, 1
+
+    def __init__(self, *args, **kwargs):
+      super(WebSocketEngineProtocol, self).__init__(*args, **kwargs)
+      self.__state = self.AWAITING_AUTH
+      self.__session = None
+      self.__channel = None
+      self.__timeout = None
+
+    def onOpen(self):
+      self.__timeout = reactor.callLater(5, self.close, "Authentication timeout")
+
+    def onClose(self, wasClean, code, reason):
+      self.__cancelTimeout()
+      if self.__session:
+        self.__session.unsubscribe(self.__channel)
+        self.__session = None
+
+    def onMessage(self, msg, isBinary):
+      # we don't bother checking the Origin header, as if you can auth then you've been able to pass the browser's
+      # normal origin handling (POSTed the new connection request and managed to get the session id)
+      state = self.__state
+      message_type, message = msg[:1], msg[1:]
+      if state == self.AWAITING_AUTH:
+        if message_type == "s":  # subscribe
+          tokens = message.split(",", 1)
+          if len(tokens) != 2:
+            self.close("Bad tokens")
+            return
+
+          seq_no, message = tokens[0], tokens[1]
+          try:
+            seq_no = int(seq_no)
+            if seq_no < 0 or seq_no > MAX_SEQNO:
+              raise ValueError
+          except ValueError:
+            self.close("Bad value")
+
+          session = Sessions.get(message)
+          if not session:
+            self.close(BAD_SESSION_MESSAGE)
+            return
+
+          self.__cancelTimeout()
+          self.__session = session
+          self.send("s", "True")
+          self.__state = self.AUTHED
+          self.__channel = WebSocketChannel(self)
+          session.subscribe(self.__channel, seq_no)
+          return
+      elif state == self.AUTHED:
+        if message_type == "p":  # push
+          tokens = message.split(",", 1)
+          if len(tokens) != 2:
+            self.close("Bad tokens")
+            return
+
+          seq_no, message = tokens[0], tokens[1]
+          try:
+            seq_no = int(seq_no)
+            if seq_no < 0 or seq_no > MAX_SEQNO:
+              raise ValueError
+          except ValueError:
+            self.close("Bad value")
+          self.__session.push(ircclient.irc_decode(message))
+          return
+
+      self.close("Bad message type")
+
+    def __cancelTimeout(self):
+      if self.__timeout is not None:
+        try:
+          self.__timeout.cancel()
+        except error.AlreadyCalled:
+          pass
+        self.__timeout = None
+
+    def close(self, reason=None):
+      self.__cancelTimeout()
+      if reason:
+        self.sendClose(4999, unicode(reason))
+      else:
+        self.sendClose(4998)
+
+      if self.__session:
+        self.__session.unsubscribe(self.__channel)
+        self.__session = None
+
+    def send(self, message_type, message):
+      self.sendMessage(message_type + message)
+
+  class WebSocketResource(autobahn.twisted.resource.WebSocketResource):
+    def render(self, request):
+      request.channel.setTimeout(None)
+      return autobahn.twisted.resource.WebSocketResource.render(self, request)
+
+  def WebSocketEngine(path=None):
+    factory = autobahn.twisted.websocket.WebSocketServerFactory("ws://localhost")
+    factory.externalPort = None
+    factory.protocol = WebSocketEngineProtocol
+    factory.setProtocolOptions(maxMessagePayloadSize=512, maxFramePayloadSize=512, tcpNoDelay=False)
+    resource = WebSocketResource(factory)
+    return resource
+