package sql_injection import ( "context" "encoding/json" "fmt" "io" "log" "net/http" "strings" "github.com/Vomitblood/cspj-application/server/internal/db" ) // unsecure version // take http reqeust body as raw sql and pass to db func ExecuteSql(w http.ResponseWriter, r *http.Request) { // read the request body sqlQuery, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Failed to read request body", http.StatusBadRequest) return } defer r.Body.Close() // execute the sql query without any sanitization rows, err := db.DbPool.Query(context.Background(), string(sqlQuery)) if err != nil { http.Error(w, "Query execution error", http.StatusInternalServerError) return } defer rows.Close() // prepare the response by iterating over the returned rows var response string for rows.Next() { values, err := rows.Values() if err != nil { http.Error(w, "Error reading query result", http.StatusInternalServerError) return } response += fmt.Sprintf("%v\n", values) } // send the response to the client w.Write([]byte(response)) } // unsecure login // login endpoint with sql injection vulnerability func LoginSql(w http.ResponseWriter, r *http.Request) { // parse the request body to get username and password var credentials struct { Username string `json:"username"` Password string `json:"password"` } // decode the json body if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil { http.Error(w, "Invalid request format", http.StatusBadRequest) return } defer r.Body.Close() // construct the unsafe query query := fmt.Sprintf( "SELECT id, username FROM users WHERE username = '%s' AND password = '&s'", credentials.Username, credentials.Password, ) // execute the query without sanitizing the input var id int var username string err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &username) if err != nil { http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } // if the user is found, return success response response := map[string]interface{}{ "message": "Login successful", "user_id": id, "username": username, } if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) log.Printf("JSON encoding error: %v", err) } } // secure version // only allow parameterized queries with validation func SecureExecuteSql(w http.ResponseWriter, r *http.Request) { var input struct { Query string `json:"query"` Params []interface{} `json:"params"` } // parse json request body if err := json.NewDecoder(r.Body).Decode(&input); err != nil { http.Error(w, "Invalid JSON format", http.StatusBadRequest) return } defer r.Body.Close() // simple validation, only allow select statements if !strings.HasPrefix(strings.ToUpper(input.Query), "SELECT") { http.Error(w, "Only SELECT queries are allowed", http.StatusForbidden) return } // execute the query as a parameterized statement rows, err := db.DbPool.Query(context.Background(), input.Query, input.Params...) if err != nil { http.Error(w, "Query execution error", http.StatusInternalServerError) return } defer rows.Close() // format the response var response []map[string]interface{} for rows.Next() { values, err := rows.Values() if err != nil { http.Error(w, "Error reading query result", http.StatusInternalServerError) return } rowMap := make(map[string]interface{}) fieldDescriptions := rows.FieldDescriptions() for i, fd := range fieldDescriptions { rowMap[string(fd.Name)] = values[i] } response = append(response, rowMap) } // return json response jsonResp, err := json.Marshal(response) if err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.Write(jsonResp) } // secure login func SecureLoginSql(w http.ResponseWriter, r *http.Request) { var credentials struct { Username string `json:"username"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil { http.Error(w, "Invalid request format", http.StatusBadRequest) return } defer r.Body.Close() // secure version using parameterized queries query := "SELECT id, username FROM users WHERE username = $1 AND password = $2" var id int var username string err := db.DbPool.QueryRow(context.Background(), query, credentials.Username, credentials.Password).Scan(&id, &username) if err != nil { http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } // send back the response if great success response := map[string]interface{}{ "message": "Login successful", "user_id": id, "username": username, } if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) log.Printf("JSON encoding error: %v", err) } } // even more secure func SecureGetUser(w http.ResponseWriter, r *http.Request) { // decode the json body var requestData struct { Username string `json:"username"` } // declare new json decoder with custom property jsonDecoder := json.NewDecoder(r.Body) // rejects any unknown fields in the json, more strict jsonDecoder.DisallowUnknownFields() if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil { http.Error(w, "Invalid request format", http.StatusBadRequest) log.Println("Failed to decode JSON body or extra fields present:", err) return } // validate that user is provided if requestData.Username == "" { http.Error(w, "Invalid request format", http.StatusBadRequest) log.Println("Username is missing in the request body") return } // retrieve list of existing usernames from db existingUsernames, err := db.FetchUsernames() if err != nil { http.Error(w, "Internal server error", http.StatusInternalServerError) log.Printf("Failed to fetch usernames: %v", err) return } // check if the username exists in the allowed list // this step is crucial // server will reject ANYTHING that does not match the list if !existingUsernames[requestData.Username] { http.Error(w, "Invalid username", http.StatusBadRequest) return } // construct the query query := "SELECT id, username, email FROM users WHERE username = $1" var id int var dbUsername, email string err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email) if err != nil { http.Error(w, "User not found", http.StatusNotFound) return } // send back the user data as a json response response := map[string]interface{}{ "id": id, "username": dbUsername, "email": email, } if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError) log.Printf("JSON encoding error: %v", err) } }