added secure api sql injection server

This commit is contained in:
Vomitblood 2024-11-11 17:34:37 +08:00
parent 549073dd95
commit 18eef141d3
6 changed files with 256 additions and 81 deletions

22
README
View file

@ -28,3 +28,25 @@ PGPORT=5432
PGDATABASE=asdfdb PGDATABASE=asdfdb
PGUSER=asdfuser PGUSER=asdfuser
PGPASSWORD=asdfpassword PGPASSWORD=asdfpassword
## Server
### SQL Injection
- `/sql-execute`
- `/secure-sql-execute`
- `/secure-get-user`
#### 1. Parameterization of Queries
Used `pool.Query()` with a parameterized query, instead of dynamically constructing the SQL query by directly inserting the user input.
Parameterized queries separate the SQL code from the data, so user input is never directly put into the query's structure. Placeholders are used instead, and the data is passed as parameters. The DB will treat them as data, not executable code.
#### 2. Input Validation and Query Type Restriction
Only allow `SELECT` statement by verifying that the input query starts with it.
Sanitized the input to ensure that no other types of statements could be executed.
#### 3. Controller JSON Input for Parameters
Instead of using raw SQL strings, we restructured the input to ONLY expect JSON data with `query` and `params` fields.

View file

@ -1,4 +1,4 @@
module cspj-server module github.com/Vomitblood/cspj-application/server
go 1.23.2 go 1.23.2

62
server/internal/db/db.go Normal file
View file

@ -0,0 +1,62 @@
package db
import (
"context"
"fmt"
"log"
"github.com/jackc/pgx/v5/pgxpool"
)
// db connection info
// !MIGHT CHANGE
const (
host = "localhost"
port = 5432
user = "asdfuser"
password = "asdfpassword"
dbname = "asdfdb"
)
var DbPool *pgxpool.Pool
var allowedUsernames map[string]bool
// initialize connection to db
func ConnectToDb() (*pgxpool.Pool, error) {
// this server is intended to be ran on the same system as the db
dbUrl := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", user, password, host, port, dbname)
config, err := pgxpool.ParseConfig((dbUrl))
if err != nil {
return nil, fmt.Errorf("unable to parse data URL: %w", err)
}
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
log.Println("Connected to DB :)")
return pool, nil
}
// fetch existing usernames from db
func FetchUsernames() (map[string]bool, error) {
usernames := make(map[string]bool)
rows, err := DbPool.Query(context.Background(), "SELECT username FROM users")
if err != nil {
return nil, fmt.Errorf("error querying users: %w", err)
}
defer rows.Close()
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, fmt.Errorf("error scanning username: %w", err)
}
usernames[username] = true
}
log.Println("Fetched usernames:", usernames)
return usernames, nil
}

View file

@ -0,0 +1,19 @@
package http_server
import (
"log"
"net/http"
"github.com/Vomitblood/cspj-application/server/internal/sql_injection"
)
// setup the http server
func ServeApi() {
http.HandleFunc("/execute-sql", sql_injection.ExecuteSql)
http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql)
http.HandleFunc("/secure-get-user", sql_injection.SecureExecuteSql)
log.Println("Server is running on http://localhost:3001")
if err := http.ListenAndServe(":3001", nil); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}

View file

