From 3faf2fd92946089e7f9fb22d2f2e4b5aeba2f2d2 Mon Sep 17 00:00:00 2001 From: Vomitblood Date: Mon, 11 Nov 2024 18:47:15 +0800 Subject: [PATCH] server with helper functions --- README | 6 +- server/internal/db/db.go | 82 +++++++++++++++++-- server/internal/http_server/http_server.go | 6 +- .../internal/sql_injection/sql_injection.go | 30 +++++-- 4 files changed, 111 insertions(+), 13 deletions(-) diff --git a/README b/README index 2a0a542..f1313fa 100644 --- a/README +++ b/README @@ -31,6 +31,9 @@ PGPASSWORD=asdfpassword ## Server +- `/SetupDemoDb` +- `/NukeDb` + ### SQL Injection - `/sql-execute` @@ -46,7 +49,8 @@ Parameterized queries separate the SQL code from the data, so user input is neve Only allow `SELECT` statement by verifying that the input query starts with it. Sanitized the input to ensure that no other types of statements could be executed. +The input is checked against a list of allowed query terms, and if it doesn't match, the query is rejected. -#### 3. Controller JSON Input for Parameters +#### 3. Controlled JSON Input for Parameters Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields. diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 2fe1c1e..7b3cca9 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -2,6 +2,7 @@ package db import ( "context" + "encoding/json" "fmt" "log" "net/http" @@ -47,16 +48,17 @@ func SetupDemoDb(w http.ResponseWriter, r *http.Request) { CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, username VARCHAR(50) UNIQUE NOT NULL, - email VARCHAR(100) NOT NULL + email VARCHAR(100) NOT NULL, + password VARCHAR(100) NOT NULL );` // also avoid duplicate entries insertDataSQL := ` - INSERT INTO users (username, email) VALUES - ('alice', 'alice@example.com'), - ('bob', 'bob@example.com'), - ('charlie', 'charlie@example.com') - ON CONFLICT (username) DO NOTHING;` + INSERT INTO users (username, email, password) VALUES + ('alice', 'alice@example.com', 'asdfalicepassword'), + ('bob', 'bob@example.com', 'asdfbobpassword'), + ('charlie', 'charlie@example.com', 'asdfcharliepassword') + ON CONFLICT (username) DO NOTHING;` // execute create table _, err := DbPool.Exec(context.Background(), createTableSQL) @@ -95,6 +97,8 @@ func NukeDb(w http.ResponseWriter, r *http.Request) { return } + w.WriteHeader(http.StatusOK) + w.Write([]byte("Bye bye")) log.Println("Database nuked") } @@ -119,3 +123,69 @@ func FetchUsernames() (map[string]bool, error) { log.Println("Fetched usernames:", usernames) return usernames, nil } + +// fetch all users for demo +func FetchAllUsers(w http.ResponseWriter, r *http.Request) { + // construct sql query to select all users + query := "SELECT * FROM users" + + // execute the query + rows, err := DbPool.Query(context.Background(), query) + if err != nil { + http.Error(w, "Failed to retrieve users", http.StatusInternalServerError) + log.Printf("Error executing query: %v", err) + return + } + defer rows.Close() + + // define a slice to hold user data + users := []map[string]interface{}{} + + // get column names + columnNames := rows.FieldDescriptions() + + // iterate over the rows and build the result set + for rows.Next() { + // create a slice to hold the values for each row + values := make([]interface{}, len(columnNames)) + valuePointers := make([]interface{}, len(columnNames)) + for i := range values { + valuePointers[i] = &values[i] + } + + // scan the row into slice of interfaces + if err := rows.Scan(valuePointers...); err != nil { + http.Error(w, "Failed to scan user data", http.StatusInternalServerError) + log.Printf("Error scanning row: %v", err) + return + } + + // create a map for the row data + user := make(map[string]interface{}) + for i, col := range columnNames { + user[string(col.Name)] = values[i] + } + + // append the user map to the users slice + users = append(users, user) + } + + // check for any errors encountered during the iteration + if err = rows.Err(); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + log.Printf("Error encountered during iteration of rows: %v", err) + return + } + + log.Printf("All Users Data: %v", users) + + // Encode the users slice as JSON and write it to the response + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(users); err != nil { + http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError) + log.Printf("Error encoding response: %v", err) + return + } + + log.Println("Response successfully written to client") +} diff --git a/server/internal/http_server/http_server.go b/server/internal/http_server/http_server.go index bcb1e29..06bc217 100644 --- a/server/internal/http_server/http_server.go +++ b/server/internal/http_server/http_server.go @@ -4,14 +4,18 @@ import ( "log" "net/http" + "github.com/Vomitblood/cspj-application/server/internal/db" "github.com/Vomitblood/cspj-application/server/internal/sql_injection" ) // setup the http server func ServeApi() { + http.HandleFunc("/setup-demo-db", db.SetupDemoDb) + http.HandleFunc("/nuke-db", db.NukeDb) + http.HandleFunc("/fetch-all-users", db.FetchAllUsers) http.HandleFunc("/execute-sql", sql_injection.ExecuteSql) http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql) - http.HandleFunc("/secure-get-user", sql_injection.SecureExecuteSql) + http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser) log.Println("Server is running on http://localhost:3001") if err := http.ListenAndServe(":3001", nil); err != nil { log.Fatalf("Failed to start server: %v", err) diff --git a/server/internal/sql_injection/sql_injection.go b/server/internal/sql_injection/sql_injection.go index 765ead0..adf3522 100644 --- a/server/internal/sql_injection/sql_injection.go +++ b/server/internal/sql_injection/sql_injection.go @@ -105,6 +105,29 @@ func SecureExecuteSql(w http.ResponseWriter, r *http.Request) { // even more secure func SecureGetUser(w http.ResponseWriter, r *http.Request) { + // decode the json body + var requestData struct { + Username string `json:"username"` + } + + // declare new json decoder with custom property + jsonDecoder := json.NewDecoder(r.Body) + // rejects any unknown fields in the json, more strict + jsonDecoder.DisallowUnknownFields() + + if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil { + http.Error(w, "Invalid request format", http.StatusBadRequest) + log.Println("Failed to decode JSON body or extra fields present:", err) + return + } + + // validate that user is provided + if requestData.Username == "" { + http.Error(w, "Invalid request format", http.StatusBadRequest) + log.Println("Username is missing in the request body") + return + } + // retrieve list of existing usernames from db existingUsernames, err := db.FetchUsernames() if err != nil { @@ -113,13 +136,10 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) { return } - // get the username from the query parameter - username := r.URL.Query().Get("username") - // check if the username exists in the allowed list // this step is crucial // server will reject ANYTHING that does not match the list - if !existingUsernames[username] { + if !existingUsernames[requestData.Username] { http.Error(w, "Invalid username", http.StatusBadRequest) return } @@ -128,7 +148,7 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) { query := "SELECT id, username, email FROM users WHERE username = $1" var id int var dbUsername, email string - err = db.DbPool.QueryRow(context.Background(), query, username).Scan(&id, &dbUsername, &email) + err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email) if err != nil { http.Error(w, "User not found", http.StatusNotFound) return