mysql selector: Always check return value of escape_str()
[paraslash.git] / mysql_selector.c
index ff94d8d7c26e5cc8fcd1095a6e1eb85e17c3a348..45a10bac0e9264e96b38dda81b7233370c02958e 100644 (file)
@@ -32,7 +32,6 @@
 #include "net.h"
 #include "string.h"
 
-extern struct gengetopt_args_info conf;
 /** pointer to the shared memory area */
 extern struct misc_meta_data *mmd;
 
@@ -663,7 +662,7 @@ out:
        return ret;
 }
 
-static char *escape_blob(char* old, int size)
+static char *escape_blob(const char* old, int size)
 {
        char *new;
 
@@ -674,7 +673,7 @@ static char *escape_blob(char* old, int size)
        return new;
 }
 
-static char *escape_str(char* old)
+static char *escape_str(const char* old)
 {
        return escape_blob(old, strlen(old));
 }
@@ -1159,12 +1158,17 @@ static char *get_query(char *streamname, char *filename, int with_path)
        char *select_clause = NULL;
        if (!streamname)
                tmp = get_current_stream();
-       else
+       else {
                tmp = escape_str(streamname);
+               if (!tmp)
+                       return NULL;
+       }
        if (!strcmp(tmp, "(none)")) {
                free(tmp);
                if (filename) {
                        char *ret, *ebn = escaped_basename(filename);
+                       if (!ebn)
+                               return NULL;
                        ret = make_message("select to_days(now()) - "
                                "to_days(lastplayed) from data "
                                "where name = '%s'", ebn);
@@ -1840,10 +1844,10 @@ static int update_audio_file(char *name)
        ret = real_query(q);
        free(q);
 out:
-       if (ebn)
-               free(ebn);
+       free(ebn);
        return ret;
 }
+
 /* If called as child, mmd_lock must be held */
 static void update_mmd(char *info)
 {
@@ -2096,14 +2100,21 @@ static int com_sl(int fd, int argc, char *argv[])
        num = atoi(argv[1]);
        if (!num)
                return -E_MYSQL_SYNTAX;
-       stream = (argc == 2)?  get_current_stream() : escape_str(argv[2]);
+       if (argc == 2) {
+               stream = get_current_stream();
+               if (!stream)
+                       return -E_GET_STREAM;
+       } else {
+               stream = escape_str(argv[2]);
+               if (!stream)
+                       return -E_ESCAPE;
+       }
        tmp = get_query(stream, NULL, 0);
+       free(stream);
+       if (!tmp)
+               return -E_GET_QUERY;
        query = make_message("%s limit %d", tmp, num);
        free(tmp);
-       ret = -E_GET_QUERY;
-       free(stream);
-       if (!query)
-               goto out;
        ret = -E_NORESULT;
        result = get_result(query);
        free(query);
@@ -2321,7 +2332,6 @@ static int mysql_write_tmp_file(const char *dir, const char *name)
 {
        int ret = -E_TMPFILE;
        char *msg = make_message("%s\t%s\n", dir, name);
-
        if (fputs(msg, out_file) != EOF)
                ret = 1;
        free(msg);
@@ -2366,7 +2376,7 @@ static int com_upd(int fd, int argc, __a_unused char *argv[])
                goto out;
        if ((ret = real_query("delete from dir")) < 0)
                goto out;
-       query = make_message("load data infile '%s' into table dir "
+       query = make_message("load data infile '%s' ignore into table dir "
                "fields terminated by '\t' lines terminated by '\n' "
                "(dir, name)", tempname);
        ret = real_query(query);
@@ -2384,12 +2394,17 @@ static int com_upd(int fd, int argc, __a_unused char *argv[])
                goto out;
        }
        while ((row = mysql_fetch_row(result))) {
+               char *erow;
                ret = -E_NOROW;
                if (!row[0])
                        goto out;
                send_va_buffer(fd, "new entry: %s\n", row[0]);
+               erow = escape_str(row[0]);
+               if (!erow)
+                       goto out;
                query = make_message("insert into data (name, pic_id) values "
-                       "('%s','%s')", row[0], "1");
+                       "('%s','%s')", erow, "1");
+               free(erow);
                ret = real_query(query);
                free(query);
                if (ret < 0)
@@ -2418,11 +2433,12 @@ static char **server_get_audio_file_list(unsigned int num)
 
        tmp = get_query(stream, NULL, 1);
        free(stream);
+       if (!tmp)
+               goto err_out;
        query = make_message("%s limit %d", tmp, num);
        free(tmp);
-       if (!query)
-               goto err_out;
        result = get_result(query);
+       free(query);
        if (!result)
                goto err_out;
        num_rows = mysql_num_rows(result);
@@ -2444,8 +2460,6 @@ err_out:
        free(list);
        list = NULL;
 success:
-       if (query)
-               free(query);
        if (result)
                mysql_free_result(result);
        return list;
@@ -2511,8 +2525,12 @@ static int com_cdb(int fd, int argc, char *argv[])
                goto out;
        if (argc < 2)
                conf.mysql_database_arg = para_strdup("paraslash");
-       else
+       else {
+               ret = -E_ESCAPE;
                conf.mysql_database_arg = escape_str(argv[1]);
+               if (!conf.mysql_database_arg)
+                       goto out;
+       }
        query = make_message("create database %s", conf.mysql_database_arg);
        ret = real_query(query);
        free(query);