Avoid lacking fds
[project/luci.git] / libs / lucid-http / luasrc / lucid / http / server.lua
index 67e917892ea9d243ffb55467e1601d2ac9c15f66..fd5f7cdd8bf9d97c3be01a1b80040ae9076177a4 100644 (file)
@@ -14,6 +14,7 @@ $Id$
 local ipairs, pairs = ipairs, pairs
 local tostring, tonumber = tostring, tonumber
 local pcall, assert, type = pcall, assert, type
+local set_memory_limit = set_memory_limit
 
 local os = require "os"
 local nixio = require "nixio"
@@ -23,6 +24,8 @@ local proto = require "luci.http.protocol"
 local table = require "table"
 local date = require "luci.http.protocol.date"
 
+--- HTTP Daemon
+-- @cstyle instance
 module "luci.lucid.http.server"
 
 VERSION = "1.0"
@@ -46,7 +49,11 @@ statusmsg = {
        [503] = "Server Unavailable",
 }
 
--- File Resource
+--- Create a new IO resource response.
+-- @class function
+-- @param fd File descriptor
+-- @param len Length of data
+-- @return IO resource
 IOResource = util.class()
 
 function IOResource.__init__(self, fd, len)
@@ -54,19 +61,26 @@ function IOResource.__init__(self, fd, len)
 end
 
 
--- Server handler implementation
+--- Create a server handler.
+-- @class function
+-- @param name Name
+-- @return Handler
 Handler = util.class()
 
 function Handler.__init__(self, name)
        self.name = name or tostring(self)
 end
 
--- Creates a failure reply
+--- Create a failure reply.
+-- @param code HTTP status code
+-- @param msg Status message
+-- @return status code, header table, response source
 function Handler.failure(self, code, msg)      
        return code, { ["Content-Type"] = "text/plain" }, ltn12.source.string(msg)
 end
 
--- Access Restrictions
+--- Add an access restriction.
+-- @param restriction Restriction specification
 function Handler.restrict(self, restriction)
        if not self.restrictions then
                self.restrictions = {restriction}
@@ -75,7 +89,9 @@ function Handler.restrict(self, restriction)
        end
 end
 
--- Check restrictions
+--- Enforce access restrictions.
+-- @param request Request object
+-- @return nil or HTTP statuscode, table of headers, response source
 function Handler.checkrestricted(self, request)
        if not self.restrictions then
                return
@@ -116,6 +132,7 @@ function Handler.checkrestricted(self, request)
                end
                
                if stat then
+                       request.env.HTTP_AUTH_USER, request.env.HTTP_AUTH_PASS = user, pass
                        return
                end
        end
@@ -126,7 +143,10 @@ function Handler.checkrestricted(self, request)
        }, ltn12.source.string("Unauthorized")
 end
 
--- Processes a request
+--- Process a request.
+-- @param request Request object
+-- @param sourcein Request data source
+-- @return HTTP statuscode, table of headers, response source
 function Handler.process(self, request, sourcein)
        local stat, code, hdr, sourceout
        
@@ -153,12 +173,19 @@ function Handler.process(self, request, sourcein)
 end
 
 
+--- Create a Virtual Host.
+-- @class function
+-- @return Virtual Host
 VHost = util.class()
 
 function VHost.__init__(self)
        self.handlers = {}
 end
 
+--- Process a request and invoke the appropriate handler. 
+-- @param request Request object
+-- @param ... Additional parameters passed to the handler
+-- @return HTTP statuscode, table of headers, response source 
 function VHost.process(self, request, ...)
        local handler
        local hlen = -1
@@ -171,6 +198,10 @@ function VHost.process(self, request, ...)
        -- Call URI part
        request.env.PATH_INFO = uri
        
