package middleware import ( "context" "fmt" "github.com/pkg/errors" "github.com/sirupsen/logrus" "go-micro.dev/v4" "go-micro.dev/v4/auth" "go-micro.dev/v4/metadata" "go-micro.dev/v4/server" "sghgogs.com/micro/common" "sghgogs.com/micro/common/errorcode" "sghgogs.com/micro/k8s-service/utils/authutil" "sort" "strings" ) const ( // BearerScheme used for Authorization header. BearerScheme = "Bearer " // ScopePublic is the scope applied to a rule to allow access to the public. ScopePublic = "" // ScopeAccount is the scope applied to a rule to limit to users with any valid account. ScopeAccount = "*" name = "kubernetesservice" ) var ( // ErrInvalidToken is when the token provided is not valid. ErrInvalidToken = errors.New("invalid token provided") // ErrForbidden is when a user does not have the necessary scope to access a resource. ErrForbidden = errors.New("resource forbidden") ) // type Access int const ( // AccessGranted to a resource. AccessGranted auth.Access = iota // AccessDenied to a resource. AccessDenied ) func Verify(rules []*auth.Rule, acc *auth.Account, res *auth.Resource) error { // the rule is only to be applied if the type matches the resource or is catch-all (*) validTypes := []string{"*", res.Type} // the rule is only to be applied if the name matches the resource or is catch-all (*) validNames := []string{"*", res.Name} // rules can have wildcard excludes on endpoints since this can also be a path for web services, // e.g. /foo/* would include /foo/bar. We also want to check for wildcards and the exact endpoint validEndpoints := []string{"*", res.Endpoint} if comps := strings.Split(res.Endpoint, "/"); len(comps) > 1 { for i := 1; i < len(comps)+1; i++ { wildcard := fmt.Sprintf("%v/*", strings.Join(comps[0:i], "/")) validEndpoints = append(validEndpoints, wildcard) } } // filter the rules to the ones which match the criteria above filteredRules := make([]*auth.Rule, 0) for _, rule := range rules { if !include(validTypes, rule.Resource.Type) { continue } if !include(validNames, rule.Resource.Name) { continue } if !include(validEndpoints, rule.Resource.Endpoint) { continue } filteredRules = append(filteredRules, rule) } fmt.Println("2-") // sort the filtered rules by priority, highest to lowest sort.SliceStable(filteredRules, func(i, j int) bool { return filteredRules[i].Priority > filteredRules[j].Priority }) fmt.Println("3-") // loop through the rules and check for a rule which applies to this account for _, rule := range filteredRules { // a blank scope indicates the rule applies to everyone, even nil accounts if rule.Scope == ScopePublic && rule.Access == AccessDenied { return ErrForbidden } else if rule.Scope == ScopePublic && rule.Access == AccessGranted { return nil } // all further checks require an account if acc == nil { continue } // this rule applies to any account if rule.Scope == ScopeAccount && rule.Access == AccessDenied { return ErrForbidden } else if rule.Scope == ScopeAccount && rule.Access == AccessGranted { return nil } // 去掉首尾的方括号 // if the account has the necessary scope if include(acc.Scopes, rule.Scope) && rule.Access == AccessDenied { return ErrForbidden } else if include(acc.Scopes, rule.Scope) && rule.Access == AccessGranted { return nil } } // if no rules matched then return forbidden return ErrForbidden } // include is a helper function which checks to see if the slice contains the value. includes is // not case sensitive. func include(slice []string, val string) bool { // str := slice if len(slice) > 0 { if strings.Contains(slice[0], ",") { data := strings.Split(slice[0], ",") // 打印结果 for _, s := range data { if s == "super_admin" { return true } if strings.EqualFold(s, val) { return true } } // 判断超级管理员 } else { // 判断超级管理员 for _, s := range slice { if s == "super_admin" { return true } if strings.EqualFold(s, val) { return true } } } return false } return false } // var ( // // catchallResource = &auth.Resource{ // // Type: "*", // // Name: "*", // // Endpoint: "*", // // } // // // // getAuthentication = &auth.Resource{ // Type: "user", // Name: name, // Endpoint: "AuthenticationService.GetAuthentication", // } // // catchallResource // rulesItems = []*auth.Rule{ // // {Scope: "*", Resource: catchallResource}, toggleAdminRole // {Scope: "kubernetes", Resource: getAuthentication, ID: uuid.New().String(), Priority: 1}, // } // ) func NewAuthWrapper(service micro.Service) server.HandlerWrapper { return func(h server.HandlerFunc) server.HandlerFunc { return func(ctx context.Context, req server.Request, rsp interface{}) error { logrus.Infof("[wrapper] server request: %v", req.Endpoint()) if req.Endpoint() == "CommonService.AdminLogin" { return h(ctx, req, rsp) } if req.Endpoint() == "AdminUserService.GetAdminUserAssociatedRoles" { return h(ctx, req, rsp) } // Fetch metadata from context (request headers). md, b := metadata.FromContext(ctx) if !b { return errorcode.Unauthorized("authorization service", common.ErrorMessage[common.UnauthorizedErrorCode]) // errors.New("no metadata found") } // local ip of service fmt.Println("local ip is", md["Local"]) // remote ip of caller fmt.Println("remote ip is", md["Remote"]) // Get auth header. authHeader, ok := md["Authorization"] if !ok || !strings.HasPrefix(authHeader, auth.BearerScheme) { logrus.Error("no auth token provided") return errorcode.Unauthorized("authorization service", common.ErrorMessage[common.UnauthorizedErrorCode]) } // Extract auth token. token := strings.TrimPrefix(authHeader, auth.BearerScheme) // Extract account from token. token = strings.TrimSpace(token) a := service.Options().Auth acc, err := a.Inspect(token) fmt.Println("acc", acc) if err != nil { return errorcode.Unauthorized("authorization service", common.ErrorMessage[common.TokenInvalidErrorCode]) } // 校验redis 存储数据 blacklisted, err := authutil.JWTAuthService.IsBlacklisted(token) if err == nil && blacklisted { return errorcode.Unauthorized("authorization service", common.ErrorMessage[common.ExpiredLonInAgainErrorCode]) } // Create resource for current endpoint from request headers. currentResource := auth.Resource{ Type: "user", Name: md["Micro-Service"], Endpoint: md["Micro-Endpoint"], } fmt.Println("acc.Scopes", acc.Scopes) // Verify if account has access. 验证帐户是否具有访问权限。 if err = Verify(authutil.JWTAuthService.GetRuleItems(), acc, ¤tResource); err != nil { return errorcode.Unauthorized("authorization service", common.ErrorMessage[common.NoAccessErrorCode]) } // 验证通过后记录操作日志x logrus.Infof("User %s is performing operation %s body %v", acc.ID, req.Endpoint(), req.Body()) return h(ctx, req, rsp) } } }