diff --git a/internal/controller/tariff_external/controller.go b/internal/controller/tariff_external/controller.go index a95b9ad..327f845 100644 --- a/internal/controller/tariff_external/controller.go +++ b/internal/controller/tariff_external/controller.go @@ -9,6 +9,7 @@ import ( our_errors "hub_admin_backend_service/internal/errors" "hub_admin_backend_service/internal/models" "hub_admin_backend_service/internal/repository/tariff" + "hub_admin_backend_service/internal/tools" ) type Deps struct { @@ -81,8 +82,13 @@ func (t *TariffExternal) Create(ctx *fiber.Ctx) error { if err := ctx.BodyParser(&req); err != nil { return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request payload"}) } - req.UserID = userID + + err := tools.ValidateTariff(req) + if err != nil { + return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + } + result, err := t.repo.Create(ctx.Context(), req) if err != nil { return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) diff --git a/internal/controller/tariff_internal/controller.go b/internal/controller/tariff_internal/controller.go index f42b607..04d2f12 100644 --- a/internal/controller/tariff_internal/controller.go +++ b/internal/controller/tariff_internal/controller.go @@ -9,6 +9,7 @@ import ( our_errors "hub_admin_backend_service/internal/errors" "hub_admin_backend_service/internal/models" "hub_admin_backend_service/internal/repository/tariff" + "hub_admin_backend_service/internal/tools" ) // todo middleware jwt @@ -73,6 +74,11 @@ func (t *TariffInternal) Create(ctx *fiber.Ctx) error { return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request payload"}) } + err := tools.ValidateTariff(req) + if err != nil { + return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + } + result, err := t.repo.Create(ctx.Context(), req) if err != nil { return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()}) @@ -118,6 +124,11 @@ func (t *TariffInternal) Update(ctx *fiber.Ctx) error { return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request payload"}) } + err := tools.ValidateTariff(req) + if err != nil { + return ctx.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + } + id := ctx.Params("id") objID, err := primitive.ObjectIDFromHex(id) if err != nil { diff --git a/internal/initialize/repository.go b/internal/initialize/repository.go index 63bfac3..45a8487 100644 --- a/internal/initialize/repository.go +++ b/internal/initialize/repository.go @@ -24,8 +24,9 @@ func NewRepository(deps RepositoryDeps) *Repository { Logger: deps.Logger, }), TariffRepo: tariff.NewTariffRepo(tariff.Deps{ - Mdb: deps.Mdb.Collection("tariffs"), - Logger: deps.Logger, + Mdb: deps.Mdb.Collection("tariffs"), + Logger: deps.Logger, + PrivilegeDB: deps.Mdb.Collection("privileges"), }), } } diff --git a/internal/repository/tariff/tariff.go b/internal/repository/tariff/tariff.go index 116d70c..cbfac2a 100644 --- a/internal/repository/tariff/tariff.go +++ b/internal/repository/tariff/tariff.go @@ -14,19 +14,22 @@ import ( ) type Deps struct { - Mdb *mongo.Collection - Logger *zap.Logger + Mdb *mongo.Collection + Logger *zap.Logger + PrivilegeDB *mongo.Collection } type Tariff struct { - mdb *mongo.Collection - logger *zap.Logger + mdb *mongo.Collection + logger *zap.Logger + privilegeDB *mongo.Collection } func NewTariffRepo(deps Deps) *Tariff { return &Tariff{ - mdb: deps.Mdb, - logger: deps.Logger, + mdb: deps.Mdb, + logger: deps.Logger, + privilegeDB: deps.PrivilegeDB, } } @@ -121,5 +124,67 @@ func (t *Tariff) SoftDelete(ctx context.Context, tariffID primitive.ObjectID) (m func (t *Tariff) Update(ctx context.Context, tariffID primitive.ObjectID, req models.Tariff) (models.Tariff, error) { var tariff models.Tariff + err := t.mdb.FindOne(ctx, bson.M{"_id": tariffID}).Decode(&tariff) + if err == mongo.ErrNoDocuments { + return tariff, errors.ErrNotFound + } else if err != nil { + t.logger.Error("failed find tariff", zap.Error(err)) + return tariff, err + } + + privilegeIDs := make([]string, len(req.Privileges)) + for i, privilege := range req.Privileges { + privilegeIDs[i] = privilege.PrivilegeID + } + + cursor, err := t.privilegeDB.Find(ctx, bson.M{"privilegeId": bson.M{"$in": privilegeIDs}}) + if err != nil { + t.logger.Error("failed find privileges", zap.Error(err)) + return tariff, err + } + defer cursor.Close(ctx) + + privilegeMap := make(map[string]models.Privilege) + for cursor.Next(ctx) { + var privilege models.Privilege + if err := cursor.Decode(&privilege); err != nil { + t.logger.Error("failed decode privilege", zap.Error(err)) + return tariff, err + } + privilegeMap[privilege.PrivilegeID] = privilege + } + + clean := make([]models.Privilege, len(req.Privileges)) + for i, privilege := range req.Privileges { + origPrivilege := privilegeMap[privilege.PrivilegeID] + clean[i] = models.Privilege{ + Name: origPrivilege.Name, + PrivilegeID: origPrivilege.PrivilegeID, + ServiceKey: origPrivilege.ServiceKey, + Description: origPrivilege.Description, + Type: origPrivilege.Type, + Value: origPrivilege.Value, + Price: origPrivilege.Price, + } + } + + update := bson.M{ + "$set": bson.M{ + "order": req.Order, + "name": req.Name, + "price": req.Price, + "isCustom": req.IsCustom, + "privileges": clean, + }, + } + + err = t.mdb.FindOneAndUpdate(ctx, bson.M{"_id": tariffID}, update).Decode(&tariff) + if err == mongo.ErrNoDocuments { + return tariff, errors.ErrNotFound + } else if err != nil { + t.logger.Error("failed update tariff", zap.Error(err)) + return tariff, err + } + return tariff, nil } diff --git a/internal/tools/validate.go b/internal/tools/validate.go index f517607..053bb40 100644 --- a/internal/tools/validate.go +++ b/internal/tools/validate.go @@ -1,6 +1,7 @@ package tools import ( + "errors" "hub_admin_backend_service/internal/models" ) @@ -11,3 +12,21 @@ func Validate(req models.CreateUpdateReq) bool { return true } + +func ValidateTariff(tariff models.Tariff) error { + if tariff.Name == "" { + return errors.New("name is required") + } + if tariff.Price < 0 { + return errors.New("invalid price value") + } + if len(tariff.Privileges) == 0 { + return errors.New("privileges are required") + } + for _, privilege := range tariff.Privileges { + if privilege.PrivilegeID == "" { + return errors.New("privilegeID is required in privileges") + } + } + return nil +}