From 53c0ada12e9f0140182f8ca51f48e12cc464d4ef Mon Sep 17 00:00:00 2001 From: Vomitblood Date: Tue, 14 Jan 2025 03:18:08 +0800 Subject: [PATCH] login --- README.md | 5 +- .../Pages/SqlInjection/SqlInjectionLogin.tsx | 128 ++++-- .../SqlInjection/SqlInjectionRegister.tsx | 54 +-- server/internal/db/db.go | 18 +- server/internal/http_server/http_server.go | 6 +- .../internal/sql_injection/sql_injection.go | 368 ++++++++---------- 6 files changed, 304 insertions(+), 275 deletions(-) diff --git a/README.md b/README.md index 0ed8ccd..f6aebcc 100644 --- a/README.md +++ b/README.md @@ -41,11 +41,8 @@ PGPASSWORD=asdfpassword ### SQL Injection -- `/sql-execute` -- `/login-sql` -- `/secure-sql-execute` +- `/secure-register-sql` - `/secure-login-sql` -- `/secure-get-user` #### 1. Parameterization of Queries diff --git a/client/src/components/Pages/SqlInjection/SqlInjectionLogin.tsx b/client/src/components/Pages/SqlInjection/SqlInjectionLogin.tsx index 032d137..4e73d86 100644 --- a/client/src/components/Pages/SqlInjection/SqlInjectionLogin.tsx +++ b/client/src/components/Pages/SqlInjection/SqlInjectionLogin.tsx @@ -1,12 +1,66 @@ -import { Box, LinearProgress, TextField, Typography } from "@mui/material"; +import { Box, TextField, Typography } from "@mui/material"; +import { useAtom } from "jotai"; import { useState } from "react"; +import { useNotification } from "../../../contexts/NotificationContext"; +import { serverUrlAtom } from "../../../lib/jotai"; import { HeaderLogo } from "../../Generic/HeaderLogo"; +import { LoadingButton } from "@mui/lab"; export const SqlInjectionLogin = () => { + // contexts + const { openNotification } = useNotification(); + + // atoms + const [serverUrl, setServerUrl] = useAtom(serverUrlAtom); + // states const [emailValueRaw, setEmailValueRaw] = useState(""); const [passwordValueRaw, setPasswordValueRaw] = useState(""); - const [passwordErrorMsg, setPasswordErrorMsg] = useState(""); + const [errorMsg, setErrorMsg] = useState(""); + const [loginLoading, setLoginLoading] = useState(false); + + const nextClickEvent = async () => { + // reset the error messages + setErrorMsg(""); + + // ensure that the server url does not end with a trailing slash + setServerUrl(serverUrl.replace(/\/$/, "")); + + // construct the request body + const requestBody = { + email: emailValueRaw, + password: passwordValueRaw, + }; + + // start loading indicator + setLoginLoading(true); + + try { + // make request good + const response = await fetch(serverUrl + "/register-sql", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(requestBody), + }); + + // check if login was successful + if (!response.ok) { + const errorMessage = await response.text(); + console.log("Login failed:", errorMessage); + setErrorMsg(errorMessage); + } else { + openNotification("Login successful"); + } + } catch (e) { + // log the error and handle failure + console.log("Request failed", e); + } finally { + // stop loading indicator regardless of success/failure + setLoginLoading(false); + } + }; return ( <> @@ -34,30 +88,52 @@ export const SqlInjectionLogin = () => { flexDirection: "column", }} > - setEmailValueRaw(e.target.value)} - size="small" - type="email" - value={emailValueRaw} - sx={{ mb: 2 }} - variant="outlined" - /> - setPasswordValueRaw(e.target.value)} - size="small" - type="password" - value={passwordValueRaw} - sx={{ mb: 2 }} - variant="outlined" - /> +
{ + e.preventDefault(); + nextClickEvent(); + }} + > + setEmailValueRaw(e.target.value)} + size="small" + type="email" + value={emailValueRaw} + sx={{ mb: 2 }} + variant="outlined" + /> + setPasswordValueRaw(e.target.value)} + size="small" + type="password" + value={passwordValueRaw} + sx={{ mb: 2 }} + variant="outlined" + /> + + + Next + + + ); diff --git a/client/src/components/Pages/SqlInjection/SqlInjectionRegister.tsx b/client/src/components/Pages/SqlInjection/SqlInjectionRegister.tsx index 5633a92..88c3c86 100644 --- a/client/src/components/Pages/SqlInjection/SqlInjectionRegister.tsx +++ b/client/src/components/Pages/SqlInjection/SqlInjectionRegister.tsx @@ -107,7 +107,7 @@ export const SqlInjectionRegister = () => { try { // make request good - const response = await fetch(serverUrl + "/register-sql", { + const response = await fetch(serverUrl + "/secure-register-sql", { method: "POST", headers: { "Content-Type": "application/json", @@ -124,10 +124,10 @@ export const SqlInjectionRegister = () => { openNotification("Registration successful"); } } catch (e) { - // Log the error and handle failure + // log the error and handle failure console.log("Request failed", e); } finally { - // Stop loading indicator regardless of success/failure + // stop loading indicator regardless of success/failure setRegisterLoading(false); } }; @@ -152,16 +152,16 @@ export const SqlInjectionRegister = () => {  Register -
{ - e.preventDefault(); - nextClickEvent(); + - { + e.preventDefault(); + nextClickEvent(); }} > {
  • At least one special character
  • -
    - - - Next - - - + + Next + +
    + + ); }; diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 8f2dec0..88d1369 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -135,25 +135,25 @@ func NukeDb(w http.ResponseWriter, r *http.Request) { } // fetch existing usernames from db -func FetchUsernames() (map[string]bool, error) { - usernames := make(map[string]bool) +func FetchEmails() (map[string]bool, error) { + emails := make(map[string]bool) - rows, err := DbPool.Query(context.Background(), "SELECT username FROM users") + rows, err := DbPool.Query(context.Background(), "SELECT email FROM users") if err != nil { return nil, fmt.Errorf("error querying users: %w", err) } defer rows.Close() for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - return nil, fmt.Errorf("error scanning username: %w", err) + var email string + if err := rows.Scan(&email); err != nil { + return nil, fmt.Errorf("error scanning email: %w", err) } - usernames[username] = true + emails[email] = true } - log.Println("Fetched usernames:", usernames) - return usernames, nil + log.Println("Fetched emails:", emails) + return emails, nil } // fetch all users for demo diff --git a/server/internal/http_server/http_server.go b/server/internal/http_server/http_server.go index 3fa596f..3772679 100644 --- a/server/internal/http_server/http_server.go +++ b/server/internal/http_server/http_server.go @@ -21,12 +21,8 @@ func ServeApi() { http.HandleFunc("/setup-demo-db", db.SetupDemoDb) http.HandleFunc("/nuke-db", db.NukeDb) http.HandleFunc("/fetch-all-users", db.FetchAllUsers) - http.HandleFunc("/execute-sql", sql_injection.ExecuteSql) - http.HandleFunc("/login-sql", sql_injection.LoginSql) - http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql) - http.HandleFunc("/register-sql", sql_injection.RegisterSql) + http.HandleFunc("/secure-register-sql", sql_injection.SecureRegisterSql) http.HandleFunc("/secure-login-sql", sql_injection.SecureLoginSql) - http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser) log.Println("Server is running on http://localhost:5000") if err := http.ListenAndServe(":5000", nil); err != nil { log.Fatalf("Failed to start server: %v", err) diff --git a/server/internal/sql_injection/sql_injection.go b/server/internal/sql_injection/sql_injection.go index b34ae0b..273ce6d 100644 --- a/server/internal/sql_injection/sql_injection.go +++ b/server/internal/sql_injection/sql_injection.go @@ -4,153 +4,16 @@ import ( "context" "encoding/json" "fmt" - "io" "log" "net/http" "regexp" - "strings" "github.com/Vomitblood/cspj-application/server/internal/db" "golang.org/x/crypto/bcrypt" ) -// unsecure version -// take http reqeust body as raw sql and pass to db -func ExecuteSql(w http.ResponseWriter, r *http.Request) { - // read the request body - sqlQuery, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // execute the sql query without any sanitization - rows, err := db.DbPool.Query(context.Background(), string(sqlQuery)) - if err != nil { - http.Error(w, "Query execution error", http.StatusInternalServerError) - return - } - defer rows.Close() - - // prepare the response by iterating over the returned rows - var response string - for rows.Next() { - values, err := rows.Values() - if err != nil { - http.Error(w, "Error reading query result", http.StatusInternalServerError) - return - } - response += fmt.Sprintf("%v\n", values) - } - - // send the response to the client - w.Write([]byte(response)) -} - -// unsecure login -// login endpoint with sql injection vulnerability -func LoginSql(w http.ResponseWriter, r *http.Request) { - // parse the request body to get username and password - var credentials struct { - Username string `json:"username"` - Password string `json:"password"` - } - - // decode the json body - if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil { - http.Error(w, "Invalid request format", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // construct the unsafe query - query := fmt.Sprintf( - "SELECT id, username FROM users WHERE username = '%s' AND password = '&s'", - credentials.Username, - credentials.Password, - ) - - // execute the query without sanitizing the input - var id int - var username string - err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &username) - if err != nil { - http.Error(w, "Invalid credentials", http.StatusUnauthorized) - return - } - - // if the user is found, return success response - response := map[string]interface{}{ - "message": "Login successful", - "user_id": id, - "username": username, - } - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - log.Printf("JSON encoding error: %v", err) - } -} - -// secure version -// only allow parameterized queries with validation -func SecureExecuteSql(w http.ResponseWriter, r *http.Request) { - var input struct { - Query string `json:"query"` - Params []interface{} `json:"params"` - } - - // parse json request body - if err := json.NewDecoder(r.Body).Decode(&input); err != nil { - http.Error(w, "Invalid JSON format", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // simple validation, only allow select statements - if !strings.HasPrefix(strings.ToUpper(input.Query), "SELECT") { - http.Error(w, "Only SELECT queries are allowed", http.StatusForbidden) - return - } - - // execute the query as a parameterized statement - rows, err := db.DbPool.Query(context.Background(), input.Query, input.Params...) - if err != nil { - http.Error(w, "Query execution error", http.StatusInternalServerError) - return - } - defer rows.Close() - - // format the response - var response []map[string]interface{} - for rows.Next() { - values, err := rows.Values() - if err != nil { - http.Error(w, "Error reading query result", http.StatusInternalServerError) - return - } - - rowMap := make(map[string]interface{}) - fieldDescriptions := rows.FieldDescriptions() - for i, fd := range fieldDescriptions { - rowMap[string(fd.Name)] = values[i] - } - response = append(response, rowMap) - } - - // return json response - jsonResp, err := json.Marshal(response) - if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(jsonResp) -} - -// register endpoint -func RegisterSql(w http.ResponseWriter, r *http.Request) { +// secure register endpoint +func UnsecureRegisterSql(w http.ResponseWriter, r *http.Request) { // read the request body var credentials struct { Email string `json:"email"` @@ -179,6 +42,75 @@ func RegisterSql(w http.ResponseWriter, r *http.Request) { } // get the number of emails that matches + // construct the sql query using concatenation, BAD + emailCheckSQL := fmt.Sprintf("SELECT COUNT(*) FROM users WHERE email = '%s'", credentials.Email) + var existingUserCount int + err := db.DbPool.QueryRow(context.Background(), emailCheckSQL).Scan(&existingUserCount) + if err != nil { + http.Error(w, "Error checking email in the database", http.StatusInternalServerError) + log.Printf("Error checking email: %v", err) + return + } + // if there is more than 0 matches, that means email already exists, reject + if existingUserCount > 0 { + http.Error(w, "Email already exists", http.StatusConflict) + return + } + + // hash the password + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(credentials.Password), bcrypt.DefaultCost) + if err != nil { + http.Error(w, "Error hashing password", http.StatusInternalServerError) + return + } + + // over here the validations has passed, so insert into the db + // also use concatenation here + insertSQL := fmt.Sprintf("INSERT INTO users (email, password, role) VALUES ('%s', '%s', 'user')", credentials.Email, string(hashedPassword)) + _, err = db.DbPool.Exec(context.Background(), insertSQL) + if err != nil { + http.Error(w, "Error inserting user into the database", http.StatusInternalServerError) + log.Printf("Error inserting user: %v", err) + return + } + + // send back status ok + w.WriteHeader(http.StatusOK) + w.Write([]byte("User registered successfully")) + log.Println("User registered successfully:", credentials.Email) +} + +// secure register endpoint +func SecureRegisterSql(w http.ResponseWriter, r *http.Request) { + // read the request body + var credentials struct { + Email string `json:"email"` + Password string `json:"password"` + RePassword string `json:"rePassword"` + } + + if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil { + http.Error(w, "Invalid request format", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // check if the email is an email using regex, if not reject + emailRegex := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$` + re := regexp.MustCompile(emailRegex) + if !re.MatchString(credentials.Email) { + http.Error(w, "Invalid email format", http.StatusBadRequest) + return + } + + // check if the password matches, if not reject + if credentials.Password != credentials.RePassword { + http.Error(w, "Passwords do not match", http.StatusBadRequest) + return + } + + // get the number of emails that matches + // use parameterization var existingUserCount int emailCheckSQL := `SELECT COUNT(*) FROM users WHERE email = $1` err := db.DbPool.QueryRow(context.Background(), emailCheckSQL, credentials.Email).Scan(&existingUserCount) @@ -201,6 +133,7 @@ func RegisterSql(w http.ResponseWriter, r *http.Request) { } // over here the validations has passed, so insert into the db + // use parameterization insertSQL := `INSERT INTO users (email, password, role) VALUES ($1, $2, $3)` _, err = db.DbPool.Exec(context.Background(), insertSQL, credentials.Email, hashedPassword, "user") if err != nil { @@ -215,46 +148,12 @@ func RegisterSql(w http.ResponseWriter, r *http.Request) { log.Println("User registered successfully:", credentials.Email) } -// secure login -func SecureLoginSql(w http.ResponseWriter, r *http.Request) { - var credentials struct { - Username string `json:"username"` - Password string `json:"password"` - } - - if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil { - http.Error(w, "Invalid request format", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // secure version using parameterized queries - query := "SELECT id, username FROM users WHERE username = $1 AND password = $2" - var id int - var username string - err := db.DbPool.QueryRow(context.Background(), query, credentials.Username, credentials.Password).Scan(&id, &username) - if err != nil { - http.Error(w, "Invalid credentials", http.StatusUnauthorized) - return - } - - // send back the response if great success - response := map[string]interface{}{ - "message": "Login successful", - "user_id": id, - "username": username, - } - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - log.Printf("JSON encoding error: %v", err) - } -} - -// even more secure -func SecureGetUser(w http.ResponseWriter, r *http.Request) { +// very secure login endpoint +func UnsecureLoginSql(w http.ResponseWriter, r *http.Request) { // decode the json body var requestData struct { - Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` } // declare new json decoder with custom property @@ -269,43 +168,104 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) { } // validate that user is provided - if requestData.Username == "" { + if requestData.Email == "" || requestData.Password == "" { http.Error(w, "Invalid request format", http.StatusBadRequest) - log.Println("Username is missing in the request body") return } - // retrieve list of existing usernames from db - existingUsernames, err := db.FetchUsernames() - if err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) - log.Printf("Failed to fetch usernames: %v", err) - return - } - - // check if the username exists in the allowed list - // this step is crucial - // server will reject ANYTHING that does not match the list - if !existingUsernames[requestData.Username] { - http.Error(w, "Invalid username", http.StatusBadRequest) - return - } - - // construct the query - query := "SELECT id, username, email FROM users WHERE username = $1" + // retrieve the user data + // construct sql command using concatenation + query := fmt.Sprintf("SELECT id, email, password FROM users WHERE email = '%s'", requestData.Email) var id int - var dbUsername, email string - err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email) + var email string + var hashedPassword string + err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &email, &hashedPassword) if err != nil { - http.Error(w, "User not found", http.StatusNotFound) + http.Error(w, "Invalid email or password", http.StatusNotFound) return } - // send back the user data as a json response + // compare the provided password with the stored hashed password + err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(requestData.Password)) + if err != nil { + http.Error(w, "Invalid email or password", http.StatusUnauthorized) + return + } + + // send back the user data as a JSON response response := map[string]interface{}{ - "id": id, - "username": dbUsername, - "email": email, + "id": id, + "email": email, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError) + log.Printf("JSON encoding error: %v", err) + } +} + +// very secure login endpoint +func SecureLoginSql(w http.ResponseWriter, r *http.Request) { + // decode the json body + var requestData struct { + Email string `json:"email"` + Password string `json:"password"` + } + + // declare new json decoder with custom property + jsonDecoder := json.NewDecoder(r.Body) + // rejects any unknown fields in the json, more strict + jsonDecoder.DisallowUnknownFields() + + if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil { + http.Error(w, "Invalid request format", http.StatusBadRequest) + log.Println("Failed to decode JSON body or extra fields present:", err) + return + } + + // validate that user is provided + if requestData.Email == "" || requestData.Password == "" { + http.Error(w, "Invalid request format", http.StatusBadRequest) + return + } + + // retrieve list of existing emails from db + existingEmails, err := db.FetchEmails() + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + log.Printf("Failed to fetch emails: %v", err) + return + } + + // check if the email exists in the allowed list + // this step is crucial + // server will reject ANYTHING that does not match the list + if !existingEmails[requestData.Email] { + http.Error(w, "Invalid email or password", http.StatusBadRequest) + return + } + + // retrieve the email and password from the database + var id int + var email string + var hashedPassword string + query := "SELECT id, email, password FROM users WHERE email = $1" + err = db.DbPool.QueryRow(context.Background(), query, requestData.Email).Scan(&id, &email, &hashedPassword) + if err != nil { + http.Error(w, "Invalid email or password", http.StatusNotFound) + return + } + + // compare the provided password and the stored password hashes + err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(requestData.Password)) + if err != nil { + http.Error(w, "Invalid email or password", http.StatusUnauthorized) + return + } + + // send back the user data as a json response + response := map[string]interface{}{ + "id": id, + "email": email, } if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)