login
This commit is contained in:
parent
9e466635de
commit
53c0ada12e
|
@ -41,11 +41,8 @@ PGPASSWORD=asdfpassword
|
||||||
|
|
||||||
### SQL Injection
|
### SQL Injection
|
||||||
|
|
||||||
- `/sql-execute`
|
- `/secure-register-sql`
|
||||||
- `/login-sql`
|
|
||||||
- `/secure-sql-execute`
|
|
||||||
- `/secure-login-sql`
|
- `/secure-login-sql`
|
||||||
- `/secure-get-user`
|
|
||||||
|
|
||||||
#### 1. Parameterization of Queries
|
#### 1. Parameterization of Queries
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,66 @@
|
||||||
import { Box, LinearProgress, TextField, Typography } from "@mui/material";
|
import { Box, TextField, Typography } from "@mui/material";
|
||||||
|
import { useAtom } from "jotai";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { useNotification } from "../../../contexts/NotificationContext";
|
||||||
|
import { serverUrlAtom } from "../../../lib/jotai";
|
||||||
import { HeaderLogo } from "../../Generic/HeaderLogo";
|
import { HeaderLogo } from "../../Generic/HeaderLogo";
|
||||||
|
import { LoadingButton } from "@mui/lab";
|
||||||
|
|
||||||
export const SqlInjectionLogin = () => {
|
export const SqlInjectionLogin = () => {
|
||||||
|
// contexts
|
||||||
|
const { openNotification } = useNotification();
|
||||||
|
|
||||||
|
// atoms
|
||||||
|
const [serverUrl, setServerUrl] = useAtom(serverUrlAtom);
|
||||||
|
|
||||||
// states
|
// states
|
||||||
const [emailValueRaw, setEmailValueRaw] = useState<string>("");
|
const [emailValueRaw, setEmailValueRaw] = useState<string>("");
|
||||||
const [passwordValueRaw, setPasswordValueRaw] = useState<string>("");
|
const [passwordValueRaw, setPasswordValueRaw] = useState<string>("");
|
||||||
const [passwordErrorMsg, setPasswordErrorMsg] = useState<string>("");
|
const [errorMsg, setErrorMsg] = useState<string>("");
|
||||||
|
const [loginLoading, setLoginLoading] = useState<boolean>(false);
|
||||||
|
|
||||||
|
const nextClickEvent = async () => {
|
||||||
|
// reset the error messages
|
||||||
|
setErrorMsg("");
|
||||||
|
|
||||||
|
// ensure that the server url does not end with a trailing slash
|
||||||
|
setServerUrl(serverUrl.replace(/\/$/, ""));
|
||||||
|
|
||||||
|
// construct the request body
|
||||||
|
const requestBody = {
|
||||||
|
email: emailValueRaw,
|
||||||
|
password: passwordValueRaw,
|
||||||
|
};
|
||||||
|
|
||||||
|
// start loading indicator
|
||||||
|
setLoginLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// make request good
|
||||||
|
const response = await fetch(serverUrl + "/register-sql", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify(requestBody),
|
||||||
|
});
|
||||||
|
|
||||||
|
// check if login was successful
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorMessage = await response.text();
|
||||||
|
console.log("Login failed:", errorMessage);
|
||||||
|
setErrorMsg(errorMessage);
|
||||||
|
} else {
|
||||||
|
openNotification("Login successful");
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// log the error and handle failure
|
||||||
|
console.log("Request failed", e);
|
||||||
|
} finally {
|
||||||
|
// stop loading indicator regardless of success/failure
|
||||||
|
setLoginLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
@ -33,6 +87,12 @@ export const SqlInjectionLogin = () => {
|
||||||
display: "flex",
|
display: "flex",
|
||||||
flexDirection: "column",
|
flexDirection: "column",
|
||||||
}}
|
}}
|
||||||
|
>
|
||||||
|
<form
|
||||||
|
onSubmit={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
nextClickEvent();
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<TextField
|
<TextField
|
||||||
fullWidth
|
fullWidth
|
||||||
|
@ -46,9 +106,9 @@ export const SqlInjectionLogin = () => {
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
/>
|
/>
|
||||||
<TextField
|
<TextField
|
||||||
error={Boolean(passwordErrorMsg)}
|
error={Boolean(errorMsg)}
|
||||||
fullWidth
|
fullWidth
|
||||||
helperText={Boolean(passwordErrorMsg) ? passwordErrorMsg : ""}
|
helperText={Boolean(errorMsg) ? errorMsg : ""}
|
||||||
id="password"
|
id="password"
|
||||||
label="Password"
|
label="Password"
|
||||||
onChange={(e: { target: { value: string } }) => setPasswordValueRaw(e.target.value)}
|
onChange={(e: { target: { value: string } }) => setPasswordValueRaw(e.target.value)}
|
||||||
|
@ -58,6 +118,22 @@ export const SqlInjectionLogin = () => {
|
||||||
sx={{ mb: 2 }}
|
sx={{ mb: 2 }}
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
/>
|
/>
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "row",
|
||||||
|
justifyContent: "end",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<LoadingButton
|
||||||
|
loading={loginLoading}
|
||||||
|
type="submit"
|
||||||
|
variant="contained"
|
||||||
|
>
|
||||||
|
Next
|
||||||
|
</LoadingButton>
|
||||||
|
</Box>
|
||||||
|
</form>
|
||||||
</Box>
|
</Box>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
|
@ -107,7 +107,7 @@ export const SqlInjectionRegister = () => {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// make request good
|
// make request good
|
||||||
const response = await fetch(serverUrl + "/register-sql", {
|
const response = await fetch(serverUrl + "/secure-register-sql", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -124,10 +124,10 @@ export const SqlInjectionRegister = () => {
|
||||||
openNotification("Registration successful");
|
openNotification("Registration successful");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
// Log the error and handle failure
|
// log the error and handle failure
|
||||||
console.log("Request failed", e);
|
console.log("Request failed", e);
|
||||||
} finally {
|
} finally {
|
||||||
// Stop loading indicator regardless of success/failure
|
// stop loading indicator regardless of success/failure
|
||||||
setRegisterLoading(false);
|
setRegisterLoading(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -152,17 +152,17 @@ export const SqlInjectionRegister = () => {
|
||||||
Register
|
Register
|
||||||
</Typography>
|
</Typography>
|
||||||
</Box>
|
</Box>
|
||||||
<form
|
|
||||||
onSubmit={(e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
nextClickEvent();
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
display: "flex",
|
display: "flex",
|
||||||
flexDirection: "column",
|
flexDirection: "column",
|
||||||
}}
|
}}
|
||||||
|
>
|
||||||
|
<form
|
||||||
|
onSubmit={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
nextClickEvent();
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<TextField
|
<TextField
|
||||||
error={Boolean(errorMsg)}
|
error={Boolean(errorMsg)}
|
||||||
|
@ -233,7 +233,6 @@ export const SqlInjectionRegister = () => {
|
||||||
<li>At least one special character</li>
|
<li>At least one special character</li>
|
||||||
</ul>
|
</ul>
|
||||||
</Typography>
|
</Typography>
|
||||||
</Box>
|
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
display: "flex",
|
display: "flex",
|
||||||
|
@ -250,6 +249,7 @@ export const SqlInjectionRegister = () => {
|
||||||
</LoadingButton>
|
</LoadingButton>
|
||||||
</Box>
|
</Box>
|
||||||
</form>
|
</form>
|
||||||
|
</Box>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
@ -135,25 +135,25 @@ func NukeDb(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetch existing usernames from db
|
// fetch existing usernames from db
|
||||||
func FetchUsernames() (map[string]bool, error) {
|
func FetchEmails() (map[string]bool, error) {
|
||||||
usernames := make(map[string]bool)
|
emails := make(map[string]bool)
|
||||||
|
|
||||||
rows, err := DbPool.Query(context.Background(), "SELECT username FROM users")
|
rows, err := DbPool.Query(context.Background(), "SELECT email FROM users")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error querying users: %w", err)
|
return nil, fmt.Errorf("error querying users: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var username string
|
var email string
|
||||||
if err := rows.Scan(&username); err != nil {
|
if err := rows.Scan(&email); err != nil {
|
||||||
return nil, fmt.Errorf("error scanning username: %w", err)
|
return nil, fmt.Errorf("error scanning email: %w", err)
|
||||||
}
|
}
|
||||||
usernames[username] = true
|
emails[email] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Fetched usernames:", usernames)
|
log.Println("Fetched emails:", emails)
|
||||||
return usernames, nil
|
return emails, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetch all users for demo
|
// fetch all users for demo
|
||||||
|
|
|
@ -21,12 +21,8 @@ func ServeApi() {
|
||||||
http.HandleFunc("/setup-demo-db", db.SetupDemoDb)
|
http.HandleFunc("/setup-demo-db", db.SetupDemoDb)
|
||||||
http.HandleFunc("/nuke-db", db.NukeDb)
|
http.HandleFunc("/nuke-db", db.NukeDb)
|
||||||
http.HandleFunc("/fetch-all-users", db.FetchAllUsers)
|
http.HandleFunc("/fetch-all-users", db.FetchAllUsers)
|
||||||
http.HandleFunc("/execute-sql", sql_injection.ExecuteSql)
|
http.HandleFunc("/secure-register-sql", sql_injection.SecureRegisterSql)
|
||||||
http.HandleFunc("/login-sql", sql_injection.LoginSql)
|
|
||||||
http.HandleFunc("/secure-execute-sql", sql_injection.SecureExecuteSql)
|
|
||||||
http.HandleFunc("/register-sql", sql_injection.RegisterSql)
|
|
||||||
http.HandleFunc("/secure-login-sql", sql_injection.SecureLoginSql)
|
http.HandleFunc("/secure-login-sql", sql_injection.SecureLoginSql)
|
||||||
http.HandleFunc("/secure-get-user", sql_injection.SecureGetUser)
|
|
||||||
log.Println("Server is running on http://localhost:5000")
|
log.Println("Server is running on http://localhost:5000")
|
||||||
if err := http.ListenAndServe(":5000", nil); err != nil {
|
if err := http.ListenAndServe(":5000", nil); err != nil {
|
||||||
log.Fatalf("Failed to start server: %v", err)
|
log.Fatalf("Failed to start server: %v", err)
|
||||||
|
|
|
@ -4,153 +4,16 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Vomitblood/cspj-application/server/internal/db"
|
"github.com/Vomitblood/cspj-application/server/internal/db"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// unsecure version
|
// secure register endpoint
|
||||||
// take http reqeust body as raw sql and pass to db
|
func UnsecureRegisterSql(w http.ResponseWriter, r *http.Request) {
|
||||||
func ExecuteSql(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// read the request body
|
|
||||||
sqlQuery, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
// execute the sql query without any sanitization
|
|
||||||
rows, err := db.DbPool.Query(context.Background(), string(sqlQuery))
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Query execution error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// prepare the response by iterating over the returned rows
|
|
||||||
var response string
|
|
||||||
for rows.Next() {
|
|
||||||
values, err := rows.Values()
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Error reading query result", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response += fmt.Sprintf("%v\n", values)
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the response to the client
|
|
||||||
w.Write([]byte(response))
|
|
||||||
}
|
|
||||||
|
|
||||||
// unsecure login
|
|
||||||
// login endpoint with sql injection vulnerability
|
|
||||||
func LoginSql(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// parse the request body to get username and password
|
|
||||||
var credentials struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// decode the json body
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
|
|
||||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
// construct the unsafe query
|
|
||||||
query := fmt.Sprintf(
|
|
||||||
"SELECT id, username FROM users WHERE username = '%s' AND password = '&s'",
|
|
||||||
credentials.Username,
|
|
||||||
credentials.Password,
|
|
||||||
)
|
|
||||||
|
|
||||||
// execute the query without sanitizing the input
|
|
||||||
var id int
|
|
||||||
var username string
|
|
||||||
err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &username)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the user is found, return success response
|
|
||||||
response := map[string]interface{}{
|
|
||||||
"message": "Login successful",
|
|
||||||
"user_id": id,
|
|
||||||
"username": username,
|
|
||||||
}
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
|
||||||
log.Printf("JSON encoding error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// secure version
|
|
||||||
// only allow parameterized queries with validation
|
|
||||||
func SecureExecuteSql(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var input struct {
|
|
||||||
Query string `json:"query"`
|
|
||||||
Params []interface{} `json:"params"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse json request body
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
|
||||||
http.Error(w, "Invalid JSON format", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
// simple validation, only allow select statements
|
|
||||||
if !strings.HasPrefix(strings.ToUpper(input.Query), "SELECT") {
|
|
||||||
http.Error(w, "Only SELECT queries are allowed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// execute the query as a parameterized statement
|
|
||||||
rows, err := db.DbPool.Query(context.Background(), input.Query, input.Params...)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Query execution error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// format the response
|
|
||||||
var response []map[string]interface{}
|
|
||||||
for rows.Next() {
|
|
||||||
values, err := rows.Values()
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Error reading query result", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rowMap := make(map[string]interface{})
|
|
||||||
fieldDescriptions := rows.FieldDescriptions()
|
|
||||||
for i, fd := range fieldDescriptions {
|
|
||||||
rowMap[string(fd.Name)] = values[i]
|
|
||||||
}
|
|
||||||
response = append(response, rowMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
// return json response
|
|
||||||
jsonResp, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Write(jsonResp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// register endpoint
|
|
||||||
func RegisterSql(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// read the request body
|
// read the request body
|
||||||
var credentials struct {
|
var credentials struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
@ -179,6 +42,75 @@ func RegisterSql(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the number of emails that matches
|
// get the number of emails that matches
|
||||||
|
// construct the sql query using concatenation, BAD
|
||||||
|
emailCheckSQL := fmt.Sprintf("SELECT COUNT(*) FROM users WHERE email = '%s'", credentials.Email)
|
||||||
|
var existingUserCount int
|
||||||
|
err := db.DbPool.QueryRow(context.Background(), emailCheckSQL).Scan(&existingUserCount)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error checking email in the database", http.StatusInternalServerError)
|
||||||
|
log.Printf("Error checking email: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// if there is more than 0 matches, that means email already exists, reject
|
||||||
|
if existingUserCount > 0 {
|
||||||
|
http.Error(w, "Email already exists", http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// hash the password
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(credentials.Password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error hashing password", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// over here the validations has passed, so insert into the db
|
||||||
|
// also use concatenation here
|
||||||
|
insertSQL := fmt.Sprintf("INSERT INTO users (email, password, role) VALUES ('%s', '%s', 'user')", credentials.Email, string(hashedPassword))
|
||||||
|
_, err = db.DbPool.Exec(context.Background(), insertSQL)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error inserting user into the database", http.StatusInternalServerError)
|
||||||
|
log.Printf("Error inserting user: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// send back status ok
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("User registered successfully"))
|
||||||
|
log.Println("User registered successfully:", credentials.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// secure register endpoint
|
||||||
|
func SecureRegisterSql(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// read the request body
|
||||||
|
var credentials struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
RePassword string `json:"rePassword"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
|
||||||
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
// check if the email is an email using regex, if not reject
|
||||||
|
emailRegex := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||||
|
re := regexp.MustCompile(emailRegex)
|
||||||
|
if !re.MatchString(credentials.Email) {
|
||||||
|
http.Error(w, "Invalid email format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the password matches, if not reject
|
||||||
|
if credentials.Password != credentials.RePassword {
|
||||||
|
http.Error(w, "Passwords do not match", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the number of emails that matches
|
||||||
|
// use parameterization
|
||||||
var existingUserCount int
|
var existingUserCount int
|
||||||
emailCheckSQL := `SELECT COUNT(*) FROM users WHERE email = $1`
|
emailCheckSQL := `SELECT COUNT(*) FROM users WHERE email = $1`
|
||||||
err := db.DbPool.QueryRow(context.Background(), emailCheckSQL, credentials.Email).Scan(&existingUserCount)
|
err := db.DbPool.QueryRow(context.Background(), emailCheckSQL, credentials.Email).Scan(&existingUserCount)
|
||||||
|
@ -201,6 +133,7 @@ func RegisterSql(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// over here the validations has passed, so insert into the db
|
// over here the validations has passed, so insert into the db
|
||||||
|
// use parameterization
|
||||||
insertSQL := `INSERT INTO users (email, password, role) VALUES ($1, $2, $3)`
|
insertSQL := `INSERT INTO users (email, password, role) VALUES ($1, $2, $3)`
|
||||||
_, err = db.DbPool.Exec(context.Background(), insertSQL, credentials.Email, hashedPassword, "user")
|
_, err = db.DbPool.Exec(context.Background(), insertSQL, credentials.Email, hashedPassword, "user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -215,46 +148,12 @@ func RegisterSql(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("User registered successfully:", credentials.Email)
|
log.Println("User registered successfully:", credentials.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// secure login
|
// very secure login endpoint
|
||||||
func SecureLoginSql(w http.ResponseWriter, r *http.Request) {
|
func UnsecureLoginSql(w http.ResponseWriter, r *http.Request) {
|
||||||
var credentials struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&credentials); err != nil {
|
|
||||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
// secure version using parameterized queries
|
|
||||||
query := "SELECT id, username FROM users WHERE username = $1 AND password = $2"
|
|
||||||
var id int
|
|
||||||
var username string
|
|
||||||
err := db.DbPool.QueryRow(context.Background(), query, credentials.Username, credentials.Password).Scan(&id, &username)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// send back the response if great success
|
|
||||||
response := map[string]interface{}{
|
|
||||||
"message": "Login successful",
|
|
||||||
"user_id": id,
|
|
||||||
"username": username,
|
|
||||||
}
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
|
||||||
log.Printf("JSON encoding error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// even more secure
|
|
||||||
func SecureGetUser(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// decode the json body
|
// decode the json body
|
||||||
var requestData struct {
|
var requestData struct {
|
||||||
Username string `json:"username"`
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// declare new json decoder with custom property
|
// declare new json decoder with custom property
|
||||||
|
@ -269,42 +168,103 @@ func SecureGetUser(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate that user is provided
|
// validate that user is provided
|
||||||
if requestData.Username == "" {
|
if requestData.Email == "" || requestData.Password == "" {
|
||||||
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||||
log.Println("Username is missing in the request body")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve list of existing usernames from db
|
// retrieve the user data
|
||||||
existingUsernames, err := db.FetchUsernames()
|
// construct sql command using concatenation
|
||||||
if err != nil {
|
query := fmt.Sprintf("SELECT id, email, password FROM users WHERE email = '%s'", requestData.Email)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
||||||
log.Printf("Failed to fetch usernames: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if the username exists in the allowed list
|
|
||||||
// this step is crucial
|
|
||||||
// server will reject ANYTHING that does not match the list
|
|
||||||
if !existingUsernames[requestData.Username] {
|
|
||||||
http.Error(w, "Invalid username", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// construct the query
|
|
||||||
query := "SELECT id, username, email FROM users WHERE username = $1"
|
|
||||||
var id int
|
var id int
|
||||||
var dbUsername, email string
|
var email string
|
||||||
err = db.DbPool.QueryRow(context.Background(), query, requestData.Username).Scan(&id, &dbUsername, &email)
|
var hashedPassword string
|
||||||
|
err := db.DbPool.QueryRow(context.Background(), query).Scan(&id, &email, &hashedPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "User not found", http.StatusNotFound)
|
http.Error(w, "Invalid email or password", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// send back the user data as a json response
|
// compare the provided password with the stored hashed password
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(requestData.Password))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid email or password", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// send back the user data as a JSON response
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"email": email,
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
|
||||||
|
log.Printf("JSON encoding error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// very secure login endpoint
|
||||||
|
func SecureLoginSql(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// decode the json body
|
||||||
|
var requestData struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// declare new json decoder with custom property
|
||||||
|
jsonDecoder := json.NewDecoder(r.Body)
|
||||||
|
// rejects any unknown fields in the json, more strict
|
||||||
|
jsonDecoder.DisallowUnknownFields()
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
|
||||||
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||||
|
log.Println("Failed to decode JSON body or extra fields present:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate that user is provided
|
||||||
|
if requestData.Email == "" || requestData.Password == "" {
|
||||||
|
http.Error(w, "Invalid request format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve list of existing emails from db
|
||||||
|
existingEmails, err := db.FetchEmails()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
log.Printf("Failed to fetch emails: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the email exists in the allowed list
|
||||||
|
// this step is crucial
|
||||||
|
// server will reject ANYTHING that does not match the list
|
||||||
|
if !existingEmails[requestData.Email] {
|
||||||
|
http.Error(w, "Invalid email or password", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve the email and password from the database
|
||||||
|
var id int
|
||||||
|
var email string
|
||||||
|
var hashedPassword string
|
||||||
|
query := "SELECT id, email, password FROM users WHERE email = $1"
|
||||||
|
err = db.DbPool.QueryRow(context.Background(), query, requestData.Email).Scan(&id, &email, &hashedPassword)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid email or password", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare the provided password and the stored password hashes
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(requestData.Password))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid email or password", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// send back the user data as a json response
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"id": id,
|
"id": id,
|
||||||
"username": dbUsername,
|
|
||||||
"email": email,
|
"email": email,
|
||||||
}
|
}
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
|
Loading…
Reference in a new issue