]> jfr.im git - irc/quakenet/qwebirc.git/blobdiff - qwebirc/engines/ajaxengine.py
tidy up autobahn support -- now requires 0.17.2
[irc/quakenet/qwebirc.git] / qwebirc / engines / ajaxengine.py
index d30b0c86e803bc39c41af2a3e911e51a5ecc0478..33bc67bf8585601125ff3424e037535d9339de06 100644 (file)
@@ -2,11 +2,36 @@ from twisted.web import resource, server, static, error as http_error
 from twisted.names import client
 from twisted.internet import reactor, error
 from authgateengine import login_optional, getSessionData
-import simplejson, md5, sys, os, time, config, weakref, traceback
+import md5, sys, os, time, config, qwebirc.config_options as config_options, traceback, socket
 import qwebirc.ircclient as ircclient
 from adminengine import AdminEngineAction
 from qwebirc.util import HitCounter
+import qwebirc.dns as qdns
+import qwebirc.util.qjson as json
+import urlparse
+import qwebirc.util.autobahn_check as autobahn_check
 
+TRANSPORTS = ["longpoll"]
+
+has_websocket = False
+autobahn_status = autobahn_check.check()
+if autobahn_status == True:
+  import autobahn
+  import autobahn.twisted.websocket
+  import autobahn.twisted.resource
+  has_websocket = True
+  TRANSPORTS.append("websocket")
+elif autobahn_status == False:
+  # they've been warned already
+  pass
+else:
+  print >>sys.stderr, "WARNING:"
+  print >>sys.stderr, "  %s" % autobahn_status
+  print >>sys.stderr, "  as a result websocket support is disabled."
+  print >>sys.stderr, "  upgrade your version of autobahn from http://autobahn.ws/python/getstarted/"
+
+BAD_SESSION_MESSAGE = "Invalid session, this most likely means the server has restarted; close this dialog and then try refreshing the page."
+MAX_SEQNO = 9223372036854775807  # 2**63 - 1... yeah it doesn't wrap
 Sessions = {}
 
 def get_session_id():
@@ -21,25 +46,10 @@ class AJAXException(Exception):
 class IDGenerationException(Exception):
   pass
 
-class PassthruException(Exception):
+class LineTooLongException(Exception):
   pass
-  
-NOT_DONE_YET = None
 
-def jsondump(fn):
-  def decorator(*args, **kwargs):
-    try:
-      x = fn(*args, **kwargs)
-      if x is None:
-        return server.NOT_DONE_YET
-      x = (True, x)
-    except AJAXException, e:
-      x = (False, e[0])
-    except PassthruException, e:
-      return str(e)
-      
-    return simplejson.dumps(x)
-  return decorator
+EMPTY_JSON_LIST = json.dumps([])
 
 def cleanupSession(id):
   try:
@@ -52,36 +62,43 @@ class IRCSession:
     self.id = id
     self.subscriptions = []
     self.buffer = []
+    self.old_buffer = None
+    self.buflen = 0
     self.throttle = 0
     self.schedule = None
     self.closed = False
     self.cleanupschedule = None
+    self.pubSeqNo = -1
+    self.subSeqNo = 0
 
-  def subscribe(self, channel, notifier):
-    timeout_entry = reactor.callLater(config.HTTP_AJAX_REQUEST_TIMEOUT, self.timeout, channel)
-    def cancel_timeout(result):
-      if channel in self.subscriptions:
-        self.subscriptions.remove(channel)
-      try:
-        timeout_entry.cancel()
-      except error.AlreadyCalled:
-        pass
-    notifier.addCallbacks(cancel_timeout, cancel_timeout)
-    
+  def subscribe(self, channel, seqNo=None):
     if len(self.subscriptions) >= config.MAXSUBSCRIPTIONS:
       self.subscriptions.pop(0).close()
 
+    if seqNo is not None and seqNo < self.subSeqNo:
+      if self.old_buffer is None or seqNo != self.old_buffer[0]:
+        channel.write(json.dumps([False, "Unable to reconnect -- sequence number too old."]), seqNo + 1)
+        return
+
+      if not channel.write(self.old_buffer[1], self.old_buffer[0] + 1):
+        return
+
     self.subscriptions.append(channel)
-    self.flush()
-      
+    self.flush(seqNo)
+
+  def unsubscribe(self, channel):
+    try:
+      self.subscriptions.remove(channel)
+    except ValueError:
+      pass
+
   def timeout(self, channel):
     if self.schedule:
       return
