cgi-io: use O_TMPFILE for uploads and attempt to directly link target file
[project/cgi-io.git] / src / main.c
index 7cf8d7b23dd097b8c7e74b5f6724f6ac4ad8088b..d45c67b85f345110a54a30e8d077abcbea559c8b 100644 (file)
@@ -38,6 +38,7 @@
 #include "multipart_parser.h"
 
 #define READ_BLOCK 4096
+#define POST_LIMIT 131072
 
 enum part {
        PART_UNKNOWN,
@@ -223,55 +224,86 @@ urldecode(char *buf)
        return true;
 }
 
-static bool
+static char *
 postdecode(char **fields, int n_fields)
 {
-       char *p;
        const char *var;
-       static char buf[1024];
-       int i, len, field, found = 0;
+       char *p, *postbuf;
+       int i, field, found = 0;
+       ssize_t len = 0, rlen = 0, content_length = 0;
 
        var = getenv("CONTENT_TYPE");
 
        if (!var || strncmp(var, "application/x-www-form-urlencoded", 33))
-               return false;
+               return NULL;
+
+       var = getenv("CONTENT_LENGTH");
+
+       if (!var)
+               return NULL;
+
+       content_length = strtol(var, &p, 10);
+
+       if (p == var || content_length <= 0 || content_length >= POST_LIMIT)
+               return NULL;
 
-       memset(buf, 0, sizeof(buf));
+       postbuf = calloc(1, content_length + 1);
 
-       if ((len = read(0, buf, sizeof(buf) - 1)) > 0)
+       if (postbuf == NULL)
+               return NULL;
+
+       for (len = 0; len < content_length; )
        {
-               for (p = buf, i = 0; i <= len; i++)
+               rlen = read(0, postbuf + len, content_length - len);
+
+               if (rlen <= 0)
+                       break;
+
+               len += rlen;
+       }
+
+       if (len < content_length)
+       {
+               free(postbuf);
+               return NULL;
+       }
+
+       for (p = postbuf, i = 0; i <= len; i++)
+       {
+               if (postbuf[i] == '=')
                {
-                       if (buf[i] == '=')
-                       {
-                               buf[i] = 0;
+                       postbuf[i] = 0;
 
-                               for (field = 0; field < (n_fields * 2); field += 2)
+                       for (field = 0; field < (n_fields * 2); field += 2)
+                       {
+                               if (!strcmp(p, fields[field]))
                                {
-                                       if (!strcmp(p, fields[field]))
-                                       {
-                                               fields[field + 1] = buf + i + 1;
-                                               found++;
-                                       }
+                                       fields[field + 1] = postbuf + i + 1;
+                                       found++;
                                }
                        }
-                       else if (buf[i] == '&' || buf[i] == '\0')
-                       {
-                               buf[i] = 0;
+               }
+               else if (postbuf[i] == '&' || postbuf[i] == '\0')
+               {
+                       postbuf[i] = 0;
 
-                               if (found >= n_fields)
-                                       break;
+                       if (found >= n_fields)
+                               break;
 
-                               p = buf + i + 1;
-                       }
+                       p = postbuf + i + 1;
                }
        }
 
        for (field = 0; field < (n_fields * 2); field += 2)
+       {
                if (!urldecode(fields[field + 1]))
-                       return false;
+               {
+                       free(postbuf);
+                       return NULL;
+               }
+       }
 
-       return (found >= n_fields);
+       return postbuf;
 }
 
 static char *
@@ -404,32 +436,44 @@ filecopy(void)
                return response(false, "No file data received");
        }
 
-       if (lseek(st.tempfd, 0, SEEK_SET) < 0)
-       {
-               close(st.tempfd);
-               return response(false, "Failed to rewind temp file");
-       }
-
-       st.filefd = open(st.filename, O_CREAT | O_TRUNC | O_WRONLY, 0600);
+       snprintf(buf, sizeof(buf), "/proc/self/fd/%d", st.tempfd);
 
-       if (st.filefd < 0)
+       if (unlink(st.filename) < 0 && errno != ENOENT)
        {
                close(st.tempfd);
-               return response(false, "Failed to open target file");
+               return response(false, "Failed to unlink existing file");
        }
 
-       while ((len = read(st.tempfd, buf, sizeof(buf))) > 0)
+       if (linkat(AT_FDCWD, buf, AT_FDCWD, st.filename, AT_SYMLINK_FOLLOW) < 0)
        {
-               if (write(st.filefd, buf, len) != len)
+               if (lseek(st.tempfd, 0, SEEK_SET) < 0)
                {
                        close(st.tempfd);
-                       close(st.filefd);
-                       return response(false, "I/O failure while writing target file");
+                       return response(false, "Failed to rewind temp file");
                }
+
+               st.filefd = open(st.filename, O_CREAT | O_TRUNC | O_WRONLY, 0600);
+
+               if (st.filefd < 0)
+               {
+                       close(st.tempfd);
+                       return response(false, "Failed to open target file");
+               }
+
+               while ((len = read(st.tempfd, buf, sizeof(buf))) > 0)
+               {
+                       if (write(st.filefd, buf, len) != len)
+                       {
+                               close(st.tempfd);
+                               close(st.filefd);
+                               return response(false, "I/O failure while writing target file");
+                       }
+               }
+
+               close(st.filefd);
        }
 
        close(st.tempfd);
