user.go 5.8 KB


  1. package repository
  2. import (
  3. "fmt"
  4. "github.com/pkg/errors"
  5. "gorm.io/gorm"
  6. "sghgogs.com/micro/shopping-service/domain/model/base"
  7. req "sghgogs.com/micro/shopping-service/domain/model/request"
  8. pb "sghgogs.com/micro/shopping-service/proto"
  9. "sghgogs.com/micro/shopping-service/utils"
  10. )
  11. const (
  12. UserGroupID = 2
  13. )
  14. func (u *Repository) IsUserExists(identifier interface{}) (bool, error) {
  15. var user req.User
  16. if err := u.db.
  17. Where("id = ? OR username = ?", identifier, identifier).
  18. Where("group_id", utils.UserUserGroupID).
  19. Select("id, username").
  20. First(&user).Error; err != nil {
  21. if errors.Is(err, gorm.ErrRecordNotFound) {
  22. return false, nil // 记录不存在,账号不存在
  23. }
  24. return false, err // 发生其他错误
  25. }
  26. return true, nil
  27. }
  28. func (u *Repository) GetUserList(query *pb.GetUserListRequest) ([]req.User, int64, error) {
  29. tx := u.db.Model(&req.User{}).
  30. Where("group_id = ?", utils.UserUserGroupID).
  31. Select("id, username, phone_number, email, avatar, created_at, created_by, updated_at, updated_by, status").
  32. Order("id desc")
  33. if query.Keyword != "" {
  34. tx.Where("username = ? OR phone_number = ?", query.Keyword, query.Keyword)
  35. }
  36. if base.IsStatusEnum(query.Status) {
  37. tx.Where("status = ?", query.Status)
  38. }
  39. var totalCount int64
  40. tx.Count(&totalCount)
  41. users := make([]req.User, 0)
  42. return users, totalCount, tx.Limit(int(query.PageSize)).Offset(int((query.Page - 1) * query.PageSize)).Find(&users).Error
  43. }
  44. func (u *Repository) GetUser(userID int64) (req.User, error) {
  45. var user req.User
  46. return user, u.db.
  47. Where("id = ? AND group_id = ?", userID, utils.UserUserGroupID).
  48. Select("id, username, phone_number, email, avatar, created_at, created_by, updated_at, updated_by, status").
  49. // Preload("Roles").
  50. Find(&user).Error
  51. }
  52. func (u *Repository) CreateUser(user *req.User) error {
  53. // 开始事务
  54. tx := u.db.Begin()
  55. // 错误处理
  56. defer func() {
  57. if r := recover(); r != nil {
  58. tx.Rollback()
  59. }
  60. }()
  61. if err := tx.Model(&req.User{}).Create(user).Error; err != nil {
  62. tx.Rollback()
  63. return err
  64. }
  65. var role req.Role
  66. if err := tx.Model(&req.Role{}).Where("name = ?", "user").First(&role).Error; err != nil {
  67. tx.Rollback()
  68. return err
  69. }
  70. if err := tx.Model(&user).Association("Roles").Append(&req.Role{ID: role.ID}); err != nil {
  71. tx.Rollback()
  72. return err
  73. }
  74. return tx.Commit().Error
  75. }
  76. func (u *Repository) UpdateUser(userID int64, user map[string]interface{}) error {
  77. // 开始事务
  78. tx := u.db.Begin()
  79. // 错误处理
  80. defer func() {
  81. if r := recover(); r != nil {
  82. tx.Rollback()
  83. }
  84. }()
  85. if err := tx.Model(&req.User{}).Where("id = ?", userID).Updates(user).Error; err != nil {
  86. tx.Rollback()
  87. return err
  88. }
  89. return tx.Commit().Error
  90. }
  91. func (u *Repository) DeleteUser(userID int64) error {
  92. // 开始事务
  93. tx := u.db.Begin()
  94. // 错误处理
  95. defer func() {
  96. if r := recover(); r != nil {
  97. tx.Rollback()
  98. }
  99. }()
  100. var user req.User
  101. if err := tx.First(&user, userID).Error; err != nil {
  102. tx.Rollback()
  103. return err
  104. }
  105. if err := tx.Model(&user).Association("Roles").Clear(); err != nil {
  106. tx.Rollback()
  107. return err
  108. }
  109. if err := tx.Model(&req.User{}).Unscoped().Delete(&req.User{ID: userID}).Error; err != nil {
  110. tx.Rollback()
  111. return err
  112. }
  113. return tx.Commit().Error
  114. }
  115. func (u *Repository) ToggleUser(userID int64, enum pb.StatusEnum, data map[string]interface{}) error {
  116. // pb.StatusEnum_DELETED {}
  117. fmt.Println("data", data)
  118. // 1. 开启事务
  119. tx := u.db.Begin()
  120. // 错误处理
  121. defer func() {
  122. if r := recover(); r != nil {
  123. tx.Rollback()
  124. }
  125. }()
  126. var user req.User
  127. // 1.查询角色
  128. if err := tx.First(&user, userID).Error; err != nil {
  129. tx.Rollback()
  130. return err
  131. }
  132. if enum == pb.StatusEnum_DELETED {
  133. // 1.1删除关联角色
  134. if err := tx.Model(&user).Association("Roles").Clear(); err != nil {
  135. tx.Rollback()
  136. return err
  137. }
  138. }
  139. // 2. 更新状态
  140. if err := tx.Model(&req.User{}).Where("id = ?", userID).Updates(data).Error; err != nil {
  141. tx.Rollback()
  142. return err
  143. }
  144. return tx.Commit().Error
  145. }
  146. // func (u *Repository) ToggleUser(userID int64, user map[string]interface{}) error {
  147. // // 开始事务
  148. // tx := u.db.Begin()
  149. // // 错误处理
  150. // defer func() {
  151. // if r := recover(); r != nil {
  152. // tx.Rollback()
  153. // }
  154. // }()
  155. // if err := tx.Model(&req.User{}).Where("id = ?", userID).Updates(&user).Error; err != nil {
  156. // tx.Rollback()
  157. // return err
  158. // }
  159. // return tx.Commit().Error
  160. // }
  161. //
  162. // func (u *Repository) DeleteUser(userID int64) error {
  163. // // 开始事务
  164. // tx := u.db.Begin()
  165. // // 错误处理
  166. // defer func() {
  167. // if r := recover(); r != nil {
  168. // tx.Rollback()
  169. // }
  170. // }()
  171. // var user req.User
  172. // if err := tx.First(&user, userID).Error; err != nil {
  173. // tx.Rollback()
  174. // return err
  175. // }
  176. // if err := tx.Model(&user).Association("Roles").Clear(); err != nil {
  177. // tx.Rollback()
  178. // return err
  179. // }
  180. // if err := tx.Model(&req.User{}).Delete(&req.User{
  181. // ID: userID,
  182. // }).Error; err != nil {
  183. // tx.Rollback()
  184. // return err
  185. // }
  186. // return tx.Commit().Error
  187. // }
  188. // 检测用户ID如果不存在会抛出错误
  189. func (u *Repository) getUserValidIDs(ids []int64) ([]int64, error) {
  190. var validIDs []int64
  191. result := u.db.Model(&req.User{}).Where("id IN ?", ids).Pluck("id", &validIDs)
  192. if result.Error != nil {
  193. return nil, result.Error
  194. }
  195. // 检查是否有未找到的 ID
  196. if result.RowsAffected != int64(len(ids)) {
  197. // 找到的 ID 数量和传入的数量不一致,说明有不存在的 ID
  198. missingIDs := findMissingIDs(ids, validIDs)
  199. errMsg := fmt.Sprintf("Some IDs do not exist: %v", missingIDs)
  200. return nil, errors.New(errMsg)
  201. }
  202. return validIDs, nil
  203. }
  204. func findMissingIDs(allIDs, foundIDs []int64) []int64 {
  205. var missingIDs []int64
  206. foundSet := make(map[int64]bool)
  207. for _, id := range foundIDs {
  208. foundSet[id] = true
  209. }
  210. for _, id := range allIDs {
  211. if !foundSet[id] {
  212. missingIDs = append(missingIDs, id)
  213. }
  214. }
  215. return missingIDs
  216. }