server with helper functions
This commit is contained in:
		
							parent
							
								
									a66aa29275
								
							
						
					
					
						commit
						3faf2fd929
					
				
							
								
								
									
										6
									
								
								README
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								README
									
									
									
									
									
								
							|  | @ -31,6 +31,9 @@ PGPASSWORD=asdfpassword | ||||||
| 
 | 
 | ||||||
| ## Server | ## Server | ||||||
| 
 | 
 | ||||||
|  | - `/SetupDemoDb` | ||||||
|  | - `/NukeDb` | ||||||
|  | 
 | ||||||
| ### SQL Injection | ### SQL Injection | ||||||
| 
 | 
 | ||||||
| - `/sql-execute` | - `/sql-execute` | ||||||
|  | @ -46,7 +49,8 @@ Parameterized queries separate the SQL code from the data, so user input is neve | ||||||
| 
 | 
 | ||||||
| Only allow `SELECT` statement by verifying that the input query starts with it.   | Only allow `SELECT` statement by verifying that the input query starts with it.   | ||||||
| Sanitized the input to ensure that no other types of statements could be executed.   | Sanitized the input to ensure that no other types of statements could be executed.   | ||||||
|  | The input is checked against a list of allowed query terms, and if it doesn't match, the query is rejected. | ||||||
| 
 | 
 | ||||||
| #### 3. Controller JSON Input for Parameters | #### 3. Controlled JSON Input for Parameters | ||||||
| 
 | 
 | ||||||
| Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields. | Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields. | ||||||
|  |  | ||||||
|  | @ -2,6 +2,7 @@ package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"log" | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | @ -47,16 +48,17 @@ func SetupDemoDb(w http.ResponseWriter, r *http.Request) { | ||||||
|     CREATE TABLE IF NOT EXISTS users ( |     CREATE TABLE IF NOT EXISTS users ( | ||||||
|         id SERIAL PRIMARY KEY, |         id SERIAL PRIMARY KEY, | ||||||
|         username VARCHAR(50) UNIQUE NOT NULL, |         username VARCHAR(50) UNIQUE NOT NULL, | ||||||
|         email VARCHAR(100) NOT NULL |         email VARCHAR(100) NOT NULL, | ||||||
|  |         password VARCHAR(100) NOT NULL | ||||||
|     );` |     );` | ||||||
| 
 | 
 | ||||||
| 	// also avoid duplicate entries
 | 	// also avoid duplicate entries
 | ||||||
| 	insertDataSQL := ` | 	insertDataSQL := ` | ||||||
|     INSERT INTO users (username, email) VALUES |     INSERT INTO users (username, email, password) VALUES | ||||||
|         ('alice', 'alice@example.com'), |         ('alice', 'alice@example.com', 'asdfalicepassword'), | ||||||
|         ('bob', 'bob@example.com'), |         ('bob', 'bob@example.com', 'asdfbobpassword'), | ||||||
|         ('charlie', 'charlie@example.com') |         ('charlie', 'charlie@example.com', 'asdfcharliepassword') | ||||||
| 		ON CONFLICT (username) DO NOTHING;` |     ON CONFLICT (username) DO NOTHING;` | ||||||
| 
 | 
 | ||||||
| 	// execute create table
 | 	// execute create table
 | ||||||
| 	_, err := DbPool.Exec(context.Background(), createTableSQL) | 	_, err := DbPool.Exec(context.Background(), createTableSQL) | ||||||
|  | @ -95,6 +97,8 @@ func NukeDb(w http.ResponseWriter, r *http.Request) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	w.WriteHeader(http.StatusOK) | ||||||
|  | 	w.Write([]byte("Bye bye")) | ||||||
| 	log.Println("Database nuked") | 	log.Println("Database nuked") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -119,3 +123,69 @@ func FetchUsernames() (map[string]bool, error) { | ||||||
| 	log.Println("Fetched usernames:", usernames) | 	log.Println("Fetched usernames:", usernames) | ||||||
| 	return usernames, nil | 	return usernames, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // fetch all users for demo
 | ||||||
|  | func FetchAllUsers(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	// construct sql query to select all users
 | ||||||
|  | 	query := "SELECT * FROM users" | ||||||
|  | 
 | ||||||
|  | 	// execute the query
 | ||||||
|  | 	rows, err := DbPool.Query(context.Background(), query) | ||||||
|  | 	if err != nil { | ||||||
|  | 		http.Error(w, "Failed to retrieve users", http.StatusInternalServerError) | ||||||
|  | 		log.Printf("Error executing query: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	defer rows.Close() | ||||||
|  | 
 | ||||||
|  | 	// define a slice to hold user data
 | ||||||
|  | 	users := []map[string]interface{}{} | ||||||
|  | 
 | ||||||
|  | 	// get column names
 | ||||||
|  | 	columnNames := rows.FieldDescriptions() | ||||||
|  | 
 | ||||||
|  | 	// iterate over the rows and build the result set
 | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		// create a slice to hold the values for each row
 | ||||||
|  | 		values := make([]interface{}, len(columnNames)) | ||||||
|  | 		valuePointers := make([]interface{}, len(columnNames)) | ||||||
|  | 		for i := range values { | ||||||
|  | 			valuePointers[i] = &values[i] | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// scan the row into slice of interfaces
 | ||||||
|  | 		if err := rows.Scan(valuePointers...); err != nil { | ||||||
|  | 			http.Error(w, "Failed to scan user data", http.StatusInternalServerError) | ||||||
|  | 			log.Printf("Error scanning row: %v", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// create a map for the row data
 | ||||||
|  | 		user := make(map[string]interface{}) | ||||||
|  | 		for i, col := range columnNames { | ||||||
|  | 			user[string(col.Name)] = values[i] | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// append the user map to the users slice
 | ||||||
|  | 		users = append(users, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// check for any errors encountered during the iteration
 | ||||||
|  | 	if err = rows.Err(); err != nil { | ||||||
|  | 		http.Error(w, "Internal server error", http.StatusInternalServerError) | ||||||
|  | 		log.Printf("Error encountered during iteration of rows: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	log.Printf("All Users Data: %v", users) | ||||||
|  | 
 | ||||||
|  | 	// Encode the users slice as JSON and write it to the response
 | ||||||
|  | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | 	if err := json.NewEncoder(w).Encode(users); err != nil { | ||||||
|  | 		http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError) | ||||||
|  | 		log.Printf("Error encoding response: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	log.Println("Response successfully written to client") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -4,14 +4,18 @@ import ( | ||||||
| 	"log" | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/Vomitblood/cspj-application/server/internal/db" | ||||||
| 	"github.com/Vomitblood/cspj-application/server/internal/sql_injection" | 	"github.com/Vomitblood/cspj-application/server/internal/sql_injection" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // setup the http server
 | // setup the http server
 | ||||||
| func ServeApi() { | func ServeApi() { | ||||||
|  | 	http.HandleFunc("/setup-demo-db", db.SetupDemoDb) | ||||||
|  | 	http.HandleFunc("/nuke-db", db.NukeDb) | ||||||
|  | 	http.HandleFunc("/fetch-all-users", db.FetchAllUsers) | ||||||
| 	http.HandleFunc("/execute-sql", sql_injection.ExecuteSql) | 	http.HandleFunc("/execute-sql", sql_injection.ExecuteSql) | ||||||
| 	http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql) | 	http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql) | ||||||
| 	http.HandleFunc("/secure-get-user", sql_injection.SecureExecuteSql) | 	http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser) | ||||||
| 	log.Println("Server is running on http://localhost:3001") | 	log.Println("Server is running on http://localhost:3001") | ||||||
| 	if err := http.ListenAndServe(":3001", nil); err != nil { | 	if err := http.ListenAndServe(":3001", nil); err != nil { | ||||||
| 		log.Fatalf("Failed to start server: %v", err) | 		log.Fatalf("Failed to start server: %v", err) | ||||||
|  |  | ||||||
|  | @ -105,6 +105,29 @@ func SecureExecuteSql(w http.ResponseWriter, r *http.Request) { | ||||||
| 
 | 
 | ||||||
| // even more secure
 | // even more secure
 | ||||||
| func SecureGetUser(w http.ResponseWriter, r *http.Request) { | func SecureGetUser(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	// decode the json body
 | ||||||
|  | 	var requestData struct { | ||||||
|  | 		Username string `json:"username"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// declare new json decoder with custom property
 | ||||||
|  | 	jsonDecoder := json.NewDecoder(r.Body) | ||||||
|  | 	// rejects any unknown fields in the json, more strict
 | ||||||
|  | 	jsonDecoder.DisallowUnknownFields() | ||||||
|  | 
 | ||||||
|  | 	if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil { | ||||||
|  | 		http.Error(w, "Invalid request format", http.StatusBadRequest) | ||||||
|  | 		log.Println("Failed to decode JSON body or extra fields present:", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// validate that user is provided
 | ||||||
|  | 	if requestData.Username == "" { | ||||||
|  | 		http.Error(w, "Invalid request format", http.StatusBadRequest) | ||||||
|  | 		log.Println("Username is missing in the request body") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// retrieve list of existing usernames from db
 | 	// retrieve list of existing usernames from db
 | ||||||
| 	existingUsernames, err := db.FetchUsernames() | 	existingUsernames, err := db.FetchUsernames() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -113,13 +136,10 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// get the username from the query parameter
 |  | ||||||
| 	username := r.URL.Query().Get("username") |  | ||||||
| 
 |  | ||||||
| 	// check if the username exists in the allowed list
 | 	// check if the username exists in the allowed list
 | ||||||
| 	// this step is crucial
 | 	// this step is crucial
 | ||||||
| 	// server will reject ANYTHING that does not match the list
 | 	// server will reject ANYTHING that does not match the list
 | ||||||
| 	if !existingUsernames[username] { | 	if !existingUsernames[requestData.Username] { | ||||||
| 		http.Error(w, "Invalid username", http.StatusBadRequest) | 		http.Error(w, "Invalid username", http.StatusBadRequest) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | @ -128,7 +148,7 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	query := "SELECT id, username, email FROM users WHERE username = $1" | 	query := "SELECT id, username, email FROM users WHERE username = $1" | ||||||
| 	var id int | 	var id int | ||||||
| 	var dbUsername, email string | 	var dbUsername, email string | ||||||
| 	err = db.DbPool.QueryRow(context.Background(), query, username).Scan(&id, &dbUsername, &email) | 	err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		http.Error(w, "User not found", http.StatusNotFound) | 		http.Error(w, "User not found", http.StatusNotFound) | ||||||
| 		return | 		return | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue