package db import ( "context" "encoding/json" "fmt" "log" "net/http" "github.com/jackc/pgx/v5/pgxpool" ) // db connection info // !MIGHT CHANGE const ( host = "localhost" port = 3335 user = "asdfuser" password = "asdfpassword" dbname = "asdfdb" ) var DbPool *pgxpool.Pool // initialize connection to db func ConnectToDb() (*pgxpool.Pool, error) { // this server is intended to be ran on the same system as the db dbUrl := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", user, password, host, port, dbname) config, err := pgxpool.ParseConfig(dbUrl) if err != nil { return nil, fmt.Errorf("unable to parse data URL: %w", err) } pool, err := pgxpool.NewWithConfig(context.Background(), config) if err != nil { return nil, fmt.Errorf("unable to create connection pool: %w", err) } // validate connection with a simple query conn, err := pool.Acquire(context.Background()) if err != nil { return nil, fmt.Errorf("unable to acquire a connection from the pool: %w", err) } defer conn.Release() // run a test query err = conn.QueryRow(context.Background(), "SELECT 1").Scan(new(int)) if err != nil { return nil, fmt.Errorf("unable to validate database connection: %w", err) } log.Printf("Connected to DB at port %d :)", port) return pool, nil } // ping the database to check health func DbHealthCheck(w http.ResponseWriter, r *http.Request) { // define the health check query healthCheckSQL := `SELECT 1;` // execute the query var result int err := DbPool.QueryRow(context.Background(), healthCheckSQL).Scan(&result) if err != nil { http.Error(w, "Database is unhealthy", http.StatusServiceUnavailable) log.Printf("Database health check failed: %v", err) return } // send success response w.WriteHeader(http.StatusOK) w.Write([]byte("Database is healthy")) log.Println("Database health check passed") } // setup demo db func SetupDemoDb(w http.ResponseWriter, r *http.Request) { // create table and insert demo data createTableSQL := ` CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, email VARCHAR(100) NOT NULL, password VARCHAR(100) NOT NULL, role VARCHAR(50) DEFAULT 'user' );` // avoid duplicate entries and specify roles insertDataSQL := ` INSERT INTO users (email, password, role) VALUES ('asdf@gmail.com', 'asdf', 'user'), ('alice@example.com', 'asdfalicepassword', 'user'), ('bob@example.com', 'asdfbobpassword', 'user'), ('charlie@example.com', 'asdfcharliepassword', 'admin') ` // execute create table _, err := DbPool.Exec(context.Background(), createTableSQL) if err != nil { http.Error(w, "Failed to create table", http.StatusInternalServerError) log.Printf("Error creating table: %v", err) return } // execute insert demo data _, err = DbPool.Exec(context.Background(), insertDataSQL) if err != nil { http.Error(w, "Failed to insert demo data", http.StatusInternalServerError) log.Printf("Error inserting demo data: %v", err) return } // response back to client w.WriteHeader(http.StatusOK) w.Write([]byte("Database setup complete with demo data")) log.Println("Demo database setup completed successfully") } // nuke the db func NukeDb(w http.ResponseWriter, r *http.Request) { // drop user table dropTableSQL := ` DROP TABLE IF EXISTS users CASCADE; ` // execute the command _, err := DbPool.Exec(context.Background(), dropTableSQL) if err != nil { http.Error(w, "Failed to drop table", http.StatusInternalServerError) log.Printf("Error dropping table: %v", err) return } w.WriteHeader(http.StatusOK) w.Write([]byte("Bye bye")) log.Println("Database nuked") } // fetch existing usernames from db func FetchEmails() (map[string]bool, error) { emails := make(map[string]bool) rows, err := DbPool.Query(context.Background(), "SELECT email FROM users") if err != nil { return nil, fmt.Errorf("error querying users: %w", err) } defer rows.Close() for rows.Next() { var email string if err := rows.Scan(&email); err != nil { return nil, fmt.Errorf("error scanning email: %w", err) } emails[email] = true } return emails, nil } // fetch all users for demo func FetchAllUsers(w http.ResponseWriter, r *http.Request) { // construct sql query to select all users query := "SELECT * FROM users" // execute the query rows, err := DbPool.Query(context.Background(), query) if err != nil { http.Error(w, "Failed to retrieve users", http.StatusInternalServerError) log.Printf("Error executing query: %v", err) return } defer rows.Close() // define a slice to hold user data users := []map[string]interface{}{} // get column names columnNames := rows.FieldDescriptions() // iterate over the rows and build the result set for rows.Next() { // create a slice to hold the values for each row values := make([]interface{}, len(columnNames)) valuePointers := make([]interface{}, len(columnNames)) for i := range values { valuePointers[i] = &values[i] } // scan the row into slice of interfaces if err := rows.Scan(valuePointers...); err != nil { http.Error(w, "Failed to scan user data", http.StatusInternalServerError) log.Printf("Error scanning row: %v", err) return } // create a map for the row data user := make(map[string]interface{}) for i, col := range columnNames { user[string(col.Name)] = values[i] } // append the user map to the users slice users = append(users, user) } // check for any errors encountered during the iteration if err = rows.Err(); err != nil { http.Error(w, "Internal server error", http.StatusInternalServerError) log.Printf("Error encountered during iteration of rows: %v", err) return } log.Printf("All Users Data: %v", users) // Encode the users slice as JSON and write it to the response w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(users); err != nil { http.Error(w, "Failed to encode response as JSON", http.StatusInternalServerError) log.Printf("Error encoding response: %v", err) return } log.Println("Response successfully written to client") }