-      
-    channel.write(simplejson.dumps([]))
-    if channel in self.subscriptions:
-      self.subscriptions.remove(channel)
-      
+
+    self.unsubscribe(channel)
+    channel.write(EMPTY_JSON_LIST, self.subSeqNo)
+
   def flush(self, scheduled=False):
     if scheduled:
       self.schedule = None
@@ -101,34 +118,48 @@ class IRCSession:
         if not self.schedule:
           self.schedule = reactor.callLater(0, self.flush, True)
         return
-        
+
     self.throttle = t + config.UPDATE_FREQ
 
-    encdata = simplejson.dumps(self.buffer)
+    encdata = json.dumps(self.buffer)
+    self.old_buffer = (self.subSeqNo, encdata)
+    self.subSeqNo+=1
     self.buffer = []
-    
-    newsubs = []
-    for x in self.subscriptions:
-      if x.write(encdata):
+    self.buflen = 0
+
+    subs = self.subscriptions
+    self.subscriptions = newsubs = []
+
+    for x in subs:
+      if x.write(encdata, self.subSeqNo):
         newsubs.append(x)
 
-    self.subscriptions = newsubs
-    if self.closed and not self.subscriptions:
+    if self.closed and not newsubs:
       cleanupSession(self.id)
 
   def event(self, data):
-    bufferlen = sum(map(len, self.buffer))
-    if bufferlen + len(data) > config.MAXBUFLEN:
+    newbuflen = self.buflen + len(data)
+    if newbuflen > config.MAXBUFLEN:
       self.buffer = []
       self.client.error("Buffer overflow.")
       return
 
     self.buffer.append(data)
+    self.buflen = newbuflen
     self.flush()
     
-  def push(self, data):
-    if not self.closed:
-      self.client.write(data)
+  def push(self, data, seq_no=None):
+    if self.closed:
+      return
+
+    if len(data) > config.MAXLINELEN:
+      raise LineTooLongException
+
+    if seq_no is not None:
+      if seq_no <= self.pubSeqNo:
+        return
+      self.pubSeqNo = seq_no
+    self.client.write(data)
 
   def disconnect(self):
     # keep the session hanging around for a few seconds so the
@@ -137,12 +168,16 @@ class IRCSession:
 
     reactor.callLater(5, cleanupSession, self.id)
 
-class Channel:
+# DANGER! Breach of encapsulation!
+def connect_notice(line):
+  return "c", "NOTICE", "", ("AUTH", "*** (qwebirc) %s" % line)
+
+class RequestChannel(object):
   def __init__(self, request):
     self.request = request
-  
-class SingleUseChannel(Channel):
-  def write(self, data):
+
+  def write(self, data, seqNo):
+    self.request.setHeader("n", str(seqNo))
     self.request.write(data)
     self.request.finish()
     return False
@@ -150,11 +185,6 @@ class SingleUseChannel(Channel):
   def close(self):
     self.request.finish()
     
-class MultipleUseChannel(Channel):
-  def write(self, data):
-    self.request.write(data)
-    return True
-
 class AJAXEngine(resource.Resource):
   isLeaf = True
   
@@ -163,32 +193,33 @@ class AJAXEngine(resource.Resource):
     self.__connect_hit = HitCounter()
     self.__total_hit = HitCounter()
     
-  @jsondump
   def render_POST(self, request):
     path = request.path[len(self.prefix):]
     if path[0] == "/":
       handler = self.COMMANDS.get(path[1:])
       if handler is not None:
-        return handler(self, request)
-        
-    raise PassthruException, http_error.NoResource().render(request)
+        try:
+          return handler(self, request)
+        except AJAXException, e:
+          return json.dumps((False, e[0]))
+
+    return "404" ## TODO: tidy up
 
-  #def render_GET(self, request):
-    #return self.render_POST(request)
-  
   def newConnection(self, request):
     ticket = login_optional(request)
     
-    _, ip, port = request.transport.getPeer()
+    ip = request.getClientIP()
 
     nick = request.args.get("nick")
     if not nick:
       raise AJAXException, "Nickname not supplied."
     nick = ircclient.irc_decode(nick[0])
 
-    ident, realname = "webchat", config.REALNAME
-    
-    for i in xrange(10):
+    password = request.args.get("password")
+    if password is not None:
+      password = ircclient.irc_decode(password[0])
+      
+    for i in range(10):
       id = get_session_id()
       if not Sessions.get(id):
         break
