server with helper functions

This commit is contained in:
Vomitblood 2024-11-11 18:47:15 +08:00
parent a66aa29275
commit 3faf2fd929
4 changed files with 111 additions and 13 deletions

6
README
View file

@ -31,6 +31,9 @@ PGPASSWORD=asdfpassword
## Server ## Server
- `/SetupDemoDb`
- `/NukeDb`
### SQL Injection ### SQL Injection
- `/sql-execute` - `/sql-execute`
@ -46,7 +49,8 @@ Parameterized queries separate the SQL code from the data, so user input is neve
Only allow `SELECT` statement by verifying that the input query starts with it. Only allow `SELECT` statement by verifying that the input query starts with it.
Sanitized the input to ensure that no other types of statements could be executed. Sanitized the input to ensure that no other types of statements could be executed.
The input is checked against a list of allowed query terms, and if it doesn't match, the query is rejected.
#### 3. Controller JSON Input for Parameters #### 3. Controlled JSON Input for Parameters
Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields. Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields.

View file

@ -2,6 +2,7 @@ package db
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
@ -47,16 +48,17 @@ func SetupDemoDb(w http.ResponseWriter, r *http.Request) {
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL, username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) NOT NULL email VARCHAR(100) NOT NULL,
password VARCHAR(100) NOT NULL
);` );`
// also avoid duplicate entries // also avoid duplicate entries
insertDataSQL := ` insertDataSQL := `
INSERT INTO users (username, email) VALUES INSERT INTO users (username, email, password) VALUES
('alice', 'alice@example.com'), ('alice', 'alice@example.com', 'asdfalicepassword'),
('bob', 'bob@example.com'), ('bob', 'bob@example.com', 'asdfbobpassword'),
('charlie', 'charlie@example.com') ('charlie', 'charlie@example.com', 'asdfcharliepassword')
ON CONFLICT (username) DO NOTHING;` ON CONFLICT (username) DO NOTHING;`
// execute create table // execute create table
_, err := DbPool.Exec(context.Background(), createTableSQL) _, err := DbPool.Exec(context.Background(), createTableSQL)
@ -95,6 +97,8 @@ func NukeDb(w http.ResponseWriter, r *http.Request) {
return return
} }
w.WriteHeader(http.StatusOK)
w.Write([]byte("Bye bye"))
log.Println("Database nuked") log.Println("Database nuked")
} }
@ -119,3 +123,69 @@ func FetchUsernames() (map[string]bool, error) {
log.Println("Fetched usernames:", usernames) log.Println("Fetched usernames:", usernames)
return usernames, nil return usernames, nil
} }
// fetch all users for demo
func FetchAllUsers(w http.ResponseWriter, r *http.Request) {
// construct sql query to select all users
query := "SELECT * FROM users"
// execute the query
rows, err := DbPool.Query(context.Background(), query)
if err != nil {
http.Error(w, "Failed to retrieve users", http.StatusInternalServerError)
log.Printf("Error executing query: %v", err)
return
}
defer rows.Close()
// define a slice to hold user data
users := []map[string]interface{}{}
// get column names
columnNames := rows.FieldDescriptions()
// iterate over the rows and build the result set
for rows.Next() {
// create a slice to hold the values for each row
values := make([]interface{}, len(columnNames))
valuePointers := make([]interface{}, len(columnNames))
for i := range values {
valuePointers[i] = &values[i]
}
// scan the row into slice of interfaces
if err := rows.Scan(valuePointers...); err != nil {
http.Error(w, "Failed to scan user data", http.StatusInternalServerError)
log.Printf("Error scanning row: %v", err)
return
}
// create a map for the row data
user := make(map[string]interface{})
for i, col := range columnNames {
user[string(col.Name)] = values[i]
}
// append the user map to the users slice
users = append(users, user)
}
// check for any errors encountered during the iteration
if err = rows.Err(); err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
log.Printf("Error encountered during iteration of rows: %v", err)
return
}
log.Printf("All Users Data: %v", users)
// Encode the users slice as JSON and write it to the response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(users); err != nil {
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
log.Printf("Error encoding response: %v", err)
return
}
log.Println("Response successfully written to client")
}

View file

@ -4,14 +4,18 @@ import (
"log" "log"
"net/http" "net/http"
"github.com/Vomitblood/cspj-application/server/internal/db"
"github.com/Vomitblood/cspj-application/server/internal/sql_injection" "github.com/Vomitblood/cspj-application/server/internal/sql_injection"
) )
// setup the http server // setup the http server
func ServeApi() { func ServeApi() {
http.HandleFunc("/setup-demo-db", db.SetupDemoDb)
http.HandleFunc("/nuke-db", db.NukeDb)
http.HandleFunc("/fetch-all-users", db.FetchAllUsers)
http.HandleFunc("/execute-sql", sql_injection.ExecuteSql) http.HandleFunc("/execute-sql", sql_injection.ExecuteSql)
http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql) http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql)
http.HandleFunc("/secure-get-user", sql_injection.SecureExecuteSql) http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser)
log.Println("Server is running on http://localhost:3001") log.Println("Server is running on http://localhost:3001")
if err := http.ListenAndServe(":3001", nil); err != nil { if err := http.ListenAndServe(":3001", nil); err != nil {
log.Fatalf("Failed to start server: %v", err) log.Fatalf("Failed to start server: %v", err)

View file

@ -105,6 +105,29 @@ func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
// even more secure // even more secure
func SecureGetUser(w http.ResponseWriter, r *http.Request) { func SecureGetUser(w http.ResponseWriter, r *http.Request) {
// decode the json body
var requestData struct {
Username string `json:"username"`
}
// declare new json decoder with custom property
jsonDecoder := json.NewDecoder(r.Body)
// rejects any unknown fields in the json, more strict
jsonDecoder.DisallowUnknownFields()
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
http.Error(w, "Invalid request format", http.StatusBadRequest)
log.Println("Failed to decode JSON body or extra fields present:", err)
return
}
// validate that user is provided
if requestData.Username == "" {
http.Error(w, "Invalid request format", http.StatusBadRequest)
log.Println("Username is missing in the request body")
return
}
// retrieve list of existing usernames from db // retrieve list of existing usernames from db
existingUsernames, err := db.FetchUsernames() existingUsernames, err := db.FetchUsernames()
if err != nil { if err != nil {
@ -113,13 +136,10 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
return return
} }
// get the username from the query parameter
username := r.URL.Query().Get("username")
// check if the username exists in the allowed list // check if the username exists in the allowed list
// this step is crucial // this step is crucial
// server will reject ANYTHING that does not match the list // server will reject ANYTHING that does not match the list
if !existingUsernames[username] { if !existingUsernames[requestData.Username] {
http.Error(w, "Invalid username", http.StatusBadRequest) http.Error(w, "Invalid username", http.StatusBadRequest)
return return
} }
@ -128,7 +148,7 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
query := "SELECT id, username, email FROM users WHERE username = $1" query := "SELECT id, username, email FROM users WHERE username = $1"
var id int var id int
var dbUsername, email string var dbUsername, email string
err = db.DbPool.QueryRow(context.Background(), query, username).Scan(&id, &dbUsername, &email) err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email)
if err != nil { if err != nil {
http.Error(w, "User not found", http.StatusNotFound) http.Error(w, "User not found", http.StatusNotFound)
return return