package postgres import ( "context" "errors" "fmt" "github.com/Pena-Co-Ltd/amocrm_templategen_back/dal/model" rs "github.com/danilsolovyov/reflectgostructv1" "github.com/jackc/pgx/v4/pgxpool" "golang.org/x/crypto/bcrypt" "strconv" "strings" "time" ) type User struct { conn *pgxpool.Pool } func InitUser(ctx context.Context, conn *pgxpool.Pool) *User { d := &User{conn: conn} err := d.init(ctx) if err != nil { //glg.Error("ErrInitUser:", err) return nil } return d } func (d *User) init(ctx context.Context) error { s := rs.PsqlTagToSql(&model.User{}) sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"user\" (%v);", s) _, err := d.conn.Exec(ctx, sql) return err } func (d *User) Insert(ctx context.Context, record *model.User) (string, error) { now := time.Now().UTC() record.CreatedAt = now record.UpdatedAt = now record.IsDeleted = false conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return "", err } // Find users by email sql := fmt.Sprintf("SELECT id FROM \"user\" WHERE email = '%v'", record.Email) var foundID int err = conn.QueryRow(ctx, sql).Scan(&foundID) if err != nil { if err.Error() != ErrorNotFound.Error() { return "", err } } if foundID > 0 { fmt.Println("user already exists", foundID) err = errors.New("user already exists") return "", err } gpass, err := bcrypt.GenerateFromPassword([]byte(record.Password), bcrypt.DefaultCost) if err != nil { return "", err } record.Password = string(gpass) tags, values := rs.GetPsqlTagsAndValues(record) sql = fmt.Sprintf("INSERT INTO \"user\" (%v) VALUES (%v) RETURNING id;", tags, values) var id int err = conn.QueryRow(ctx, sql).Scan(&id) conn.Release() if err != nil { return "", err } return strconv.Itoa(id), nil } func (d *User) GetByID(ctx context.Context, id int) (*model.User, error) { var result model.User conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return nil, err } sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE id = %v", id) rows, err := conn.Query(ctx, sql) conn.Release() if err != nil { return nil, err } for rows.Next() { err = rows.Scan(&result.ID, &result.FullName, &result.Email, &result.Password, &result.IsActivated, &result.JwtToken, &result.IsDeleted, &result.CreatedAt, &result.UpdatedAt) if err != nil { return nil, err } } return &result, err } func (d *User) GetByEmail(ctx context.Context, email string) (*model.User, error) { var result model.User conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return nil, err } sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE email = '%v'", email) rows, err := conn.Query(ctx, sql) conn.Release() if err != nil { return nil, err } id := 0 for rows.Next() { err = rows.Scan(&id, &result.FullName, &result.Email, &result.Password, &result.IsActivated, &result.RoleID, &result.JwtToken, &result.IsDeleted, &result.CreatedAt, &result.UpdatedAt) if err != nil { return nil, err } } result.ID = strconv.Itoa(id) return &result, err } func (d *User) GetByFilter(ctx context.Context, start, count int, needle map[string]string) ([]model.User, error) { var result []model.User conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return nil, err } needleToSql := "" if len(needle) > 0 { needleToSql += " AND " i := 0 for k, v := range needle { v = strings.ReplaceAll(v, "'", "''") needleToSql += fmt.Sprintf("%v = %v", k, v) if i < len(needle) { needleToSql += " AND " } i++ } } sql := fmt.Sprintf("SELECT * FROM \"user\" WHERE (id >= %v %v) ORDER BY id LIMIT %v;", start, needleToSql, count, ) rows, err := conn.Query(ctx, sql) conn.Release() if err != nil { return nil, err } if rows == nil { err = ErrorGotEmptyRow return nil, err } for rows.Next() { var u model.User err = rows.Scan(&u.ID, &u.FullName, &u.Email, &u.Password, &u.IsActivated, &u.RoleID, &u.JwtToken, &u.IsDeleted, &u.CreatedAt, &u.UpdatedAt) if err != nil { return nil, err } result = append(result, u) } return result, nil } func (d *User) UpdateByID(ctx context.Context, record *model.User) error { record.UpdatedAt = time.Now() conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return err } tags, values := rs.GetPsqlTagsAndValues(record) sql := fmt.Sprintf("UPDATE \"user\" SET (%v) = (%v) WHERE id = %v;", tags, values, record.ID) err = conn.QueryRow(ctx, sql).Scan() conn.Release() if err != nil { return err } return nil } func (d *User) DeleteByID(ctx context.Context, id int) error { conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return err } sql := fmt.Sprintf("UPDATE \"user\" SET is_deleted = true WHERE id = %v", id) err = conn.QueryRow(ctx, sql).Scan() conn.Release() if err != nil { return err } return nil } func (d *User) SetJwtToken(ctx context.Context, id int, jwtToken string) error { conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return err } sql := fmt.Sprintf("UPDATE \"user\" SET jwt_token = '%v' WHERE id = %v", jwtToken, id) err = conn.QueryRow(ctx, sql).Scan() conn.Release() if err != nil { return err } return nil } func (d *User) ChangePassword(ctx context.Context, id int, password string) error { gpass, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return err } conn, err := d.conn.Acquire(ctx) if err != nil { conn.Release() return err } sql := fmt.Sprintf("UPDATE \"user\" SET password = '%v' WHERE id = %v", gpass, id) err = conn.QueryRow(ctx, sql).Scan() conn.Release() if err != nil { return err } return nil }