package repository import ( "fmt" "github.com/pkg/errors" "gorm.io/gorm" "sghgogs.com/micro/shopping-service/domain/model/base" req "sghgogs.com/micro/shopping-service/domain/model/request" pb "sghgogs.com/micro/shopping-service/proto" "time" ) func (u *Repository) IsRoleExists(identifier interface{}) (bool, error) { var user req.Role if err := u.db. Where("id = ? OR name = ?", identifier, identifier). Select("id, name"). First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return false, nil // 记录不存在,账号不存在 } return false, err // 发生其他错误 } return true, nil } func (u *Repository) GetAllRoles() ([]*req.Role, error) { roles := make([]*req.Role, 0) return roles, u.db.Model(&req.Role{}). Where("status = ?", pb.StatusEnum_ENABLED). Select("id, name"). Order("id desc"). Find(&roles).Error } func (u *Repository) GetRoleList(query *pb.GetRoleListRequest) ([]*req.Role, int64, error) { var totalCount int64 tx := u.db.Model(&req.Role{}).Order("id desc"). Select("id, name, description, created_by, created_at, updated_at, updated_by, status, is_reserved") if query.Keyword != "" { tx.Where("name = ?", query.Keyword) } if base.IsStatusEnum(query.Status) { tx.Where("status = ?", query.Status) } tx.Count(&totalCount) roles := make([]*req.Role, 0) return roles, totalCount, tx.Limit(int(query.PageSize)).Offset(int((query.Page - 1) * query.PageSize)).Find(&roles).Error } func (u *Repository) GetRole(roleID int64) (*req.Role, error) { var role req.Role return &role, u.db.Where("id = ? ", roleID). Where("status = ?", pb.StatusEnum_ENABLED). Preload("Permissions", "status = ?", pb.StatusEnum_ENABLED). Preload("Users", "status = ?", pb.StatusEnum_ENABLED). First(&role).Error } func (u *Repository) AllRoles() ([]*req.Role, error) { roles := make([]*req.Role, 0) return roles, u.db.Model(&req.Role{}). Where("status = ?", pb.StatusEnum_ENABLED). Preload("Permissions", "status = ?", pb.StatusEnum_ENABLED). Find(&roles).Error } func (u *Repository) CreateRole(data *pb.CreateRoleRequest) error { // 开始事务 tx := u.db.Begin() // 错误处理 defer func() { if r := recover(); r != nil { tx.Rollback() } }() role := req.Role{ Name: data.Name, Description: data.Description, CreatedBy: data.CreatedAt, CreatedAt: time.Now(), Status: pb.StatusEnum_ENABLED, IsReserved: false, } if err := tx.Model(&req.Role{}).Create(&role).Error; err != nil { tx.Rollback() return err } if len(data.Users) > 0 { if _, err := u.getUserValidIDs(data.Users); err != nil { tx.Rollback() return err } users := make([]req.User, 0) for _, ID := range data.Users { users = append(users, req.User{ID: ID}) } if err := tx.Model(&role).Association("Users").Append(&users); err != nil { tx.Rollback() return err } } if len(data.Permissions) > 0 { if _, err := u.getPermissionValidIDs(data.Permissions); err != nil { tx.Rollback() return err } permissions := make([]req.Permission, 0) for _, ID := range data.Permissions { permissions = append(permissions, req.Permission{ID: ID}) } if err := tx.Model(&role).Association("Permissions").Append(&permissions); err != nil { tx.Rollback() return err } } return tx.Commit().Error } func (u *Repository) UpdateRole(query *pb.UpdateRoleRequest) error { // 开始事务 tx := u.db.Begin() // 错误处理 defer func() { if r := recover(); r != nil { tx.Rollback() } }() var oldRole req.Role if err := tx.First(&oldRole, query.RoleId).Error; err != nil { tx.Rollback() return err } if err := tx.Model(&oldRole).Association("Users").Clear(); err != nil { tx.Rollback() return err } if err := tx.Model(&oldRole).Association("Permissions").Clear(); err != nil { tx.Rollback() return err } if len(query.Users) > 0 { if _, err := u.getUserValidIDs(query.Users); err != nil { tx.Rollback() return err } users := make([]req.User, 0) for _, ID := range query.Users { users = append(users, req.User{ID: ID}) } if err := tx.Model(&oldRole).Association("Users").Append(&users); err != nil { tx.Rollback() return err } } if len(query.Permissions) > 0 { if _, err := u.getPermissionValidIDs(query.Permissions); err != nil { tx.Rollback() return err } permissions := make([]req.Permission, 0) for _, ID := range query.Permissions { permissions = append(permissions, req.Permission{ID: ID}) } if err := tx.Model(&oldRole).Association("Permissions").Append(&permissions); err != nil { tx.Rollback() return err } } if err := tx.Model(&oldRole).Updates(map[string]interface{}{ "description": query.Description, "updated_at": time.Now(), "updated_by": query.UpdatedBy, }).Error; err != nil { tx.Rollback() return err } return tx.Commit().Error } func (u *Repository) DeleteRole(roleID int64) error { // 开始事务 tx := u.db.Begin() // 错误处理 defer func() { if r := recover(); r != nil { tx.Rollback() } }() var role req.Role if err := tx.First(&role, roleID).Error; err != nil { tx.Rollback() return err } if err := tx.Model(&role).Association("Users").Clear(); err != nil { tx.Rollback() return err } if err := tx.Model(&role).Association("Permissions").Clear(); err != nil { tx.Rollback() return err } if err := tx.Model(&req.Role{}).Delete(&req.Role{ ID: roleID, }).Error; err != nil { tx.Rollback() return err } return tx.Commit().Error } func (u *Repository) ToggleRole(toggle *pb.ToggleRoleRequest) error { // 开始事务 tx := u.db.Begin() // 错误处理 defer func() { if r := recover(); r != nil { tx.Rollback() } }() var role req.Role // 1.查询角色 if err := tx.First(&role, toggle.RoleId).Error; err != nil { tx.Rollback() return err } if toggle.Status == pb.StatusEnum_DELETED { // 1.1 删除关联用户 if err := tx.Model(&role).Association("Users").Clear(); err != nil { tx.Rollback() return err } // 1.2 移除关联权限 if err := tx.Model(&role).Association("Permissions").Clear(); err != nil { tx.Rollback() return err } } // 2. 更新状态 if err := tx.Model(&role).Updates(map[string]interface{}{ "status": toggle.Status, "updated_at": time.Now(), "updated_by": toggle.UpdatedBy, }).Error; err != nil { tx.Rollback() return err } return tx.Commit().Error } // 检测角色ID如果不存在会抛出错误 func (u *Repository) getRoleValidIDs(ids []int64) ([]int64, error) { var validIDs []int64 result := u.db.Model(&req.Role{}).Where("id IN ?", ids).Pluck("id", &validIDs) if result.Error != nil { return nil, result.Error } // 检查是否有未找到的 ID if result.RowsAffected != int64(len(ids)) { // 找到的 ID 数量和传入的数量不一致,说明有不存在的 ID missingIDs := findMissingIDs(ids, validIDs) errMsg := fmt.Sprintf("Some IDs do not exist: %v", missingIDs) return nil, errors.New(errMsg) } return validIDs, nil }