server with helper functions

This commit is contained in:
Vomitblood 2024-11-11 18:47:15 +08:00
parent a66aa29275
commit 3faf2fd929
4 changed files with 111 additions and 13 deletions

6
README
View file

@ -31,6 +31,9 @@ PGPASSWORD=asdfpassword
## Server
- `/SetupDemoDb`
- `/NukeDb`
### SQL Injection
- `/sql-execute`
@ -46,7 +49,8 @@ Parameterized queries separate the SQL code from the data, so user input is neve
Only allow `SELECT` statement by verifying that the input query starts with it.
Sanitized the input to ensure that no other types of statements could be executed.
The input is checked against a list of allowed query terms, and if it doesn't match, the query is rejected.
#### 3. Controller JSON Input for Parameters
#### 3. Controlled JSON Input for Parameters
Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields.

View file

@ -2,6 +2,7 @@ package db
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
@ -47,16 +48,17 @@ func SetupDemoDb(w http.ResponseWriter, r *http.Request) {
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) NOT NULL
email VARCHAR(100) NOT NULL,
password VARCHAR(100) NOT NULL
);`
// also avoid duplicate entries
insertDataSQL := `
INSERT INTO users (username, email) VALUES
('alice', 'alice@example.com'),
('bob', 'bob@example.com'),
('charlie', 'charlie@example.com')
ON CONFLICT (username) DO NOTHING;`
INSERT INTO users (username, email, password) VALUES
('alice', 'alice@example.com', 'asdfalicepassword'),
('bob', 'bob@example.com', 'asdfbobpassword'),
('charlie', 'charlie@example.com', 'asdfcharliepassword')
ON CONFLICT (username) DO NOTHING;`
// execute create table
_, err := DbPool.Exec(context.Background(), createTableSQL)
@ -95,6 +97,8 @@ func NukeDb(w http.ResponseWriter, r *http.Request) {
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("Bye bye"))
log.Println("Database nuked")
}
@ -119,3 +123,69 @@ func FetchUsernames() (map[string]bool, error) {
log.Println("Fetched usernames:", usernames)
return usernames, nil
}
// fetch all users for demo
func FetchAllUsers(w http.ResponseWriter, r *http.Request) {
// construct sql query to select all users
query := "SELECT * FROM users"
// execute the query
rows, err := DbPool.Query(context.Background(), query)
if err != nil {
http.Error(w, "Failed to retrieve users", http.StatusInternalServerError)
log.Printf("Error executing query: %v", err)
return
}
defer rows.Close()
// define a slice to hold user data
users := []map[string]interface{}{}
// get column names
columnNames := rows.FieldDescriptions()
// iterate over the rows and build the result set
for rows.Next() {
// create a slice to hold the values for each row
values := make([]interface{}, len(columnNames))
valuePointers := make([]interface{}, len(columnNames))
for i := range values {
valuePointers[i] = &values[i]
}
// scan the row into slice of interfaces
if err := rows.Scan(valuePointers...); err != nil {
http.Error(w, "Failed to scan user data", http.StatusInternalServerError)
log.Printf("Error scanning row: %v", err)
return
}
// create a map for the row data
user := make(map[string]interface{})
for i, col := range columnNames {
user[string(col.Name)] = values[i]
}
// append the user map to the users slice
users = append(users, user)
}
// check for any errors encountered during the iteration
if err = rows.Err(); err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
log.Printf("Error encountered during iteration of rows: %v", err)
return
}
log.Printf("All Users Data: %v", users)
// Encode the users slice as JSON and write it to the response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(users); err != nil {
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
log.Printf("Error encoding response: %v", err)
return
}
log.Println("Response successfully written to client")
}

View file

@ -4,14 +4,18 @@ import (
"log"
"net/http"
"github.com/Vomitblood/cspj-application/server/internal/db"
"github.com/Vomitblood/cspj-application/server/internal/sql_injection"
)
// setup the http server
func ServeApi() {
http.HandleFunc("/setup-demo-db", db.SetupDemoDb)
http.HandleFunc("/nuke-db", db.NukeDb)
http.HandleFunc("/fetch-all-users", db.FetchAllUsers)
http.HandleFunc("/execute-sql", sql_injection.ExecuteSql)
http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql)
http.HandleFunc("/secure-get-user", sql_injection.SecureExecuteSql)
http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser)
log.Println("Server is running on http://localhost:3001")
if err := http.ListenAndServe(":3001", nil); err != nil {
log.Fatalf("Failed to start server: %v", err)

View file

@ -105,6 +105,29 @@ func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
// even more secure
func SecureGetUser(w http.ResponseWriter, r *http.Request) {
// decode the json body
var requestData struct {
Username string `json:"username"`
}
// declare new json decoder with custom property
jsonDecoder := json.NewDecoder(r.Body)
// rejects any unknown fields in the json, more strict
jsonDecoder.DisallowUnknownFields()
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
http.Error(w, "Invalid request format", http.StatusBadRequest)
log.Println("Failed to decode JSON body or extra fields present:", err)
return
}
// validate that user is provided
if requestData.Username == "" {
http.Error(w, "Invalid request format", http.StatusBadRequest)
log.Println("Username is missing in the request body")
return
}
// retrieve list of existing usernames from db
existingUsernames, err := db.FetchUsernames()
if err != nil {
@ -113,13 +136,10 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
return
}
// get the username from the query parameter
username := r.URL.Query().Get("username")
// check if the username exists in the allowed list
// this step is crucial
// server will reject ANYTHING that does not match the list
if !existingUsernames[username] {
if !existingUsernames[requestData.Username] {
http.Error(w, "Invalid username", http.StatusBadRequest)
return
}
@ -128,7 +148,7 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
query := "SELECT id, username, email FROM users WHERE username = $1"
var id int
var dbUsername, email string
err = db.DbPool.QueryRow(context.Background(), query, username).Scan(&id, &dbUsername, &email)
err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email)
if err != nil {
http.Error(w, "User not found", http.StatusNotFound)
return