diff --git a/internal/repository/base_repository.go b/internal/repository/base_repository.go index 472898e..118a567 100644 --- a/internal/repository/base_repository.go +++ b/internal/repository/base_repository.go @@ -10,7 +10,7 @@ type baseRepository[T any] struct { tableName string } -func NewBaseRepository[T any](db *sql.DB, tableName string) Repository[T] { +func NewBaseRepository[T any](db *sql.DB, tableName string) *baseRepository[T] { return &baseRepository[T]{ db: db, tableName: tableName, @@ -25,14 +25,34 @@ func (r *baseRepository[T]) GetDB() *sql.DB { return r.db } +func (r *baseRepository[T]) GetAll() ([]T, error) { + query := r.BuildQuery("SELECT * FROM %s ORDER BY id DESC") + + rows, err := r.db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + entities := make([]T, 0) + rowsErr := ScanRows(rows, &entities) + + if rowsErr != nil { + return nil, err + } + + return entities, nil +} + func (r *baseRepository[T]) GetByID(id int) (*T, error) { var entity T query := r.BuildQuery("SELECT * FROM %s WHERE id = ?") - err := r.db.QueryRow( + row := r.db.QueryRow( query, id, - ).Scan(entity) + ) + err := scanRow(row, &entity) if err != nil { if err == sql.ErrNoRows { @@ -44,33 +64,8 @@ func (r *baseRepository[T]) GetByID(id int) (*T, error) { return &entity, nil } -func (r *baseRepository[T]) GetAll() ([]T, error) { - query := r.BuildQuery("SELECT * FROM %s ORDER BY id DESC") - - rows, err := r.db.Query(query) - if err != nil { - return nil, err - } - defer rows.Close() - - entities := make([]T, 0) - for rows.Next() { - var entity T - if err := rows.Scan(entity); err != nil { - return nil, err - } - entities = append(entities, entity) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return entities, nil -} - func (r *baseRepository[T]) Delete(id int) (int64, error) { - query := r.BuildQuery("DELETE %s WHERE id = ? LIMIT 1") + query := r.BuildQuery("DELETE FROM %s WHERE id = ?") res, err := r.db.Exec(query, id) if err != nil { return 0, err diff --git a/internal/repository/contact_repository.go b/internal/repository/contact_repository.go index 9f34a5b..5863bb9 100644 --- a/internal/repository/contact_repository.go +++ b/internal/repository/contact_repository.go @@ -2,6 +2,7 @@ package repository import ( "database/sql" + "fmt" "gitea.gabilandia.com/gabdlr/agenda-web-go/internal/models" ) @@ -10,7 +11,7 @@ type ContactRepository struct { baseRepository[models.Contact] } -func NewContactRepository(db *sql.DB) *ContactRepository { +func NewContactRepository(db *sql.DB) Repository[models.Contact] { return &ContactRepository{ baseRepository[models.Contact]{ db: db, @@ -38,20 +39,45 @@ func (r *ContactRepository) Create(contact *models.Contact) (int64, error) { return id, nil } -func (r *ContactRepository) Update(contact *models.Contact) (int64, error) { - query := r.BuildQuery("UPDATE %s SET name = ?, company = ?, phone = ? WHERE id = ?") - result, err := r.db.Exec(query, - contact.Name, contact.Company, contact.Phone, contact.ID, +func (r *ContactRepository) Update(contact *models.Contact) error { + query := r.BuildQuery("UPDATE %s SET") + fieldsToUpdate := make([]string, 0, 4) + fields := make([]any, 0) + + if contact.Name != "" { + fieldsToUpdate = append(fieldsToUpdate, "name") + fields = append(fields, &contact.Name) + } + + if contact.Company != "" { + fieldsToUpdate = append(fieldsToUpdate, "company") + fields = append(fields, &contact.Company) + } + + if contact.Phone != "" { + fieldsToUpdate = append(fieldsToUpdate, "phone") + fields = append(fields, &contact.Phone) + } + + fields = append(fields, &contact.ID) + + fieldsToUpdatelen := len(fieldsToUpdate) + for i, field := range fieldsToUpdate { + query += fmt.Sprintf(" %s = ?", field) + if i != fieldsToUpdatelen-1 { + query += "," + } + } + + query += " WHERE id = ?" + + _, err := r.db.Exec(query, + fields..., ) if err != nil { - return 0, err + return err } - rowsAffected, err := result.RowsAffected() - if err != nil { - return 0, err - } - - return rowsAffected, nil + return nil } diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 8043df5..710900f 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -1,8 +1,9 @@ package repository type Repository[T any] interface { - BuildQuery(s string) string - GetByID(id int) (*T, error) - GetAll() ([]T, error) + Create(T *T) (int64, error) Delete(id int) (int64, error) + GetAll() ([]T, error) + GetByID(id int) (*T, error) + Update(contact *T) error } diff --git a/internal/repository/scanner_helper.go b/internal/repository/scanner_helper.go new file mode 100644 index 0000000..5592375 --- /dev/null +++ b/internal/repository/scanner_helper.go @@ -0,0 +1,55 @@ +package repository + +import ( + "database/sql" + "fmt" + "reflect" +) + +func scanRow(row *sql.Row, dest any) error { + destValue := reflect.ValueOf(dest).Elem() + fields := make([]any, destValue.NumField()) + + for i := 0; i < destValue.NumField(); i++ { + fields[i] = destValue.Field(i).Addr().Interface() + } + + return row.Scan(fields...) +} + +func ScanRows(rows *sql.Rows, destSlice any) error { + sliceValue := reflect.ValueOf(destSlice) + if sliceValue.Kind() != reflect.Pointer || sliceValue.Elem().Kind() != reflect.Slice { + return fmt.Errorf("destSlice must be a pointer to a slice") + } + + sliceElem := sliceValue.Elem() + structType := sliceElem.Type().Elem() + + for rows.Next() { + newStruct := reflect.New(structType).Elem() + + fields := make([]any, newStruct.NumField()) + for i := 0; i < newStruct.NumField(); i++ { + fields[i] = newStruct.Field(i).Addr().Interface() + } + + if err := rows.Scan(fields...); err != nil { + return err + } + + sliceElem.Set(reflect.Append(sliceElem, newStruct)) + } + + return rows.Err() +} + +func GetStructFieldsPtr(object any) []any { + destValue := reflect.ValueOf(object).Elem() + fields := make([]any, destValue.NumField()) + + for i := 0; i < destValue.NumField(); i++ { + fields[i] = destValue.Field(i).Addr().Interface() + } + return fields +}