236 lines
5.9 KiB
Go
236 lines
5.9 KiB
Go
package gigachat
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"gitea.pena/SQuiz/common/model"
|
||
"gitea.pena/SQuiz/worker/internal/senders"
|
||
"github.com/go-redis/redis/v8"
|
||
"github.com/go-resty/resty/v2"
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
"time"
|
||
)
|
||
|
||
type Deps struct {
|
||
Logger *zap.Logger
|
||
Client *resty.Client
|
||
BaseURL string
|
||
AuthKey string
|
||
RedisClient *redis.Client
|
||
TgSender *senders.TgSender
|
||
TgChatID int64
|
||
}
|
||
|
||
type GigaChatClient struct {
|
||
logger *zap.Logger
|
||
client *resty.Client
|
||
baseURL string
|
||
authKey string
|
||
redisClient *redis.Client
|
||
tgSender *senders.TgSender
|
||
tgChatID int64
|
||
}
|
||
|
||
func NewGigaChatClient(ctx context.Context, deps Deps) (*GigaChatClient, error) {
|
||
client := &GigaChatClient{
|
||
logger: deps.Logger,
|
||
client: deps.Client,
|
||
baseURL: deps.BaseURL,
|
||
authKey: deps.AuthKey,
|
||
redisClient: deps.RedisClient,
|
||
|
||
tgSender: deps.TgSender,
|
||
tgChatID: deps.TgChatID,
|
||
}
|
||
|
||
if err := client.updateToken(ctx); err != nil {
|
||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||
}
|
||
|
||
return client, nil
|
||
}
|
||
|
||
func (r *GigaChatClient) SendMsg(ctx context.Context, audience model.GigaChatAudience, question model.Question) (string, error) {
|
||
gender := ""
|
||
|
||
switch audience.Sex {
|
||
case 0:
|
||
gender = "женский"
|
||
case 1:
|
||
gender = "мужской"
|
||
case 2:
|
||
gender = "не имеет значения"
|
||
}
|
||
|
||
userInput := fmt.Sprintf(model.ReworkQuestionPrompt, audience.Age, gender, question.Title, question.Description)
|
||
|
||
token, err := r.redisClient.Get(ctx, "gigachat_token").Result()
|
||
if err != nil {
|
||
r.logger.Error("failed to get token from redis", zap.Error(err))
|
||
return "", err
|
||
}
|
||
|
||
reqBody := model.GigaChatRequest{
|
||
Model: "GigaChat-2-Max",
|
||
Stream: false,
|
||
UpdateInterval: 0,
|
||
Messages: []model.GigaChatMessage{
|
||
{Role: "system", Content: model.CreatePrompt},
|
||
{Role: "user", Content: userInput},
|
||
},
|
||
}
|
||
|
||
var response model.GigaChatResponse
|
||
|
||
resp, err := r.client.R().
|
||
SetHeader("Content-Type", "application/json").
|
||
SetHeader("Authorization", "Bearer "+token).
|
||
SetBody(reqBody).
|
||
SetResult(&response).
|
||
Post(r.baseURL + "/chat/completions")
|
||
|
||
if err != nil {
|
||
r.logger.Error("failed send request to GigaChat", zap.Error(err))
|
||
return "", err
|
||
}
|
||
|
||
if resp.IsError() {
|
||
errMsg := fmt.Sprintf("error GigaChat API: %s", resp.Status())
|
||
r.logger.Error(errMsg)
|
||
return "", errors.New(errMsg)
|
||
}
|
||
|
||
if len(response.Choices) == 0 || response.Choices[0].Message.Content == "" {
|
||
// когда возникает такая ошибка то значит еще траим отправить запрос
|
||
return "", model.EmptyResponseErrorGigaChat
|
||
}
|
||
|
||
return response.Choices[0].Message.Content, nil
|
||
}
|
||
|
||
func (r *GigaChatClient) TokenResearch(ctx context.Context) {
|
||
ticker := time.NewTicker(time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
ttl, err := r.redisClient.TTL(ctx, "gigachat_token").Result()
|
||
fmt.Println("GGCHATtoken", ttl, err, ttl < 2*time.Minute)
|
||
if err != nil || ttl < 2*time.Minute {
|
||
if err := r.updateToken(ctx); err != nil {
|
||
r.logger.Error("failed to update GigaChat token", zap.Error(err))
|
||
} else {
|
||
r.logger.Info("successfully updated GigaChat token")
|
||
}
|
||
}
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func (r *GigaChatClient) updateToken(ctx context.Context) error {
|
||
formData := "scope=GIGACHAT_API_B2B"
|
||
|
||
var respData struct {
|
||
AccessToken string `json:"access_token"`
|
||
ExpiresAt int64 `json:"expires_at"`
|
||
}
|
||
|
||
resp, err := r.client.R().
|
||
SetHeader("Authorization", "Basic "+r.authKey).
|
||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||
SetBody(formData).
|
||
SetHeader("RqUID", uuid.New().String()).
|
||
SetResult(&respData).
|
||
Post("https://ngw.devices.sberbank.ru:9443/api/v2/oauth")
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if resp.IsError() {
|
||
return fmt.Errorf("token request failed: %s", resp.Status())
|
||
}
|
||
|
||
ttl := time.Until(time.Unix(int64(respData.ExpiresAt/1000), 0))
|
||
fmt.Println("GGCTOKENEXP", respData.ExpiresAt, ttl, ttl < 2*time.Minute, time.Now())
|
||
err = r.redisClient.Set(ctx, "gigachat_token", respData.AccessToken, ttl).Err()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to save token to redis: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *GigaChatClient) getBalance(ctx context.Context) (int, error) {
|
||
var respData struct {
|
||
Balance []struct {
|
||
Usage string `json:"usage"`
|
||
Value int `json:"value"`
|
||
} `json:"balance"`
|
||
}
|
||
|
||
token, err := r.redisClient.Get(ctx, "gigachat_token").Result()
|
||
if err != nil {
|
||
r.logger.Error("failed to get token from redis", zap.Error(err))
|
||
return 0, err
|
||
}
|
||
|
||
resp, err := r.client.R().
|
||
SetHeader("Authorization", "Bearer "+token).
|
||
SetResult(&respData).
|
||
Get(r.baseURL + "/balance")
|
||
|
||
if err != nil {
|
||
return 0, fmt.Errorf("failed to fetch balance: %w", err)
|
||
}
|
||
if resp.IsError() {
|
||
return 0, fmt.Errorf("balance request failed: %s", resp.Status())
|
||
}
|
||
|
||
// прверяем то что используем для переформулирования
|
||
for _, item := range respData.Balance {
|
||
if item.Usage == "GigaChat-Max" {
|
||
return item.Value, nil
|
||
}
|
||
}
|
||
|
||
return 0, errors.New("no used models found")
|
||
}
|
||
|
||
func (r *GigaChatClient) MonitorTokenBalance(ctx context.Context) {
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
alert := false // чтоб не спамить каждые 5 минут
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
balance, err := r.getBalance(ctx)
|
||
if err != nil {
|
||
r.logger.Error("failed to get GigaChat token", zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
if balance < 500_000 && !alert {
|
||
msg := fmt.Sprintf("Остаток токенов в GigaChat упал ниже 500000.\nТекущий баланс: %d токенов.", balance)
|
||
if err := r.tgSender.SendMessage(r.tgChatID, msg); err != nil {
|
||
r.logger.Error("failed to send Telegram alert", zap.Error(err))
|
||
} else {
|
||
alert = true
|
||
}
|
||
}
|
||
|
||
if balance >= 500_000 && alert {
|
||
alert = false
|
||
}
|
||
case <-ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|