cspj-application/server/internal/sql_injection/sql_injection.go

247 lines
6.8 KiB
Go
Raw Normal View History

2024-11-11 17:34:37 +08:00
package sql_injection
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"github.com/Vomitblood/cspj-application/server/internal/db"
)
// unsecure version
// take http reqeust body as raw sql and pass to db
func ExecuteSql(w http.ResponseWriter, r *http.Request) {
// read the request body
sqlQuery, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
// execute the sql query without any sanitization
rows, err := db.DbPool.Query(context.Background(), string(sqlQuery))
if err != nil {
http.Error(w, "Query execution error", http.StatusInternalServerError)
return
}
defer rows.Close()
// prepare the response by iterating over the returned rows
var response string
for rows.Next() {
values, err := rows.Values()
if err != nil {
http.Error(w, "Error reading query result", http.StatusInternalServerError)
return
}
response += fmt.Sprintf("%v\n", values)
}
// send the response to the client
w.Write([]byte(response))
}
// unsecure login
// login endpoint with sql injection vulnerability
func LoginSql(w http.ResponseWriter, r *http.Request) {
// parse the request body to get username and password
var credentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
// decode the json body
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
http.Error(w, "Invalid request format", http.StatusBadRequest)
return
}
defer r.Body.Close()
// construct the unsafe query
query := fmt.Sprintf(
"SELECT id, username FROM users WHERE username = '%s' AND password = '&s'",
credentials.Username,
credentials.Password,
)
// execute the query without sanitizing the input
var id int
var username string
err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &username)
if err != nil {
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return
}
// if the user is found, return success response
response := map[string]interface{}{
"message": "Login successful",
"user_id": id,
"username": username,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
log.Printf("JSON encoding error: %v", err)
}
}
2024-11-11 17:34:37 +08:00
// secure version
// only allow parameterized queries with validation
func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
var input struct {
Query string `json:"query"`
Params []interface{} `json:"params"`
}
// parse json request body
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
http.Error(w, "Invalid JSON format", http.StatusBadRequest)
return
}
defer r.Body.Close()
// simple validation, only allow select statements
if !strings.HasPrefix(strings.ToUpper(input.Query), "SELECT") {
http.Error(w, "Only SELECT queries are allowed", http.StatusForbidden)
return
}
// execute the query as a parameterized statement
rows, err := db.DbPool.Query(context.Background(), input.Query, input.Params...)
if err != nil {
http.Error(w, "Query execution error", http.StatusInternalServerError)
return
}
defer rows.Close()
// format the response
var response []map[string]interface{}
for rows.Next() {
values, err := rows.Values()
if err != nil {
http.Error(w, "Error reading query result", http.StatusInternalServerError)
return
}
rowMap := make(map[string]interface{})
fieldDescriptions := rows.FieldDescriptions()
for i, fd := range fieldDescriptions {
rowMap[string(fd.Name)] = values[i]
}
response = append(response, rowMap)
}
// return json response
jsonResp, err := json.Marshal(response)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
2024-11-11 17:34:37 +08:00
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonResp)
}
// secure login
func SecureLoginSql(w http.ResponseWriter, r *http.Request) {
var credentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
http.Error(w, "Invalid request format", http.StatusBadRequest)
return
}
defer r.Body.Close()
// secure version using parameterized queries
query := "SELECT id, username FROM users WHERE username = $1 AND password = $2"
var id int
var username string
err := db.DbPool.QueryRow(context.Background(), query, credentials.Username, credentials.Password).Scan(&id, &username)
if err != nil {
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return
}
// send back the response if great success
response := map[string]interface{}{
"message": "Login successful",
"user_id": id,
"username": username,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
log.Printf("JSON encoding error: %v", err)
}
}
2024-11-11 17:34:37 +08:00
// even more secure
func SecureGetUser(w http.ResponseWriter, r *http.Request) {
2024-11-11 18:47:15 +08:00
// decode the json body
var requestData struct {
Username string `json:"username"`
}
// declare new json decoder with custom property
jsonDecoder := json.NewDecoder(r.Body)
// rejects any unknown fields in the json, more strict
jsonDecoder.DisallowUnknownFields()
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
http.Error(w, "Invalid request format", http.StatusBadRequest)
log.Println("Failed to decode JSON body or extra fields present:", err)
return
}
// validate that user is provided
if requestData.Username == "" {
http.Error(w, "Invalid request format", http.StatusBadRequest)
log.Println("Username is missing in the request body")
return
}
2024-11-11 17:34:37 +08:00
// retrieve list of existing usernames from db
existingUsernames, err := db.FetchUsernames()
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
log.Printf("Failed to fetch usernames: %v", err)
return
}
// check if the username exists in the allowed list
// this step is crucial
// server will reject ANYTHING that does not match the list
2024-11-11 18:47:15 +08:00
if !existingUsernames[requestData.Username] {
2024-11-11 17:34:37 +08:00
http.Error(w, "Invalid username", http.StatusBadRequest)
return
}
// construct the query
query := "SELECT id, username, email FROM users WHERE username = $1"
var id int
var dbUsername, email string
2024-11-11 18:47:15 +08:00
err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email)
2024-11-11 17:34:37 +08:00
if err != nil {
http.Error(w, "User not found", http.StatusNotFound)
return
}
// send back the user data as a json response
response := map[string]interface{}{
"id": id,
"username": dbUsername,
"email": email,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
log.Printf("JSON encoding error: %v", err)
}
}