-       close(st.filefd);
 
        if (chmod(st.filename, st.filemode))
                return response(false, "Failed to chmod target file");
@@ -478,8 +522,6 @@ header_value(multipart_parser *p, const char *data, size_t len)
 static int
 data_begin_cb(multipart_parser *p)
 {
-       char tmpname[24] = "/tmp/luci-upload.XXXXXX";
-
        if (st.parttype == PART_FILEDATA)
        {
                if (!st.sessionid)
@@ -491,12 +533,10 @@ data_begin_cb(multipart_parser *p)
                if (!session_access(st.sessionid, "file", st.filename, "write"))
                        return response(false, "Access to path denied by ACL");
 
-               st.tempfd = mkstemp(tmpname);
+               st.tempfd = open("/tmp", O_TMPFILE | O_RDWR, S_IRUSR | S_IWUSR);
 
                if (st.tempfd < 0)
                        return response(false, "Failed to create temporary file");
-
-               unlink(tmpname);
        }
 
        return 0;
@@ -658,6 +698,14 @@ main_upload(int argc, char *argv[])
        return 0;
 }
 
+static void
+free_charp(char **ptr)
+{
+       free(*ptr);
+}
+
+#define autochar __attribute__((__cleanup__(free_charp))) char
+
 static int
 main_download(int argc, char **argv)
 {
@@ -668,7 +716,7 @@ main_download(int argc, char **argv)
        struct stat s;
        int rfd;
 
-       postdecode(fields, 4);
+       autochar *post = postdecode(fields, 4);
 
        if (!fields[1] || !session_access(fields[1], "cgi-io", "download", "read"))
                return failure(403, 0, "Download permission denied");
@@ -706,29 +754,39 @@ main_download(int argc, char **argv)
        if (fields[5])
                printf("Content-Disposition: attachment; filename=\"%s\"\r\n", fields[5]);
 
-       printf("Content-Length: %llu\r\n\r\n", size);
-       fflush(stdout);
+       if (size > 0) {
+               printf("Content-Length: %llu\r\n\r\n", size);
+               fflush(stdout);
 
-       while (size > 0) {
-               len = sendfile(1, rfd, NULL, size);
+               while (size > 0) {
+                       len = sendfile(1, rfd, NULL, size);
 
-               if (len == -1) {
-                       if (errno == ENOSYS || errno == EINVAL) {
-                               while ((len = read(rfd, buf, sizeof(buf))) > 0)
-                                       fwrite(buf, len, 1, stdout);
+                       if (len == -1) {
+                               if (errno == ENOSYS || errno == EINVAL) {
+                                       while ((len = read(rfd, buf, sizeof(buf))) > 0)
+                                               fwrite(buf, len, 1, stdout);
 
-                               fflush(stdout);
-                               break;
+                                       fflush(stdout);
+                                       break;
+                               }
+
+                               if (errno == EINTR || errno == EAGAIN)
+                                       continue;
                        }
 
-                       if (errno == EINTR || errno == EAGAIN)
-                               continue;
+                       if (len <= 0)
+                               break;
+
+                       size -= len;
                }
+       }
+       else {
+               printf("\r\n");
 
-               if (len <= 0)
-                       break;
+               while ((len = read(rfd, buf, sizeof(buf))) > 0)
+                       fwrite(buf, len, 1, stdout);
 
-               size -= len;
+               fflush(stdout);
        }
 
        close(rfd);
@@ -749,7 +807,9 @@ main_backup(int argc, char **argv)
        char hostname[64] = { 0 };
        char *fields[] = { "sessionid", NULL };
 
-       if (!postdecode(fields, 1) || !session_access(fields[1], "cgi-io", "backup", "read"))
+       autochar *post = postdecode(fields, 1);
+
+       if (!fields[1] || !session_access(fields[1], "cgi-io", "backup", "read"))
                return failure(403, 0, "Backup permission denied");
 
        if (pipe(fds))
@@ -929,7 +989,7 @@ main_exec(int argc, char **argv)
        char *p, **args;
        pid_t pid;
 
-       postdecode(fields, 4);
+       autochar *post = postdecode(fields, 4);
 
        if (!fields[1] || !session_access(fields[1], "cgi-io", "exec", "read"))
                return failure(403, 0, "Exec permission denied");