@ -0,0 +1,147 @@
package sql_injection
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"github.com/Vomitblood/cspj-application/server/internal/db"
)
// unsecure version
// take http reqeust body as raw sql and pass to db
func ExecuteSql(w http.ResponseWriter, r *http.Request) {
// read the request body
sqlQuery, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
// execute the sql query without any sanitization
rows, err := db.DbPool.Query(context.Background(), string(sqlQuery))
if err != nil {
http.Error(w, "Query execution error", http.StatusInternalServerError)
return
}
defer rows.Close()
// prepare the response by iterating over the returned rows
var response string
for rows.Next() {
values, err := rows.Values()
if err != nil {
http.Error(w, "Error reading query result", http.StatusInternalServerError)
return
}
response += fmt.Sprintf("%v\n", values)
}
// send the response to the client
w.Write([]byte(response))
}
// secure version
// only allow parameterized queries with validation
func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
var input struct {
Query string `json:"query"`
Params []interface{} `json:"params"`
}
// parse json request body
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
http.Error(w, "Invalid JSON format", http.StatusBadRequest)
return
}
defer r.Body.Close()
// simple validation, only allow select statements
if !strings.HasPrefix(strings.ToUpper(input.Query), "SELECT") {
http.Error(w, "Only SELECT queries are allowed", http.StatusForbidden)
return
}
// execute the query as a parameterized statement
rows, err := db.DbPool.Query(context.Background(), input.Query, input.Params...)
if err != nil {
http.Error(w, "Query execution error", http.StatusInternalServerError)
return
}
defer rows.Close()
// format the response
var response []map[string]interface{}
for rows.Next() {
values, err := rows.Values()
if err != nil {
http.Error(w, "Error reading query result", http.StatusInternalServerError)
return
}
rowMap := make(map[string]interface{})
fieldDescriptions := rows.FieldDescriptions()
for i, fd := range fieldDescriptions {
rowMap[string(fd.Name)] = values[i]
}
response = append(response, rowMap)
}
// return json response
jsonResp, err := json.Marshal(response)
if err != nil {
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonResp)
}
// even more secure
func SecureGetUser(w http.ResponseWriter, r *http.Request) {
// retrieve list of existing usernames from db
existingUsernames, err := db.FetchUsernames()
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
log.Printf("Failed to fetch usernames: %v", err)
return
}
// get the username from the query parameter
username := r.URL.Query().Get("username")
// check if the username exists in the allowed list
// this step is crucial
// server will reject ANYTHING that does not match the list
if !existingUsernames[username] {
http.Error(w, "Invalid username", http.StatusBadRequest)
return
}
// construct the query
query := "SELECT id, username, email FROM users WHERE username = $1"
var id int
var dbUsername, email string
err = db.DbPool.QueryRow(context.Background(), query, username).Scan(&id, &dbUsername, &email)
if err != nil {
http.Error(w, "User not found", http.StatusNotFound)
return
}
// send back the user data as a json response
response := map[string]interface{}{
"id": id,
"username": dbUsername,
"email": email,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
log.Printf("JSON encoding error: %v", err)
}
}

View file

@ -1,94 +1,19 @@
package main package main
import ( import (
"context"
"fmt"
"io"
"log" "log"
"net/http"
"github.com/jackc/pgx/v5/pgxpool" "github.com/Vomitblood/cspj-application/server/internal/db"
"github.com/Vomitblood/cspj-application/server/internal/http_server"
) )
// db connection info
// !MIGHT CHANGE
const (
host = "localhost"
port = 5432
user = "asdfuser"
password = "asdfpassword"
dbname = "asdfdb"
)
var pool *pgxpool.Pool
// initialize connection to db
func connectToDb() (*pgxpool.Pool, error) {
// this server is intended to be ran on the same system as the db
dbUrl := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", user, password, host, port, dbname)
config, err := pgxpool.ParseConfig((dbUrl))
if err != nil {
return nil, fmt.Errorf("unable to parse data URL: %w", err)
}
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
log.Println("Connected to DB :)")
return pool, nil
}
// take http reqeust body as raw sql and pass to db
func executeSql(w http.ResponseWriter, r *http.Request) {
// read the request body
sqlQuery, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
// execute the sql query without any sanitization
rows, err := pool.Query(context.Background(), string(sqlQuery))
if err != nil {
http.Error(w, "Query execution error", http.StatusInternalServerError)
return
}
defer rows.Close()
// prepare the response by iterating over the returned rows
var response string
for rows.Next() {
values, err := rows.Values()
if err != nil {
http.Error(w, "Error reading query result", http.StatusInternalServerError)
return
}
response += fmt.Sprintf("%v\n", values)
}
// send the response to the client
w.Write([]byte(response))
}
// setup the http server
func serveApi() {
http.HandleFunc("/executeSql", executeSql)
log.Println("Unsecure server is running on http://localhost:3001")
if err := http.ListenAndServe(":3001", nil); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}
func main() { func main() {
var err error var err error
pool, err = connectToDb() db.DbPool, err = db.ConnectToDb()
if err != nil { if err != nil {
log.Fatalf("Failed to connect to db: %v", err) log.Fatalf("Failed to connect to db: %v", err)
} }
defer pool.Close() defer db.DbPool.Close()
serveApi() http_server.ServeApi()
} }