From 99d035c9151a87e0ee3c35a80af33875316f3e7f Mon Sep 17 00:00:00 2001 From: Gabriel De Los Rios Date: Mon, 3 Nov 2025 00:20:19 -0300 Subject: [PATCH] refactor: enhance logic to handle missing db or table scenarios --- cmd/server/main.go | 2 +- internal/database/database.go | 111 +++++++++++++++++++++++++++++----- 2 files changed, 98 insertions(+), 15 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index ddfc5cc..966f057 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -25,7 +25,7 @@ import ( // @host localhost:8080 // @BasePath / func main() { - err := database.InitDB(database.Conn_string) + err := database.InitDB() defer database.CloseDB() if err != nil { log.Fatal("Database connection failed:", err) diff --git a/internal/database/database.go b/internal/database/database.go index 779ab83..f18d5f0 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -3,37 +3,100 @@ package database import ( "database/sql" "fmt" + "log" "os" + "regexp" "time" _ "github.com/go-sql-driver/mysql" ) var DB *sql.DB -var dbName = os.Getenv("DB_NAME") -var dbUser = os.Getenv("DB_USER") -var dbPassword = os.Getenv("DB_PASSWORD") -var dbHost = os.Getenv("DB_HOST") -var dbPort = os.Getenv("DB_PORT") -var Conn_string = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&charset=utf8mb4&collation=utf8mb4_unicode_ci", dbUser, dbPassword, dbHost, dbPort, dbName) +func InitDB() error { -func InitDB(dataSourceName string) error { - var err error - - DB, err = sql.Open("mysql", dataSourceName) - if err != nil { + if err := ensureDatabaseExists(); err != nil { return err } + if err := connectToDatabase(); err != nil { + return err + } + + if err := ensureTablesExist(); err != nil { + return err + } + + log.Println("Database initialized successfully") + return nil +} + +func ensureDatabaseExists() error { + dsn := getDSN("") + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("failed to connect to database server: %w", err) + } + defer db.Close() + + // Get database name from environment + dbName := os.Getenv("DB_NAME") + + // Check if database exists, create if not + var exists bool + err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)", dbName).Scan(&exists) + if err != nil { + return fmt.Errorf("failed to check database existence: %w", err) + } + + if !exists { + if !isValidDatabaseName(dbName) { + return fmt.Errorf("invalid database name: %s", dbName) + } + _, err = db.Exec("CREATE DATABASE `" + dbName + "` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") + + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + } + + return nil +} + +func connectToDatabase() error { + dsn := getDSN(os.Getenv("DB_NAME")) + + var err error + DB, err = sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + DB.SetMaxOpenConns(25) DB.SetMaxIdleConns(25) DB.SetConnMaxLifetime(5 * time.Minute) - if err = DB.Ping(); err != nil { - return err + return nil +} + +func ensureTablesExist() error { + + _, err := DB.Exec(` + CREATE TABLE IF NOT EXISTS contacts ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(120) NOT NULL, + company VARCHAR(120) NOT NULL, + phone VARCHAR(15) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + `) + + if err != nil { + return fmt.Errorf("failed to create contacts table: %w", err) } - println("DB connected") + return nil } @@ -42,3 +105,23 @@ func CloseDB() { DB.Close() } } + +func getDSN(dbName string) string { + user := os.Getenv("DB_USER") + password := os.Getenv("DB_PASSWORD") + host := os.Getenv("DB_HOST") + port := os.Getenv("DB_PORT") + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/", user, password, host, port) + if dbName != "" { + dsn += dbName + } + dsn += "?parseTime=true&charset=utf8mb4&collation=utf8mb4_unicode_ci" + + return dsn +} + +func isValidDatabaseName(name string) bool { + matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, name) + return matched && len(name) > 0 && len(name) <= 64 +}