@@ -201,16 +232,43 @@ class AJAXEngine(resource.Resource):
     if qticket is None:
       perform = None
     else:
-      perform = ["PRIVMSG %s :TICKETAUTH %s" % (config.QBOT, qticket)]
+      service_mask = config.AUTH_SERVICE
+      msg_mask = service_mask.split("!")[0] + "@" + service_mask.split("@", 1)[1]
+      perform = ["PRIVMSG %s :TICKETAUTH %s" % (msg_mask, qticket)]
+
+    ident, realname = config.IDENT, config.REALNAME
+    if ident is config_options.IDENT_HEX or ident is None: # latter is legacy
+      ident = socket.inet_aton(ip).encode("hex")
+    elif ident is config_options.IDENT_NICKNAME:
+      ident = nick
 
     self.__connect_hit()
-    client = ircclient.createIRC(session, nick=nick, ident=ident, ip=ip, realname=realname, perform=perform)
-    session.client = client
-    
+
+    def proceed(hostname):
+      kwargs = dict(nick=nick, ident=ident, ip=ip, realname=realname, perform=perform, hostname=hostname)
+      if password is not None:
+        kwargs["password"] = password
+        
+      client = ircclient.createIRC(session, **kwargs)
+      session.client = client
+
+    if not hasattr(config, "WEBIRC_MODE") or config.WEBIRC_MODE == "hmac":
+      proceed(None)
+    elif config.WEBIRC_MODE != "hmac":
+      notice = lambda x: session.event(connect_notice(x))
+      notice("Looking up your hostname...")
+      def callback(hostname):
+        notice("Found your hostname.")
+        proceed(hostname)
+      def errback(failure):
+        notice("Couldn't look up your hostname!")
+        proceed(ip)
+      qdns.lookupAndVerifyPTR(ip, timeout=[config.DNS_TIMEOUT]).addCallbacks(callback, errback)
+
     Sessions[id] = session
     
-    return id
-  
+    return json.dumps((True, id, TRANSPORTS))
+
   def getSession(self, request):
     bad_session_message = "Invalid session, this most likely means the server has restarted; close this dialog and then try refreshing the page."
     
@@ -224,26 +282,51 @@ class AJAXEngine(resource.Resource):
     return session
     
   def subscribe(self, request):
-    request.channel.cancelTimeout()
-    self.getSession(request).subscribe(SingleUseChannel(request), request.notifyFinish())
-    return NOT_DONE_YET
+    request.channel.setTimeout(None)
+
+    channel = RequestChannel(request)
+    session = self.getSession(request)
+    notifier = request.notifyFinish()
+
+    seq_no = request.args.get("n")
+    try:
+      if seq_no is not None:
+        seq_no = int(seq_no[0])
+        if seq_no < 0 or seq_no > MAX_SEQNO:
+          raise ValueError
+    except ValueError:
+      raise AJAXEngine, "Bad sequence number"
+
+    session.subscribe(channel, seq_no)
+
+    timeout_entry = reactor.callLater(config.HTTP_AJAX_REQUEST_TIMEOUT, session.timeout, channel)
+    def cancel_timeout(result):
+      try:
+        timeout_entry.cancel()
+      except error.AlreadyCalled:
+        pass
+      session.unsubscribe(channel)
+    notifier.addCallbacks(cancel_timeout, cancel_timeout)
+    return server.NOT_DONE_YET
 
   def push(self, request):
     command = request.args.get("c")
     if command is None:
       raise AJAXException, "No command specified."
     self.__total_hit()
-    
-    decoded = ircclient.irc_decode(command[0])
-    
-    session = self.getSession(request)
 
-    if len(decoded) > config.MAXLINELEN:
-      session.disconnect()
-      raise AJAXException, "Line too long."
+    seq_no = request.args.get("n")
+    try:
+      if seq_no is not None:
+        seq_no = int(seq_no[0])
+        if seq_no < 0 or seq_no > MAX_SEQNO:
+          raise ValueError
+    except ValueError:
+      raise AJAXEngine("Bad sequence number %r" % seq_no)
 
+    session = self.getSession(request)
     try:
-      session.push(decoded)
+      session.push(ircclient.irc_decode(command[0]), seq_no)
     except AttributeError: # occurs when we haven't noticed an error
       session.disconnect()
       raise AJAXException, "Connection closed by server; try reconnecting by reloading the page."
