server with helper functions
This commit is contained in:
parent
a66aa29275
commit
3faf2fd929
6
README
6
README
|
@ -31,6 +31,9 @@ PGPASSWORD=asdfpassword
|
|||
|
||||
## Server
|
||||
|
||||
- `/SetupDemoDb`
|
||||
- `/NukeDb`
|
||||
|
||||
### SQL Injection
|
||||
|
||||
- `/sql-execute`
|
||||
|
@ -46,7 +49,8 @@ Parameterized queries separate the SQL code from the data, so user input is neve
|
|||
|
||||
Only allow `SELECT` statement by verifying that the input query starts with it.
|
||||
Sanitized the input to ensure that no other types of statements could be executed.
|
||||
The input is checked against a list of allowed query terms, and if it doesn't match, the query is rejected.
|
||||
|
||||
#### 3. Controller JSON Input for Parameters
|
||||
#### 3. Controlled JSON Input for Parameters
|
||||
|
||||
Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields.
|
||||
|
|
|
@ -2,6 +2,7 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
@ -47,16 +48,17 @@ func SetupDemoDb(w http.ResponseWriter, r *http.Request) {
|
|||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(50) UNIQUE NOT NULL,
|
||||
email VARCHAR(100) NOT NULL
|
||||
email VARCHAR(100) NOT NULL,
|
||||
password VARCHAR(100) NOT NULL
|
||||
);`
|
||||
|
||||
// also avoid duplicate entries
|
||||
insertDataSQL := `
|
||||
INSERT INTO users (username, email) VALUES
|
||||
('alice', 'alice@example.com'),
|
||||
('bob', 'bob@example.com'),
|
||||
('charlie', 'charlie@example.com')
|
||||
ON CONFLICT (username) DO NOTHING;`
|
||||
INSERT INTO users (username, email, password) VALUES
|
||||
('alice', 'alice@example.com', 'asdfalicepassword'),
|
||||
('bob', 'bob@example.com', 'asdfbobpassword'),
|
||||
('charlie', 'charlie@example.com', 'asdfcharliepassword')
|
||||
ON CONFLICT (username) DO NOTHING;`
|
||||
|
||||
// execute create table
|
||||
_, err := DbPool.Exec(context.Background(), createTableSQL)
|
||||
|
@ -95,6 +97,8 @@ func NukeDb(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Bye bye"))
|
||||
log.Println("Database nuked")
|
||||
}
|
||||
|
||||
|
@ -119,3 +123,69 @@ func FetchUsernames() (map[string]bool, error) {
|
|||
log.Println("Fetched usernames:", usernames)
|
||||
return usernames, nil
|
||||
}
|
||||
|
||||
// fetch all users for demo
|
||||
func FetchAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
// construct sql query to select all users
|
||||
query := "SELECT * FROM users"
|
||||
|
||||
// execute the query
|
||||
rows, err := DbPool.Query(context.Background(), query)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to retrieve users", http.StatusInternalServerError)
|
||||
log.Printf("Error executing query: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// define a slice to hold user data
|
||||
users := []map[string]interface{}{}
|
||||
|
||||
// get column names
|
||||
columnNames := rows.FieldDescriptions()
|
||||
|
||||
// iterate over the rows and build the result set
|
||||
for rows.Next() {
|
||||
// create a slice to hold the values for each row
|
||||
values := make([]interface{}, len(columnNames))
|
||||
valuePointers := make([]interface{}, len(columnNames))
|
||||
for i := range values {
|
||||
valuePointers[i] = &values[i]
|
||||
}
|
||||
|
||||
// scan the row into slice of interfaces
|
||||
if err := rows.Scan(valuePointers...); err != nil {
|
||||
http.Error(w, "Failed to scan user data", http.StatusInternalServerError)
|
||||
log.Printf("Error scanning row: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// create a map for the row data
|
||||
user := make(map[string]interface{})
|
||||
for i, col := range columnNames {
|
||||
user[string(col.Name)] = values[i]
|
||||
}
|
||||
|
||||
// append the user map to the users slice
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
// check for any errors encountered during the iteration
|
||||
if err = rows.Err(); err != nil {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
log.Printf("Error encountered during iteration of rows: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("All Users Data: %v", users)
|
||||
|
||||
// Encode the users slice as JSON and write it to the response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(users); err != nil {
|
||||
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
|
||||
log.Printf("Error encoding response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("Response successfully written to client")
|
||||
}
|
||||
|
|
|
@ -4,14 +4,18 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/Vomitblood/cspj-application/server/internal/db"
|
||||
"github.com/Vomitblood/cspj-application/server/internal/sql_injection"
|
||||
)
|
||||
|
||||
// setup the http server
|
||||
func ServeApi() {
|
||||
http.HandleFunc("/setup-demo-db", db.SetupDemoDb)
|
||||
http.HandleFunc("/nuke-db", db.NukeDb)
|
||||
http.HandleFunc("/fetch-all-users", db.FetchAllUsers)
|
||||
http.HandleFunc("/execute-sql", sql_injection.ExecuteSql)
|
||||
http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql)
|
||||
http.HandleFunc("/secure-get-user", sql_injection.SecureExecuteSql)
|
||||
http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser)
|
||||
log.Println("Server is running on http://localhost:3001")
|
||||
if err := http.ListenAndServe(":3001", nil); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
|
|
|
@ -105,6 +105,29 @@ func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// even more secure
|
||||
func SecureGetUser(w http.ResponseWriter, r *http.Request) {
|
||||
// decode the json body
|
||||
var requestData struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// declare new json decoder with custom property
|
||||
jsonDecoder := json.NewDecoder(r.Body)
|
||||
// rejects any unknown fields in the json, more strict
|
||||
jsonDecoder.DisallowUnknownFields()
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
|
||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||
log.Println("Failed to decode JSON body or extra fields present:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// validate that user is provided
|
||||
if requestData.Username == "" {
|
||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||
log.Println("Username is missing in the request body")
|
||||
return
|
||||
}
|
||||
|
||||
// retrieve list of existing usernames from db
|
||||
existingUsernames, err := db.FetchUsernames()
|
||||
if err != nil {
|
||||
|
@ -113,13 +136,10 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// get the username from the query parameter
|
||||
username := r.URL.Query().Get("username")
|
||||
|
||||
// check if the username exists in the allowed list
|
||||
// this step is crucial
|
||||
// server will reject ANYTHING that does not match the list
|
||||
if !existingUsernames[username] {
|
||||
if !existingUsernames[requestData.Username] {
|
||||
http.Error(w, "Invalid username", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
@ -128,7 +148,7 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
|
|||
query := "SELECT id, username, email FROM users WHERE username = $1"
|
||||
var id int
|
||||
var dbUsername, email string
|
||||
err = db.DbPool.QueryRow(context.Background(), query, username).Scan(&id, &dbUsername, &email)
|
||||
err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email)
|
||||
if err != nil {
|
||||
http.Error(w, "User not found", http.StatusNotFound)
|
||||
return
|
||||
|
|
Loading…
Reference in a new issue