+       if self.default and uri == "/" then
+               return 302, {Location = self.default}
+       end
+
        for k, h in pairs(self.handlers) do
                if #k > hlen then
                        if uri == k or (uri:sub(1, #k) == k and uri:byte(#k+1) == sc) then
@@ -189,15 +220,20 @@ function VHost.process(self, request, ...)
        end
 end
 
+--- Get a list of registered handlers.
+-- @return Table of handlers
 function VHost.get_handlers(self)
        return self.handlers
 end
 
+--- Register handler with a given URI prefix.
+-- @oaram match URI prefix
+-- @param handler Handler object
 function VHost.set_handler(self, match, handler)
        self.handlers[match] = handler
 end
 
-
+-- Remap IPv6-IPv4-compatibility addresses back to IPv4 addresses.
 local function remapipv6(adr)
        local map = "::ffff:"
        if adr:sub(1, #map) == map then
@@ -207,6 +243,7 @@ local function remapipv6(adr)
        end 
 end
 
+-- Create a source that decodes chunked-encoded data from a socket.
 local function chunksource(sock, buffer)
        buffer = buffer or ""
        return function()
@@ -250,30 +287,44 @@ local function chunksource(sock, buffer)
        end
 end
 
+-- Create a sink that chunk-encodes data and writes it on a given socket.
 local function chunksink(sock)
        return function(chunk, err)
                if not chunk then
                        return sock:writeall("0\r\n\r\n")
                else
-                       return sock:writeall(("%X\r\n%s\r\n"):format(#chunk, chunk))
+                       return sock:writeall(("%X\r\n%s\r\n"):format(#chunk, tostring(chunk)))
                end
        end
 end
 
+
+--- Create a server object.
+-- @class function
+-- @return Server object
 Server = util.class()
 
 function Server.__init__(self)
        self.vhosts = {}
 end
 
+--- Get a list of registered virtual hosts.
+-- @return Table of virtual hosts
 function Server.get_vhosts(self)
        return self.vhosts
 end
 
+--- Register a virtual host with a given name.
+-- @param name Hostname
+-- @param vhost Virtual host object
 function Server.set_vhost(self, name, vhost)
        self.vhosts[name] = vhost
 end
 
+--- Send a fatal error message to given client and close the connection.
+-- @param client Client socket
+-- @param code HTTP status code
+-- @param msg status message
 function Server.error(self, client, code, msg)
        hcode = tostring(code)
        
@@ -304,6 +355,9 @@ local hdr2env = {
        ["User-Agent"] = "HTTP_USER_AGENT"
 }
 
+--- Parse the request headers and prepare the environment.
+-- @param source line-based input source
+-- @return Request object
 function Server.parse_headers(self, source)
        local env = {}
        local req = {env = env, headers = {}}
@@ -348,7 +402,9 @@ function Server.parse_headers(self, source)
        return req
 end
 
-
+--- Handle a new client connection.
+-- @param client client socket
+-- @param env superserver environment
 function Server.process(self, client, env)
        local sourcein  = function() end
        local sourcehdr = client:linesource()
@@ -358,6 +414,11 @@ function Server.process(self, client, env)
        local close = false
        local stat, code, msg, message, err
        
+       env.config.memlimit = tonumber(env.config.memlimit)
+       if env.config.memlimit and set_memory_limit then
+               set_memory_limit(env.config.memlimit)
+       end
+
        client:setsockopt("socket", "rcvtimeo", 5)
        client:setsockopt("socket", "sndtimeo", 5)
        
@@ -432,6 +493,8 @@ function Server.process(self, client, env)
                        else
                                return self:error(client, 411, statusmsg[411])
                        end
+
+                       close = true
                else
                        return self:error(client, 405, statusmsg[405])
                end
@@ -452,7 +515,7 @@ function Server.process(self, client, env)
                                        headers["Content-Length"] = sourceout.len
                                end
                        end
-                       if not headers["Content-Length"] then
+                       if not headers["Content-Length"] and not close then
                                if message.env.SERVER_PROTOCOL == "HTTP/1.1" then
                                        headers["Transfer-Encoding"] = "chunked"
                                        sinkout = chunksink(client)
@@ -466,9 +529,10 @@ function Server.process(self, client, env)
                
                if close then
                        headers["Connection"] = "close"
-               elseif message.env.SERVER_PROTOCOL == "HTTP/1.0" then
+               else
                        headers["Connection"] = "Keep-Alive"
-               end 
+                       headers["Keep-Alive"] = "timeout=5, max=50"
+               end
 
                headers["Date"] = date.to_http(os.time())
                local header = {
@@ -495,11 +559,25 @@ function Server.process(self, client, env)
                stat, code, msg = client:writeall(table.concat(header, "\r\n"))
 
                if sourceout and stat then
+                       local closefd
                        if util.instanceof(sourceout, IOResource) then
-                               stat, code, msg = sourceout.fd:copyz(client, sourceout.len)
-                       else
+                               if not headers["Transfer-Encoding"] then
+                                       stat, code, msg = sourceout.fd:copyz(client, sourceout.len)
+                                       closefd = sourceout.fd
+                                       sourceout = nil
+                               else
+                                       closefd = sourceout.fd
+                                       sourceout = sourceout.fd:blocksource(nil, sourceout.len)
+                               end
+                       end
+
+                       if sourceout then
                                stat, msg = ltn12.pump.all(sourceout, sinkout)
                        end
+
+                       if closefd then
+                               closefd:close()
+                       end
                end