@@ -252,7 +335,7 @@ class AJAXEngine(resource.Resource):
       traceback.print_exc(file=sys.stderr)
       raise AJAXException, "Unknown error."
   
-    return True
+    return json.dumps((True, True))
   
   def closeById(self, k):
     s = Sessions.get(k)
@@ -269,4 +352,121 @@ class AJAXEngine(resource.Resource):
     }
     
   COMMANDS = dict(p=push, n=newConnection, s=subscribe)
-  
\ No newline at end of file
+  
+if has_websocket:
+  class WebSocketChannel(object):
+    def __init__(self, channel):
+      self.channel = channel
+
+    def write(self, data, seqNo):
+      self.channel.send("c", "%d,%s" % (seqNo, data))
+      return True
+
+    def close(self):
+      self.channel.close()
+
+  class WebSocketEngineProtocol(autobahn.twisted.websocket.WebSocketServerProtocol):
+    AWAITING_AUTH, AUTHED = 0, 1
+
+    def __init__(self, *args, **kwargs):
+      super(WebSocketEngineProtocol, self).__init__(*args, **kwargs)
+      self.__state = self.AWAITING_AUTH
+      self.__session = None
+      self.__channel = None
+      self.__timeout = None
+
+    def onOpen(self):
+      self.__timeout = reactor.callLater(5, self.close, "Authentication timeout")
+
+    def onClose(self, wasClean, code, reason):
+      self.__cancelTimeout()
+      if self.__session:
+        self.__session.unsubscribe(self.__channel)
+        self.__session = None
+
+    def onMessage(self, msg, isBinary):
+      # we don't bother checking the Origin header, as if you can auth then you've been able to pass the browser's
+      # normal origin handling (POSTed the new connection request and managed to get the session id)
+      state = self.__state
+      message_type, message = msg[:1], msg[1:]
+      if state == self.AWAITING_AUTH:
+        if message_type == "s":  # subscribe
+          tokens = message.split(",", 1)
+          if len(tokens) != 2:
+            self.close("Bad tokens")
+            return
+
+          seq_no, message = tokens[0], tokens[1]
+          try:
+            seq_no = int(seq_no)
+            if seq_no < 0 or seq_no > MAX_SEQNO:
+              raise ValueError
+          except ValueError:
+            self.close("Bad value")
+
+          session = Sessions.get(message)
+          if not session:
+            self.close(BAD_SESSION_MESSAGE)
+            return
+
+          self.__cancelTimeout()
+          self.__session = session
+          self.send("s", "True")
+          self.__state = self.AUTHED
+          self.__channel = WebSocketChannel(self)
+          session.subscribe(self.__channel, seq_no)
+          return
+      elif state == self.AUTHED:
+        if message_type == "p":  # push
+          tokens = message.split(",", 1)
+          if len(tokens) != 2:
+            self.close("Bad tokens")
+            return
+
+          seq_no, message = tokens[0], tokens[1]
+          try:
+            seq_no = int(seq_no)
+            if seq_no < 0 or seq_no > MAX_SEQNO:
+              raise ValueError
+          except ValueError:
+            self.close("Bad value")
+          self.__session.push(ircclient.irc_decode(message))
+          return
+
+      self.close("Bad message type")
+
+    def __cancelTimeout(self):
+      if self.__timeout is not None:
+        try:
+          self.__timeout.cancel()
+        except error.AlreadyCalled:
+          pass
+        self.__timeout = None
+
+    def close(self, reason=None):
+      self.__cancelTimeout()
+      if reason:
+        self.sendClose(4999, unicode(reason))
+      else:
+        self.sendClose(4998)
+
+      if self.__session:
+        self.__session.unsubscribe(self.__channel)
+        self.__session = None
+
+    def send(self, message_type, message):
+      self.sendMessage(message_type + message)
+
+  class WebSocketResource(autobahn.twisted.resource.WebSocketResource):
+    def render(self, request):
+      request.channel.setTimeout(None)
+      return autobahn.twisted.resource.WebSocketResource.render(self, request)
+
+  def WebSocketEngine(path=None):
+    factory = autobahn.twisted.websocket.WebSocketServerFactory("ws://localhost")
+    factory.externalPort = None
+    factory.protocol = WebSocketEngineProtocol
+    factory.setProtocolOptions(maxMessagePayloadSize=512, maxFramePayloadSize=512, tcpNoDelay=False)
+    resource = WebSocketResource(factory)
+    return resource
+