cspj-application/server/internal/db/db.go
2025-01-16 03:18:48 +08:00

224 lines
6 KiB
Go

package db
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"github.com/jackc/pgx/v5/pgxpool"
)
// db connection info
// !MIGHT CHANGE
const (
host = "localhost"
port = 3335
user = "asdfuser"
password = "asdfpassword"
dbname = "asdfdb"
)
var DbPool *pgxpool.Pool
// initialize connection to db
func ConnectToDb() (*pgxpool.Pool, error) {
// this server is intended to be ran on the same system as the db
dbUrl := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", user, password, host, port, dbname)
config, err := pgxpool.ParseConfig(dbUrl)
if err != nil {
return nil, fmt.Errorf("unable to parse data URL: %w", err)
}
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
// validate connection with a simple query
conn, err := pool.Acquire(context.Background())
if err != nil {
return nil, fmt.Errorf("unable to acquire a connection from the pool: %w", err)
}
defer conn.Release()
// run a test query
err = conn.QueryRow(context.Background(), "SELECT 1").Scan(new(int))
if err != nil {
return nil, fmt.Errorf("unable to validate database connection: %w", err)
}
log.Printf("Connected to DB at port %d :)", port)
return pool, nil
}
// ping the database to check health
func DbHealthCheck(w http.ResponseWriter, r *http.Request) {
// define the health check query
healthCheckSQL := `SELECT 1;`
// execute the query
var result int
err := DbPool.QueryRow(context.Background(), healthCheckSQL).Scan(&result)
if err != nil {
http.Error(w, "Database is unhealthy", http.StatusServiceUnavailable)
log.Printf("Database health check failed: %v", err)
return
}
// send success response
w.WriteHeader(http.StatusOK)
w.Write([]byte("Database is healthy"))
log.Println("Database health check passed")
}
// setup demo db
func SetupDemoDb(w http.ResponseWriter, r *http.Request) {
// create table and insert demo data
createTableSQL := `
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
email VARCHAR(100) NOT NULL,
password VARCHAR(100) NOT NULL,
role VARCHAR(50) DEFAULT 'user'
);`
// avoid duplicate entries and specify roles
insertDataSQL := `
INSERT INTO users (email, password, role) VALUES
('alice@example.com', 'asdfalicepassword', 'user'),
('bob@example.com', 'asdfbobpassword', 'user'),
('charlie@example.com', 'asdfcharliepassword', 'admin')
`
// execute create table
_, err := DbPool.Exec(context.Background(), createTableSQL)
if err != nil {
http.Error(w, "Failed to create table", http.StatusInternalServerError)
log.Printf("Error creating table: %v", err)
return
}
// execute insert demo data
_, err = DbPool.Exec(context.Background(), insertDataSQL)
if err != nil {
http.Error(w, "Failed to insert demo data", http.StatusInternalServerError)
log.Printf("Error inserting demo data: %v", err)
return
}
// response back to client
w.WriteHeader(http.StatusOK)
w.Write([]byte("Database setup complete with demo data"))
log.Println("Demo database setup completed successfully")
}
// nuke the db
func NukeDb(w http.ResponseWriter, r *http.Request) {
// drop user table
dropTableSQL := `
DROP TABLE IF EXISTS users CASCADE;
`
// execute the command
_, err := DbPool.Exec(context.Background(), dropTableSQL)
if err != nil {
http.Error(w, "Failed to drop table", http.StatusInternalServerError)
log.Printf("Error dropping table: %v", err)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("Bye bye"))
log.Println("Database nuked")
}
// fetch existing usernames from db
func FetchEmails() (map[string]bool, error) {
emails := make(map[string]bool)
rows, err := DbPool.Query(context.Background(), "SELECT email FROM users")
if err != nil {
return nil, fmt.Errorf("error querying users: %w", err)
}
defer rows.Close()
for rows.Next() {
var email string
if err := rows.Scan(&email); err != nil {
return nil, fmt.Errorf("error scanning email: %w", err)
}
emails[email] = true
}
log.Println("Fetched emails:", emails)
return emails, nil
}
// fetch all users for demo
func FetchAllUsers(w http.ResponseWriter, r *http.Request) {
// construct sql query to select all users
query := "SELECT * FROM users"
// execute the query
rows, err := DbPool.Query(context.Background(), query)
if err != nil {
http.Error(w, "Failed to retrieve users", http.StatusInternalServerError)
log.Printf("Error executing query: %v", err)
return
}
defer rows.Close()
// define a slice to hold user data
users := []map[string]interface{}{}
// get column names
columnNames := rows.FieldDescriptions()
// iterate over the rows and build the result set
for rows.Next() {
// create a slice to hold the values for each row
values := make([]interface{}, len(columnNames))
valuePointers := make([]interface{}, len(columnNames))
for i := range values {
valuePointers[i] = &values[i]
}
// scan the row into slice of interfaces
if err := rows.Scan(valuePointers...); err != nil {
http.Error(w, "Failed to scan user data", http.StatusInternalServerError)
log.Printf("Error scanning row: %v", err)
return
}
// create a map for the row data
user := make(map[string]interface{})
for i, col := range columnNames {
user[string(col.Name)] = values[i]
}
// append the user map to the users slice
users = append(users, user)
}
// check for any errors encountered during the iteration
if err = rows.Err(); err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
log.Printf("Error encountered during iteration of rows: %v", err)
return
}
log.Printf("All Users Data: %v", users)
// Encode the users slice as JSON and write it to the response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(users); err != nil {
http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError)
log.Printf("Error encoding response: %v", err)
return
}
log.Println("Response successfully written to client")
}