2024-11-11 17:34:37 +08:00
|
|
|
package db
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2024-11-11 18:47:15 +08:00
|
|
|
"encoding/json"
|
2024-11-11 17:34:37 +08:00
|
|
|
"fmt"
|
|
|
|
"log"
|
2024-11-11 17:46:30 +08:00
|
|
|
"net/http"
|
2024-11-11 17:34:37 +08:00
|
|
|
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
)
|
|
|
|
|
|
|
|
// db connection info
|
|
|
|
// !MIGHT CHANGE
|
|
|
|
const (
|
|
|
|
host = "localhost"
|
|
|
|
port = 5432
|
|
|
|
user = "asdfuser"
|
|
|
|
password = "asdfpassword"
|
|
|
|
dbname = "asdfdb"
|
|
|
|
)
|
|
|
|
|
|
|
|
var DbPool *pgxpool.Pool
|
|
|
|
var allowedUsernames map[string]bool
|
|
|
|
|
|
|
|
// initialize connection to db
|
|
|
|
func ConnectToDb() (*pgxpool.Pool, error) {
|
|
|
|
// this server is intended to be ran on the same system as the db
|
|
|
|
dbUrl := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", user, password, host, port, dbname)
|
|
|
|
config, err := pgxpool.ParseConfig((dbUrl))
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("unable to parse data URL: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("unable to create connection pool: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println("Connected to DB :)")
|
|
|
|
return pool, nil
|
|
|
|
}
|
|
|
|
|
2024-11-11 17:46:30 +08:00
|
|
|
// setup demo db
|
|
|
|
func SetupDemoDb(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// create table and insert demo data
|
|
|
|
createTableSQL := `
|
2024-11-12 11:53:55 +08:00
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
|
|
id SERIAL PRIMARY KEY,
|
|
|
|
username VARCHAR(50) UNIQUE NOT NULL,
|
|
|
|
email VARCHAR(100) NOT NULL,
|
|
|
|
password VARCHAR(100) NOT NULL
|
|
|
|
);`
|
2024-11-11 17:46:30 +08:00
|
|
|
|
|
|
|
// also avoid duplicate entries
|
|
|
|
insertDataSQL := `
|
2024-11-12 11:53:55 +08:00
|
|
|
INSERT INTO users (username, email, password) VALUES
|
|
|
|
('alice', 'alice@example.com', 'asdfalicepassword'),
|
|
|
|
('bob', 'bob@example.com', 'asdfbobpassword'),
|
|
|
|
('charlie', 'charlie@example.com', 'asdfcharliepassword')
|
|
|
|
ON CONFLICT (username) DO NOTHING;`
|
2024-11-11 17:46:30 +08:00
|
|
|
|
|
|
|
// execute create table
|
|
|
|
_, err := DbPool.Exec(context.Background(), createTableSQL)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Failed to create table", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error creating table: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// execute insert demo data
|
|
|
|
_, err = DbPool.Exec(context.Background(), insertDataSQL)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Failed to insert demo data", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error inserting demo data: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// response back to client
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
w.Write([]byte("Database setup complete with demo data"))
|
|
|
|
log.Println("Demo database setup completed successfully")
|
|
|
|
}
|
|
|
|
|
|
|
|
// nuke the db
|
|
|
|
func NukeDb(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// drop user table
|
|
|
|
dropTableSQL := `
|
|
|
|
DROP TABLE IF EXISTS users CASCADE;
|
|
|
|
`
|
|
|
|
|
|
|
|
// execute the command
|
|
|
|
_, err := DbPool.Exec(context.Background(), dropTableSQL)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Failed to drop table", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error dropping table: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-11-11 18:47:15 +08:00
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
w.Write([]byte("Bye bye"))
|
2024-11-11 17:46:30 +08:00
|
|
|
log.Println("Database nuked")
|
|
|
|
}
|
|
|
|
|
2024-11-11 17:34:37 +08:00
|
|
|
// fetch existing usernames from db
|
|
|
|
func FetchUsernames() (map[string]bool, error) {
|
|
|
|
usernames := make(map[string]bool)
|
|
|
|
|
|
|
|
rows, err := DbPool.Query(context.Background(), "SELECT username FROM users")
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error querying users: %w", err)
|
|
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
|
|
for rows.Next() {
|
|
|
|
var username string
|
|
|
|
if err := rows.Scan(&username); err != nil {
|
|
|
|
return nil, fmt.Errorf("error scanning username: %w", err)
|
|
|
|
}
|
|
|
|
usernames[username] = true
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println("Fetched usernames:", usernames)
|
|
|
|
return usernames, nil
|
|
|
|
}
|
2024-11-11 18:47:15 +08:00
|
|
|
|
|
|
|
// fetch all users for demo
|
|
|
|
func FetchAllUsers(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// construct sql query to select all users
|
|
|
|
query := "SELECT * FROM users"
|
|
|
|
|
|
|
|
// execute the query
|
|
|
|
rows, err := DbPool.Query(context.Background(), query)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, "Failed to retrieve users", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error executing query: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
|
|
// define a slice to hold user data
|
|
|
|
users := []map[string]interface{}{}
|
|
|
|
|
|
|
|
// get column names
|
|
|
|
columnNames := rows.FieldDescriptions()
|
|
|
|
|
|
|
|
// iterate over the rows and build the result set
|
|
|
|
for rows.Next() {
|
|
|
|
// create a slice to hold the values for each row
|
|
|
|
values := make([]interface{}, len(columnNames))
|
|
|
|
valuePointers := make([]interface{}, len(columnNames))
|
|
|
|
for i := range values {
|
|
|
|
valuePointers[i] = &values[i]
|
|
|
|
}
|
|
|
|
|
|
|
|
// scan the row into slice of interfaces
|
|
|
|
if err := rows.Scan(valuePointers...); err != nil {
|
|
|
|
http.Error(w, "Failed to scan user data", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error scanning row: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// create a map for the row data
|
|
|
|
user := make(map[string]interface{})
|
|
|
|
for i, col := range columnNames {
|
|
|
|
user[string(col.Name)] = values[i]
|
|
|
|
}
|
|
|
|
|
|
|
|
// append the user map to the users slice
|
|
|
|
users = append(users, user)
|
|
|
|
}
|
|
|
|
|
|
|
|
// check for any errors encountered during the iteration
|
|
|
|
if err = rows.Err(); err != nil {
|
|
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error encountered during iteration of rows: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Printf("All Users Data: %v", users)
|
|
|
|
|
|
|
|
// Encode the users slice as JSON and write it to the response
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
if err := json.NewEncoder(w).Encode(users); err != nil {
|
|
|
|
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
|
|
|
|
log.Printf("Error encoding response: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println("Response successfully written to client")
|
|
|
|
}
|