docxTemplater/dal/postgres/user.go
2022-12-31 14:46:28 +00:00

277 lines
5.6 KiB
Go

package postgres
import (
"context"
"errors"
"fmt"
rs "github.com/danilsolovyov/reflectgostructv1"
"github.com/jackc/pgx/v4/pgxpool"
"golang.org/x/crypto/bcrypt"
"penahub.gitlab.yandexcloud.net/backend/templategen/dal/model"
"strconv"
"strings"
"time"
)
type User struct {
conn *pgxpool.Pool
}
func InitUser(ctx context.Context, conn *pgxpool.Pool) *User {
d := &User{conn: conn}
err := d.init(ctx)
if err != nil {
//glg.Error("ErrInitUser:", err)
return nil
}
return d
}
func (d *User) init(ctx context.Context) error {
s := rs.PsqlTagToSql(&model.User{})
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"user\" (%v);", s)
_, err := d.conn.Exec(ctx, sql)
return err
}
func (d *User) Insert(ctx context.Context, record *model.User) (string, error) {
now := time.Now().UTC()
record.CreatedAt = now
record.UpdatedAt = now
record.IsDeleted = false
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return "", err
}
// Find users by email
sql := fmt.Sprintf("SELECT id FROM \"user\" WHERE email = '%v'", record.Email)
var foundID int
err = conn.QueryRow(ctx, sql).Scan(&foundID)
if err != nil {
if err.Error() != ErrorNotFound.Error() {
return "", err
}
}
if foundID > 0 {
fmt.Println("user already exists", foundID)
err = errors.New("user already exists")
return "", err
}
gpass, err := bcrypt.GenerateFromPassword([]byte(record.Password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
record.Password = string(gpass)
tags, values := rs.GetPsqlTagsAndValues(record)
sql = fmt.Sprintf("INSERT INTO \"user\" (%v) VALUES (%v) RETURNING id;", tags, values)
var id int
err = conn.QueryRow(ctx, sql).Scan(&id)
conn.Release()
if err != nil {
return "", err
}
return strconv.Itoa(id), nil
}
func (d *User) GetByID(ctx context.Context, id int) (*model.User, error) {
var result model.User
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return nil, err
}
sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE id = %v", id)
rows, err := conn.Query(ctx, sql)
conn.Release()
if err != nil {
return nil, err
}
for rows.Next() {
err = rows.Scan(&result.ID, &result.FullName, &result.Email, &result.Password,
&result.IsActivated, &result.JwtToken, &result.IsDeleted, &result.CreatedAt, &result.UpdatedAt)
if err != nil {
return nil, err
}
}
return &result, err
}
func (d *User) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var result model.User
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return nil, err
}
sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE email = '%v'", email)
rows, err := conn.Query(ctx, sql)
conn.Release()
if err != nil {
return nil, err
}
id := 0
for rows.Next() {
err = rows.Scan(&id, &result.FullName, &result.Email, &result.Password,
&result.IsActivated, &result.RoleID, &result.JwtToken, &result.IsDeleted, &result.CreatedAt,
&result.UpdatedAt)
if err != nil {
return nil, err
}
}
result.ID = strconv.Itoa(id)
return &result, err
}
func (d *User) GetByFilter(ctx context.Context, start, count int, needle map[string]string) ([]model.User, error) {
var result []model.User
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return nil, err
}
needleToSql := ""
if len(needle) > 0 {
needleToSql += " AND "
i := 0
for k, v := range needle {
v = strings.ReplaceAll(v, "'", "''")
needleToSql += fmt.Sprintf("%v = %v", k, v)
if i < len(needle) {
needleToSql += " AND "
}
i++
}
}
sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE (id >= %v %v) ORDER BY id LIMIT %v;",
start,
needleToSql,
count,
)
rows, err := conn.Query(ctx, sql)
conn.Release()
if err != nil {
return nil, err
}
if rows == nil {
err = ErrorGotEmptyRow
return nil, err
}
for rows.Next() {
var u model.User
err = rows.Scan(&u.ID, &u.FullName, &u.Email, &u.Password,
&u.IsActivated, &u.RoleID, &u.JwtToken, &u.IsDeleted, &u.CreatedAt, &u.UpdatedAt)
if err != nil {
return nil, err
}
result = append(result, u)
}
return result, nil
}
func (d *User) UpdateByID(ctx context.Context, record *model.User) error {
record.UpdatedAt = time.Now()
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return err
}
tags, values := rs.GetPsqlTagsAndValues(record)
sql := fmt.Sprintf("UPDATE \"user\" SET (%v) = (%v) WHERE id = %v;", tags, values, record.ID)
err = conn.QueryRow(ctx, sql).Scan()
conn.Release()
if err != nil {
return err
}
return nil
}
func (d *User) DeleteByID(ctx context.Context, id int) error {
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return err
}
sql := fmt.Sprintf("UPDATE \"user\" SET is_deleted = true WHERE id = %v", id)
err = conn.QueryRow(ctx, sql).Scan()
conn.Release()
if err != nil {
return err
}
return nil
}
func (d *User) SetJwtToken(ctx context.Context, id int, jwtToken string) error {
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return err
}
sql := fmt.Sprintf("UPDATE \"user\" SET jwt_token = '%v' WHERE id = %v", jwtToken, id)
err = conn.QueryRow(ctx, sql).Scan()
conn.Release()
if err != nil {
return err
}
return nil
}
func (d *User) ChangePassword(ctx context.Context, id int, password string) error {
gpass, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
conn, err := d.conn.Acquire(ctx)
if err != nil {
conn.Release()
return err
}
sql := fmt.Sprintf("UPDATE \"user\" SET password = '%v' WHERE id = %v", gpass, id)
err = conn.QueryRow(ctx, sql).Scan()
conn.Release()
if err != nil {
return err
}
return nil
}