diff options
Diffstat (limited to '')
-rw-r--r-- | config.go | 47 | ||||
-rw-r--r-- | docs/cca.scfg.example | 18 | ||||
-rw-r--r-- | endpoint_auth.go | 162 | ||||
-rw-r--r-- | errors.go | 3 |
4 files changed, 69 insertions, 161 deletions
@@ -54,12 +54,13 @@ var configWithPointers struct { Conn *string `scfg:"conn"` } `scfg:"db"` Auth struct { - Client *string `scfg:"client"` - Authorize *string `scfg:"authorize"` - Jwks *string `scfg:"jwks"` - Token *string `scfg:"token"` - Secret *string `scfg:"secret"` - Expr *int `scfg:"expr"` + Client *string `scfg:"client"` + Authorize *string `scfg:"authorize"` + Jwks *string `scfg:"jwks"` + Token *string `scfg:"token"` + Expr *int `scfg:"expr"` + Departments *map[string]string `scfg:"depts"` + Udepts *map[string]string `scfg:"udepts"` } `scfg:"auth"` Perf struct { SendQ *int `scfg:"sendq"` @@ -107,12 +108,13 @@ var config struct { Conn string } Auth struct { - Client string - Authorize string - Jwks string - Token string - Secret string - Expr int + Client string + Authorize string + Jwks string + Token string + Expr int + Departments map[string]string + Udepts map[string]string } Perf struct { SendQ int @@ -237,16 +239,27 @@ func fetchConfig(path string) (retErr error) { } config.Auth.Token = *(configWithPointers.Auth.Token) - if configWithPointers.Auth.Secret == nil { - return fmt.Errorf("%w: auth.secret", errMissingConfigValue) - } - config.Auth.Secret = *(configWithPointers.Auth.Secret) - if configWithPointers.Auth.Expr == nil { return fmt.Errorf("%w: auth.expr", errMissingConfigValue) } config.Auth.Expr = *(configWithPointers.Auth.Expr) + if configWithPointers.Auth.Departments == nil { + return fmt.Errorf("%w: auth.depts", errMissingConfigValue) + } + config.Auth.Departments = *(configWithPointers.Auth.Departments) + if config.Auth.Departments == nil { + return fmt.Errorf("%w: auth.depts", errMissingConfigValue) + } + + if configWithPointers.Auth.Udepts == nil { + return fmt.Errorf("%w: auth.udepts", errMissingConfigValue) + } + config.Auth.Udepts = *(configWithPointers.Auth.Udepts) + if config.Auth.Udepts == nil { + return fmt.Errorf("%w: auth.udepts", errMissingConfigValue) + } + if configWithPointers.Perf.SendQ == nil { return fmt.Errorf( "%w: perf.sendq", diff --git a/docs/cca.scfg.example b/docs/cca.scfg.example index 5ce05d3..4ba7f84 100644 --- a/docs/cca.scfg.example +++ b/docs/cca.scfg.example @@ -60,12 +60,24 @@ auth { # What is the URL to the JSON Web Key Set? jwks https://login.microsoftonline.com/common/discovery/keys - - # What is the client secret? Certificates are not supported yet. - secret something # How long, in seconds, should cookies last? expr 604800 + + # Which group IDs mean which departments? + depts { + dc3ab000-6352-4596-9f15-771e0b17f6f1 Y12 + b006d3b8-2ab7-4038-9887-a8276f7ba8e6 Y11 + a51fb4ab-704e-4c7a-b639-b84de0516e57 Y10 + 4bae8dbe-ce80-4b5e-994f-d42f0307bd13 Y9 + } + + # User department overrides + udepts { + fa1f6b2b-0424-41db-bda0-13962abdadf4 Staff + a1a735c0-1ba8-4f08-b4d0-4c6f85552ac7 Staff + 34d4ee3c-6515-4e13-9679-57ccb9ca2835 Staff + } } # The following block contains some tweaks for performance. diff --git a/endpoint_auth.go b/endpoint_auth.go index 0ff3316..e27f74f 100644 --- a/endpoint_auth.go +++ b/endpoint_auth.go @@ -21,13 +21,9 @@ package main import ( - "context" - "encoding/json" "errors" "fmt" "net/http" - "net/url" - "strings" "time" "github.com/MicahParks/keyfunc/v3" @@ -46,9 +42,10 @@ const tokenLength = 20 * zero strings if it hasn't been configured correctly. */ type msclaimsT struct { - Name string `json:"name"` /* Scope: profile */ - Email string `json:"email"` /* Scope: email */ - Oid string `json:"oid"` /* Scope: profile */ + Name string `json:"name"` + Email string `json:"email"` + Oid string `json:"oid"` + Groups []string `json:"groups"` jwt.RegisteredClaims } @@ -153,25 +150,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) (string, int, error) { return "", http.StatusBadRequest, errCannotUnpackClaims } - authorizationCode := req.PostFormValue("code") - - accessToken, err := getAccessToken(req.Context(), authorizationCode) - if err != nil { - return "", -1, err - } - - department, err := getDepartment(req.Context(), *(accessToken.Content)) - if err != nil { - return "", -1, err - } - - switch { - case department == "SJ Co-Curricular Activities Office 松江课外项目办公室": - department = staffDepartment - case department == "Y9" || department == "Y10" || - department == "Y11" || department == "Y12": - default: - return "", http.StatusForbidden, errUnknownDepartment + var department string + var ok bool + department, ok = getDepartmentByUserIDOverride(claims.Oid) + if !ok { + department, ok = getDepartmentByGroups(claims.Groups) + if !ok { + return "", http.StatusBadRequest, errUnknownDepartment + } } cookieValue, err := randomString(tokenLength) @@ -239,120 +225,20 @@ func setupJwks() error { return nil } -/* - * Fetch the department name of the user, mostly to identify which grade - * a student is in. This expects an accessToken obtained from the OAUTH 2.0 - * token endpoint obtained via an authorization code. It might also be able - * to use this as part of a hybrid flow that directly provides access tokens, - * but this flow seems to be only usable for single-page applications according - * to the Azure portal. - */ -func getDepartment(ctx context.Context, accessToken string) (string, error) { - req, err := http.NewRequestWithContext( - ctx, - http.MethodGet, - "https://graph.microsoft.com/v1.0/me?$select=department", - nil, - ) - if err != nil { - return "", wrapError(errCannotGetDepartment, err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - - client := &http.Client{} //exhaustruct:ignore - resp, err := client.Do(req) - if err != nil { - return "", wrapError(errCannotGetDepartment, err) - } - defer resp.Body.Close() - - var departmentWrap struct { - Department *string `json:"department"` - } - - decoder := json.NewDecoder(resp.Body) - err = decoder.Decode(&departmentWrap) - if err != nil { - return "", wrapError(errCannotGetDepartment, err) - } - - if departmentWrap.Department == nil { - /* - * This is probably because the response does not contain a - * "department" field, which hopefully doesn't occur as we - * have specified $select=department in the OData query. - */ - return "", wrapError( - errCannotGetDepartment, - errInsufficientFields, - ) +func getDepartmentByGroups(groups []string) (string, bool) { + for _, g := range groups { + d, ok := config.Auth.Departments[g] + if ok { + return d, true + } } - - return *(departmentWrap.Department), nil -} - -type accessTokenT struct { - Content *string `json:"access_token"` - Error *string `json:"error"` - ErrorDescription *string `json:"error_description"` - ErrorCodes *[]int `json:"error_codes"` + return "", false } -func getAccessToken( - ctx context.Context, - authorizationCode string, -) (accessTokenT, error) { - var accessToken accessTokenT - v := url.Values{} - v.Set("client_id", config.Auth.Client) - v.Set("scope", "https://graph.microsoft.com/User.Read") - v.Set("code", authorizationCode) - v.Set("redirect_uri", config.URL+"/auth") - v.Set("grant_type", "authorization_code") - v.Set("client_secret", config.Auth.Secret) - req, err := http.NewRequestWithContext( - ctx, - http.MethodPost, - config.Auth.Token, - strings.NewReader(v.Encode()), - ) - if err != nil { - return accessToken, - wrapError(errCannotFetchAccessToken, err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return accessToken, - wrapError(errCannotFetchAccessToken, err) - } - defer resp.Body.Close() - - decoder := json.NewDecoder(resp.Body) - err = decoder.Decode(&accessToken) - if err != nil { - return accessToken, - wrapError(errCannotFetchAccessToken, err) +func getDepartmentByUserIDOverride(userID string) (string, bool) { + d, ok := config.Auth.Udepts[userID] + if ok { + return d, true } - if accessToken.Error != nil || accessToken.ErrorCodes != nil || - accessToken.ErrorDescription != nil { - if accessToken.Error == nil || accessToken.ErrorCodes == nil || - accessToken.ErrorDescription == nil { - return accessToken, errCannotFetchAccessToken - } - return accessToken, - fmt.Errorf( - "%w: %v", - errTokenEndpointReturnedError, - *accessToken.ErrorDescription, - ) - } - if accessToken.Content == nil { - return accessToken, - fmt.Errorf( - "error extracting access token: %w", - errInsufficientFields, - ) - } - - return accessToken, nil + return "", false } @@ -28,10 +28,7 @@ import ( var ( errCannotSetupJwks = errors.New("cannot set up jwks") errInsufficientFields = errors.New("insufficient fields") - errCannotGetDepartment = errors.New("cannot get department") errUnknownDepartment = errors.New("unknown department") - errCannotFetchAccessToken = errors.New("cannot fetch access token") - errTokenEndpointReturnedError = errors.New("token endpoint returned error") errCannotProcessConfig = errors.New("cannot process configuration file") errCannotOpenConfig = errors.New("cannot open configuration file") errCannotDecodeConfig = errors.New("cannot decode configuration file") |