2024-11-11 17:34:37 +08:00
|
|
|
package sql_injection
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"net/http"
|
2025-01-14 02:39:24 +08:00
|
|
|
"regexp"
|
2024-11-11 17:34:37 +08:00
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/Vomitblood/cspj-application/server/internal/db"
|
2025-01-14 02:39:24 +08:00
|
|
|
"golang.org/x/crypto/bcrypt"
|
2024-11-11 17:34:37 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
// unsecure version
|
|
|
|
// take http reqeust body as raw sql and pass to db
|
|
|
|
func ExecuteSql(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// read the request body
|
|
|
|
sqlQuery, err := io.ReadAll(r.Body)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer r.Body.Close()
|
|
|
|
|
|
|
|
// execute the sql query without any sanitization
|
|
|
|
rows, err := db.DbPool.Query(context.Background(), string(sqlQuery))
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Query execution error", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
|
|
// prepare the response by iterating over the returned rows
|
|
|
|
var response string
|
|
|
|
for rows.Next() {
|
|
|
|
values, err := rows.Values()
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Error reading query result", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
response += fmt.Sprintf("%v\n", values)
|
|
|
|
}
|
|
|
|
|
|
|
|
// send the response to the client
|
|
|
|
w.Write([]byte(response))
|
|
|
|
}
|
|
|
|
|
2024-11-12 11:53:55 +08:00
|
|
|
// unsecure login
|
|
|
|
// login endpoint with sql injection vulnerability
|
|
|
|
func LoginSql(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// parse the request body to get username and password
|
|
|
|
var credentials struct {
|
|
|
|
Username string `json:"username"`
|
|
|
|
Password string `json:"password"`
|
|
|
|
}
|
|
|
|
|
|
|
|
// decode the json body
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
|
|
|
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer r.Body.Close()
|
|
|
|
|
|
|
|
// construct the unsafe query
|
|
|
|
query := fmt.Sprintf(
|
|
|
|
"SELECT id, username FROM users WHERE username = '%s' AND password = '&s'",
|
|
|
|
credentials.Username,
|
|
|
|
credentials.Password,
|
|
|
|
)
|
|
|
|
|
|
|
|
// execute the query without sanitizing the input
|
|
|
|
var id int
|
|
|
|
var username string
|
|
|
|
err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &username)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// if the user is found, return success response
|
|
|
|
response := map[string]interface{}{
|
|
|
|
"message": "Login successful",
|
|
|
|
"user_id": id,
|
|
|
|
"username": username,
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
|
|
|
log.Printf("JSON encoding error: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-11-11 17:34:37 +08:00
|
|
|
// secure version
|
|
|
|
// only allow parameterized queries with validation
|
|
|
|
func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
|
|
|
|
var input struct {
|
|
|
|
Query string `json:"query"`
|
|
|
|
Params []interface{} `json:"params"`
|
|
|
|
}
|
|
|
|
|
|
|
|
// parse json request body
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
|
|
|
http.Error(w, "Invalid JSON format", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer r.Body.Close()
|
|
|
|
|
|
|
|
// simple validation, only allow select statements
|
|
|
|
if !strings.HasPrefix(strings.ToUpper(input.Query), "SELECT") {
|
|
|
|
http.Error(w, "Only SELECT queries are allowed", http.StatusForbidden)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// execute the query as a parameterized statement
|
|
|
|
rows, err := db.DbPool.Query(context.Background(), input.Query, input.Params...)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Query execution error", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
|
|
// format the response
|
|
|
|
var response []map[string]interface{}
|
|
|
|
for rows.Next() {
|
|
|
|
values, err := rows.Values()
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Error reading query result", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
rowMap := make(map[string]interface{})
|
|
|
|
fieldDescriptions := rows.FieldDescriptions()
|
|
|
|
for i, fd := range fieldDescriptions {
|
|
|
|
rowMap[string(fd.Name)] = values[i]
|
|
|
|
}
|
|
|
|
response = append(response, rowMap)
|
|
|
|
}
|
|
|
|
|
|
|
|
// return json response
|
|
|
|
jsonResp, err := json.Marshal(response)
|
|
|
|
if err != nil {
|
2024-11-12 11:53:55 +08:00
|
|
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
2024-11-11 17:34:37 +08:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
w.Write(jsonResp)
|
|
|
|
}
|
|
|
|
|
2025-01-14 02:39:24 +08:00
|
|
|
// register endpoint
|
|
|
|
func RegisterSql(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// read the request body
|
|
|
|
var credentials struct {
|
|
|
|
Email string `json:"email"`
|
|
|
|
Password string `json:"password"`
|
|
|
|
RePassword string `json:"rePassword"`
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
|
|
|
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer r.Body.Close()
|
|
|
|
|
|
|
|
// check if the email is an email using regex, if not reject
|
|
|
|
emailRegex := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
|
|
|
re := regexp.MustCompile(emailRegex)
|
|
|
|
if !re.MatchString(credentials.Email) {
|
|
|
|
http.Error(w, "Invalid email format", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if the password matches, if not reject
|
|
|
|
if credentials.Password != credentials.RePassword {
|
|
|
|
http.Error(w, "Passwords do not match", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the number of emails that matches
|
|
|
|
var existingUserCount int
|
|
|
|
emailCheckSQL := `SELECT COUNT(*) FROM users WHERE email = $1`
|
|
|
|
err := db.DbPool.QueryRow(context.Background(), emailCheckSQL, credentials.Email).Scan(&existingUserCount)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Error checking email in the database", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error checking email: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
// if there is more than 0 matches, that means email already exists, reject
|
|
|
|
if existingUserCount > 0 {
|
|
|
|
http.Error(w, "Email already exists", http.StatusConflict)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// hash the password
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(credentials.Password), bcrypt.DefaultCost)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Error hashing password", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// over here the validations has passed, so insert into the db
|
|
|
|
insertSQL := `INSERT INTO users (email, password, role) VALUES ($1, $2, $3)`
|
|
|
|
_, err = db.DbPool.Exec(context.Background(), insertSQL, credentials.Email, hashedPassword, "user")
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Error inserting user into the database", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error inserting user: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// send back status ok
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
w.Write([]byte("User registered successfully"))
|
|
|
|
log.Println("User registered successfully:", credentials.Email)
|
|
|
|
}
|
|
|
|
|
2024-11-12 11:53:55 +08:00
|
|
|
// secure login
|
|
|
|
func SecureLoginSql(w http.ResponseWriter, r *http.Request) {
|
|
|
|
var credentials struct {
|
|
|
|
Username string `json:"username"`
|
|
|
|
Password string `json:"password"`
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
|
|
|
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer r.Body.Close()
|
|
|
|
|
|
|
|
// secure version using parameterized queries
|
|
|
|
query := "SELECT id, username FROM users WHERE username = $1 AND password = $2"
|
|
|
|
var id int
|
|
|
|
var username string
|
|
|
|
err := db.DbPool.QueryRow(context.Background(), query, credentials.Username, credentials.Password).Scan(&id, &username)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// send back the response if great success
|
|
|
|
response := map[string]interface{}{
|
|
|
|
"message": "Login successful",
|
|
|
|
"user_id": id,
|
|
|
|
"username": username,
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
|
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
|
|
|
log.Printf("JSON encoding error: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-11-11 17:34:37 +08:00
|
|
|
// even more secure
|
|
|
|
func SecureGetUser(w http.ResponseWriter, r *http.Request) {
|
2024-11-11 18:47:15 +08:00
|
|
|
// decode the json body
|
|
|
|
var requestData struct {
|
|
|
|
Username string `json:"username"`
|
|
|
|
}
|
|
|
|
|
|
|
|
// declare new json decoder with custom property
|
|
|
|
jsonDecoder := json.NewDecoder(r.Body)
|
|
|
|
// rejects any unknown fields in the json, more strict
|
|
|
|
jsonDecoder.DisallowUnknownFields()
|
|
|
|
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
|
|
|
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
|
|
log.Println("Failed to decode JSON body or extra fields present:", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// validate that user is provided
|
|
|
|
if requestData.Username == "" {
|
|
|
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
|
|
log.Println("Username is missing in the request body")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-11-11 17:34:37 +08:00
|
|
|
// retrieve list of existing usernames from db
|
|
|
|
existingUsernames, err := db.FetchUsernames()
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
|
|
log.Printf("Failed to fetch usernames: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if the username exists in the allowed list
|
|
|
|
// this step is crucial
|
|
|
|
// server will reject ANYTHING that does not match the list
|
2024-11-11 18:47:15 +08:00
|
|
|
if !existingUsernames[requestData.Username] {
|
2024-11-11 17:34:37 +08:00
|
|
|
http.Error(w, "Invalid username", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// construct the query
|
|
|
|
query := "SELECT id, username, email FROM users WHERE username = $1"
|
|
|
|
var id int
|
|
|
|
var dbUsername, email string
|
2024-11-11 18:47:15 +08:00
|
|
|
err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email)
|
2024-11-11 17:34:37 +08:00
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "User not found", http.StatusNotFound)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// send back the user data as a json response
|
|
|
|
response := map[string]interface{}{
|
|
|
|
"id": id,
|
|
|
|
"username": dbUsername,
|
|
|
|
"email": email,
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
|
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
|
|
|
|
log.Printf("JSON encoding error: %v", err)
|
|
|
|
}
|
|
|
|
}
|