middleware.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/pkg/errors"
  6. "github.com/sirupsen/logrus"
  7. "go-micro.dev/v4"
  8. "go-micro.dev/v4/auth"
  9. "go-micro.dev/v4/metadata"
  10. "go-micro.dev/v4/server"
  11. "sghgogs.com/micro/common"
  12. "sghgogs.com/micro/common/errorcode"
  13. "sghgogs.com/micro/shopping-service/utils/authutil"
  14. "sort"
  15. "strings"
  16. )
  17. const (
  18. // BearerScheme used for Authorization header.
  19. BearerScheme = "Bearer "
  20. // ScopePublic is the scope applied to a rule to allow access to the public.
  21. ScopePublic = ""
  22. // ScopeAccount is the scope applied to a rule to limit to users with any valid account.
  23. ScopeAccount = "*"
  24. )
  25. var (
  26. // ErrInvalidToken is when the token provided is not valid.
  27. ErrInvalidToken = errors.New("invalid token provided")
  28. // ErrForbidden is when a user does not have the necessary scope to access a resource.
  29. ErrForbidden = errors.New("resource forbidden")
  30. )
  31. // type Access int
  32. const (
  33. // AccessGranted to a resource.
  34. AccessGranted auth.Access = iota
  35. // AccessDenied to a resource.
  36. AccessDenied
  37. )
  38. func Verify(rules []*auth.Rule, acc *auth.Account, res *auth.Resource) error {
  39. // the rule is only to be applied if the type matches the resource or is catch-all (*)
  40. validTypes := []string{"*", res.Type}
  41. // the rule is only to be applied if the name matches the resource or is catch-all (*)
  42. validNames := []string{"*", res.Name}
  43. // rules can have wildcard excludes on endpoints since this can also be a path for web services,
  44. // e.g. /foo/* would include /foo/bar. We also want to check for wildcards and the exact endpoint
  45. validEndpoints := []string{"*", res.Endpoint}
  46. if comps := strings.Split(res.Endpoint, "/"); len(comps) > 1 {
  47. for i := 1; i < len(comps)+1; i++ {
  48. wildcard := fmt.Sprintf("%v/*", strings.Join(comps[0:i], "/"))
  49. validEndpoints = append(validEndpoints, wildcard)
  50. }
  51. }
  52. // filter the rules to the ones which match the criteria above
  53. filteredRules := make([]*auth.Rule, 0)
  54. for _, rule := range rules {
  55. if !include(validTypes, rule.Resource.Type) {
  56. continue
  57. }
  58. if !include(validNames, rule.Resource.Name) {
  59. continue
  60. }
  61. if !include(validEndpoints, rule.Resource.Endpoint) {
  62. continue
  63. }
  64. filteredRules = append(filteredRules, rule)
  65. }
  66. // sort the filtered rules by priority, highest to lowest
  67. sort.SliceStable(filteredRules, func(i, j int) bool {
  68. return filteredRules[i].Priority > filteredRules[j].Priority
  69. })
  70. // loop through the rules and check for a rule which applies to this account
  71. for _, rule := range filteredRules {
  72. // a blank scope indicates the rule applies to everyone, even nil accounts
  73. if rule.Scope == ScopePublic && rule.Access == AccessDenied {
  74. return ErrForbidden
  75. } else if rule.Scope == ScopePublic && rule.Access == AccessGranted {
  76. return nil
  77. }
  78. // all further checks require an account
  79. if acc == nil {
  80. continue
  81. }
  82. // this rule applies to any account
  83. if rule.Scope == ScopeAccount && rule.Access == AccessDenied {
  84. return ErrForbidden
  85. } else if rule.Scope == ScopeAccount && rule.Access == AccessGranted {
  86. return nil
  87. }
  88. // 去掉首尾的方括号
  89. // if the account has the necessary scope
  90. if include(acc.Scopes, rule.Scope) && rule.Access == AccessDenied {
  91. return ErrForbidden
  92. } else if include(acc.Scopes, rule.Scope) && rule.Access == AccessGranted {
  93. return nil
  94. }
  95. }
  96. // if no rules matched then return forbidden
  97. return ErrForbidden
  98. }
  99. // include is a helper function which checks to see if the slice contains the value. includes is
  100. // not case sensitive.
  101. func include(slice []string, val string) bool {
  102. // str := slice
  103. if len(slice) > 0 {
  104. if strings.Contains(slice[0], ",") {
  105. data := strings.Split(slice[0], ",")
  106. // 打印结果
  107. for _, s := range data {
  108. if s == "super_admin" {
  109. return true
  110. }
  111. if strings.EqualFold(s, val) {
  112. return true
  113. }
  114. }
  115. // 判断超级管理员
  116. } else {
  117. // 判断超级管理员
  118. for _, s := range slice {
  119. if s == "super_admin" {
  120. return true
  121. }
  122. if strings.EqualFold(s, val) {
  123. return true
  124. }
  125. }
  126. }
  127. return false
  128. }
  129. return false
  130. }
  131. // var (
  132. // // catchallResource = &auth.Resource{
  133. // // Type: "*",
  134. // // Name: "*",
  135. // // Endpoint: "*",
  136. // // }
  137. // //
  138. //
  139. // getAuthentication = &auth.Resource{
  140. // Type: "user",
  141. // Name: name,
  142. // Endpoint: "AuthenticationService.GetAuthentication",
  143. // }
  144. // // catchallResource
  145. // rulesItems = []*auth.Rule{
  146. // // {Scope: "*", Resource: catchallResource}, toggleAdminRole
  147. // {Scope: "kubernetes", Resource: getAuthentication, ID: uuid.New().String(), Priority: 1},
  148. // }
  149. // )
  150. func NewAuthWrapper(service micro.Service, namespace string) server.HandlerWrapper {
  151. return func(h server.HandlerFunc) server.HandlerFunc {
  152. return func(ctx context.Context, req server.Request, rsp interface{}) error {
  153. logrus.Infof("[wrapper] server request: %v", req.Endpoint())
  154. if req.Endpoint() == "ShoppingAuthService.Login" {
  155. return h(ctx, req, rsp)
  156. }
  157. // Fetch metadata from context (request headers).
  158. md, b := metadata.FromContext(ctx)
  159. if !b {
  160. return errorcode.Unauthorized(namespace, common.ErrorMessage[common.UnauthorizedErrorCode])
  161. // errors.New("no metadata found")
  162. }
  163. // local ip of service
  164. fmt.Println("local ip is", md["Local"])
  165. // remote ip of caller
  166. fmt.Println("remote ip is", md["Remote"])
  167. // Get auth header.
  168. authHeader, ok := md["Authorization"]
  169. if !ok || !strings.HasPrefix(authHeader, auth.BearerScheme) {
  170. logrus.Error("no auth token provided")
  171. return errorcode.Unauthorized(namespace, common.ErrorMessage[common.UnauthorizedErrorCode])
  172. }
  173. // Extract auth token.
  174. token := strings.TrimPrefix(authHeader, auth.BearerScheme)
  175. // Extract account from token.
  176. token = strings.TrimSpace(token)
  177. a := service.Options().Auth
  178. acc, err := a.Inspect(token)
  179. fmt.Println("acc", acc)
  180. if err != nil {
  181. return errorcode.Unauthorized(namespace, common.ErrorMessage[common.TokenInvalidErrorCode])
  182. }
  183. // 校验redis 存储数据
  184. blacklisted, err := authutil.JWTAuthService.IsBlacklisted(token)
  185. if err == nil && blacklisted {
  186. return errorcode.Unauthorized(namespace, common.ErrorMessage[common.ExpiredLonInAgainErrorCode])
  187. }
  188. // // Create resource for current endpoint from request headers.
  189. currentResource := auth.Resource{
  190. Type: "user",
  191. Name: md["Micro-Service"],
  192. Endpoint: md["Micro-Endpoint"],
  193. }
  194. // fmt.Println("acc.Scopes", acc.Scopes)
  195. // // Verify if account has access. 验证帐户是否具有访问权限。
  196. if err = Verify(authutil.JWTAuthService.GetRuleItems(), acc, &currentResource); err != nil {
  197. return errorcode.Unauthorized("authorization service", common.ErrorMessage[common.NoAccessErrorCode])
  198. }
  199. // 验证通过后记录操作日志x
  200. logrus.Infof("User %s is performing operation %s body %v", acc.ID, req.Endpoint(), req.Body())
  201. return h(ctx, req, rsp)
  202. }
  203. }
  204. }