277 lines
5.6 KiB
Go
277 lines
5.6 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
rs "github.com/danilsolovyov/reflectgostructv1"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"penahub.gitlab.yandexcloud.net/backend/templategen/dal/model"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type User struct {
|
|
conn *pgxpool.Pool
|
|
}
|
|
|
|
func InitUser(ctx context.Context, conn *pgxpool.Pool) *User {
|
|
d := &User{conn: conn}
|
|
err := d.init(ctx)
|
|
if err != nil {
|
|
//glg.Error("ErrInitUser:", err)
|
|
return nil
|
|
}
|
|
|
|
return d
|
|
}
|
|
|
|
func (d *User) init(ctx context.Context) error {
|
|
s := rs.PsqlTagToSql(&model.User{})
|
|
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"user\" (%v);", s)
|
|
_, err := d.conn.Exec(ctx, sql)
|
|
return err
|
|
}
|
|
|
|
func (d *User) Insert(ctx context.Context, record *model.User) (string, error) {
|
|
now := time.Now().UTC()
|
|
record.CreatedAt = now
|
|
record.UpdatedAt = now
|
|
record.IsDeleted = false
|
|
|
|
conn, err := d.conn.Acquire(ctx)
|
|
|
|
if err != nil {
|
|
conn.Release()
|
|
return "", err
|
|
}
|
|
|
|
// Find users by email
|
|
sql := fmt.Sprintf("SELECT id FROM \"user\" WHERE email = '%v'", record.Email)
|
|
|
|
var foundID int
|
|
err = conn.QueryRow(ctx, sql).Scan(&foundID)
|
|
if err != nil {
|
|
if err.Error() != ErrorNotFound.Error() {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
if foundID > 0 {
|
|
fmt.Println("user already exists", foundID)
|
|
err = errors.New("user already exists")
|
|
return "", err
|
|
}
|
|
|
|
gpass, err := bcrypt.GenerateFromPassword([]byte(record.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
record.Password = string(gpass)
|
|
|
|
tags, values := rs.GetPsqlTagsAndValues(record)
|
|
sql = fmt.Sprintf("INSERT INTO \"user\" (%v) VALUES (%v) RETURNING id;", tags, values)
|
|
|
|
var id int
|
|
err = conn.QueryRow(ctx, sql).Scan(&id)
|
|
conn.Release()
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return strconv.Itoa(id), nil
|
|
}
|
|
|
|
func (d *User) GetByID(ctx context.Context, id int) (*model.User, error) {
|
|
var result model.User
|
|
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return nil, err
|
|
}
|
|
|
|
sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE id = %v", id)
|
|
rows, err := conn.Query(ctx, sql)
|
|
conn.Release()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
err = rows.Scan(&result.ID, &result.FullName, &result.Email, &result.Password,
|
|
&result.IsActivated, &result.JwtToken, &result.IsDeleted, &result.CreatedAt, &result.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &result, err
|
|
}
|
|
|
|
func (d *User) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
|
var result model.User
|
|
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return nil, err
|
|
}
|
|
|
|
sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE email = '%v'", email)
|
|
rows, err := conn.Query(ctx, sql)
|
|
conn.Release()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
id := 0
|
|
|
|
for rows.Next() {
|
|
err = rows.Scan(&id, &result.FullName, &result.Email, &result.Password,
|
|
&result.IsActivated, &result.RoleID, &result.JwtToken, &result.IsDeleted, &result.CreatedAt,
|
|
&result.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
result.ID = strconv.Itoa(id)
|
|
|
|
return &result, err
|
|
}
|
|
|
|
func (d *User) GetByFilter(ctx context.Context, start, count int, needle map[string]string) ([]model.User, error) {
|
|
var result []model.User
|
|
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return nil, err
|
|
}
|
|
|
|
needleToSql := ""
|
|
if len(needle) > 0 {
|
|
needleToSql += " AND "
|
|
i := 0
|
|
for k, v := range needle {
|
|
v = strings.ReplaceAll(v, "'", "''")
|
|
needleToSql += fmt.Sprintf("%v = %v", k, v)
|
|
if i < len(needle) {
|
|
needleToSql += " AND "
|
|
}
|
|
i++
|
|
}
|
|
}
|
|
|
|
sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE (id >= %v %v) ORDER BY id LIMIT %v;",
|
|
start,
|
|
needleToSql,
|
|
count,
|
|
)
|
|
rows, err := conn.Query(ctx, sql)
|
|
conn.Release()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if rows == nil {
|
|
err = ErrorGotEmptyRow
|
|
return nil, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
var u model.User
|
|
err = rows.Scan(&u.ID, &u.FullName, &u.Email, &u.Password,
|
|
&u.IsActivated, &u.RoleID, &u.JwtToken, &u.IsDeleted, &u.CreatedAt, &u.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, u)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (d *User) UpdateByID(ctx context.Context, record *model.User) error {
|
|
record.UpdatedAt = time.Now()
|
|
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return err
|
|
}
|
|
|
|
tags, values := rs.GetPsqlTagsAndValues(record)
|
|
sql := fmt.Sprintf("UPDATE \"user\" SET (%v) = (%v) WHERE id = %v;", tags, values, record.ID)
|
|
err = conn.QueryRow(ctx, sql).Scan()
|
|
conn.Release()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *User) DeleteByID(ctx context.Context, id int) error {
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return err
|
|
}
|
|
|
|
sql := fmt.Sprintf("UPDATE \"user\" SET is_deleted = true WHERE id = %v", id)
|
|
err = conn.QueryRow(ctx, sql).Scan()
|
|
conn.Release()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *User) SetJwtToken(ctx context.Context, id int, jwtToken string) error {
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return err
|
|
}
|
|
|
|
sql := fmt.Sprintf("UPDATE \"user\" SET jwt_token = '%v' WHERE id = %v", jwtToken, id)
|
|
err = conn.QueryRow(ctx, sql).Scan()
|
|
conn.Release()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *User) ChangePassword(ctx context.Context, id int, password string) error {
|
|
gpass, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
conn, err := d.conn.Acquire(ctx)
|
|
if err != nil {
|
|
conn.Release()
|
|
return err
|
|
}
|
|
|
|
sql := fmt.Sprintf("UPDATE \"user\" SET password = '%v' WHERE id = %v", gpass, id)
|
|
err = conn.QueryRow(ctx, sql).Scan()
|
|
conn.Release()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|