package database import ( "database/sql" "fmt" "log" "os" "regexp" "time" _ "github.com/go-sql-driver/mysql" ) var DB *sql.DB func InitDB() error { if err := ensureDatabaseExists(); err != nil { return err } if err := connectToDatabase(); err != nil { return err } if err := ensureTablesExist(); err != nil { return err } log.Println("Database initialized successfully") return nil } func ensureDatabaseExists() error { dsn := getDSN("") db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to connect to database server: %w", err) } defer db.Close() // Get database name from environment dbName := os.Getenv("DB_NAME") // Check if database exists, create if not var exists bool err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)", dbName).Scan(&exists) if err != nil { return fmt.Errorf("failed to check database existence: %w", err) } if !exists { if !isValidDatabaseName(dbName) { return fmt.Errorf("invalid database name: %s", dbName) } _, err = db.Exec("CREATE DATABASE `" + dbName + "` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") if err != nil { return fmt.Errorf("failed to create database: %w", err) } } return nil } func connectToDatabase() error { dsn := getDSN(os.Getenv("DB_NAME")) var err error DB, err = sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } DB.SetMaxOpenConns(25) DB.SetMaxIdleConns(25) DB.SetConnMaxLifetime(5 * time.Minute) return nil } func ensureTablesExist() error { _, err := DB.Exec(` CREATE TABLE IF NOT EXISTS contacts ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(120) NOT NULL, company VARCHAR(120) NOT NULL, phone VARCHAR(15) NOT NULL ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; `) if err != nil { return fmt.Errorf("failed to create contacts table: %w", err) } return nil } func CloseDB() { if DB != nil { DB.Close() } } func getDSN(dbName string) string { user := os.Getenv("DB_USER") password := os.Getenv("DB_PASSWORD") host := os.Getenv("DB_HOST") port := os.Getenv("DB_PORT") dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/", user, password, host, port) if dbName != "" { dsn += dbName } dsn += "?parseTime=true&charset=utf8mb4&collation=utf8mb4_unicode_ci" return dsn } func isValidDatabaseName(name string) bool { matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, name) return matched && len(name) > 0 && len(name) <= 64 }