aboutsummaryrefslogtreecommitdiff
path: root/endpoint_auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'endpoint_auth.go')
-rw-r--r--endpoint_auth.go162
1 files changed, 24 insertions, 138 deletions
diff --git a/endpoint_auth.go b/endpoint_auth.go
index 0ff3316..e27f74f 100644
--- a/endpoint_auth.go
+++ b/endpoint_auth.go
@@ -21,13 +21,9 @@
package main
import (
- "context"
- "encoding/json"
"errors"
"fmt"
"net/http"
- "net/url"
- "strings"
"time"
"github.com/MicahParks/keyfunc/v3"
@@ -46,9 +42,10 @@ const tokenLength = 20
* zero strings if it hasn't been configured correctly.
*/
type msclaimsT struct {
- Name string `json:"name"` /* Scope: profile */
- Email string `json:"email"` /* Scope: email */
- Oid string `json:"oid"` /* Scope: profile */
+ Name string `json:"name"`
+ Email string `json:"email"`
+ Oid string `json:"oid"`
+ Groups []string `json:"groups"`
jwt.RegisteredClaims
}
@@ -153,25 +150,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) (string, int, error) {
return "", http.StatusBadRequest, errCannotUnpackClaims
}
- authorizationCode := req.PostFormValue("code")
-
- accessToken, err := getAccessToken(req.Context(), authorizationCode)
- if err != nil {
- return "", -1, err
- }
-
- department, err := getDepartment(req.Context(), *(accessToken.Content))
- if err != nil {
- return "", -1, err
- }
-
- switch {
- case department == "SJ Co-Curricular Activities Office 松江课外项目办公室":
- department = staffDepartment
- case department == "Y9" || department == "Y10" ||
- department == "Y11" || department == "Y12":
- default:
- return "", http.StatusForbidden, errUnknownDepartment
+ var department string
+ var ok bool
+ department, ok = getDepartmentByUserIDOverride(claims.Oid)
+ if !ok {
+ department, ok = getDepartmentByGroups(claims.Groups)
+ if !ok {
+ return "", http.StatusBadRequest, errUnknownDepartment
+ }
}
cookieValue, err := randomString(tokenLength)
@@ -239,120 +225,20 @@ func setupJwks() error {
return nil
}
-/*
- * Fetch the department name of the user, mostly to identify which grade
- * a student is in. This expects an accessToken obtained from the OAUTH 2.0
- * token endpoint obtained via an authorization code. It might also be able
- * to use this as part of a hybrid flow that directly provides access tokens,
- * but this flow seems to be only usable for single-page applications according
- * to the Azure portal.
- */
-func getDepartment(ctx context.Context, accessToken string) (string, error) {
- req, err := http.NewRequestWithContext(
- ctx,
- http.MethodGet,
- "https://graph.microsoft.com/v1.0/me?$select=department",
- nil,
- )
- if err != nil {
- return "", wrapError(errCannotGetDepartment, err)
- }
- req.Header.Set("Authorization", "Bearer "+accessToken)
-
- client := &http.Client{} //exhaustruct:ignore
- resp, err := client.Do(req)
- if err != nil {
- return "", wrapError(errCannotGetDepartment, err)
- }
- defer resp.Body.Close()
-
- var departmentWrap struct {
- Department *string `json:"department"`
- }
-
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&departmentWrap)
- if err != nil {
- return "", wrapError(errCannotGetDepartment, err)
- }
-
- if departmentWrap.Department == nil {
- /*
- * This is probably because the response does not contain a
- * "department" field, which hopefully doesn't occur as we
- * have specified $select=department in the OData query.
- */
- return "", wrapError(
- errCannotGetDepartment,
- errInsufficientFields,
- )
+func getDepartmentByGroups(groups []string) (string, bool) {
+ for _, g := range groups {
+ d, ok := config.Auth.Departments[g]
+ if ok {
+ return d, true
+ }
}
-
- return *(departmentWrap.Department), nil
-}
-
-type accessTokenT struct {
- Content *string `json:"access_token"`
- Error *string `json:"error"`
- ErrorDescription *string `json:"error_description"`
- ErrorCodes *[]int `json:"error_codes"`
+ return "", false
}
-func getAccessToken(
- ctx context.Context,
- authorizationCode string,
-) (accessTokenT, error) {
- var accessToken accessTokenT
- v := url.Values{}
- v.Set("client_id", config.Auth.Client)
- v.Set("scope", "https://graph.microsoft.com/User.Read")
- v.Set("code", authorizationCode)
- v.Set("redirect_uri", config.URL+"/auth")
- v.Set("grant_type", "authorization_code")
- v.Set("client_secret", config.Auth.Secret)
- req, err := http.NewRequestWithContext(
- ctx,
- http.MethodPost,
- config.Auth.Token,
- strings.NewReader(v.Encode()),
- )
- if err != nil {
- return accessToken,
- wrapError(errCannotFetchAccessToken, err)
- }
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return accessToken,
- wrapError(errCannotFetchAccessToken, err)
- }
- defer resp.Body.Close()
-
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&accessToken)
- if err != nil {
- return accessToken,
- wrapError(errCannotFetchAccessToken, err)
+func getDepartmentByUserIDOverride(userID string) (string, bool) {
+ d, ok := config.Auth.Udepts[userID]
+ if ok {
+ return d, true
}
- if accessToken.Error != nil || accessToken.ErrorCodes != nil ||
- accessToken.ErrorDescription != nil {
- if accessToken.Error == nil || accessToken.ErrorCodes == nil ||
- accessToken.ErrorDescription == nil {
- return accessToken, errCannotFetchAccessToken
- }
- return accessToken,
- fmt.Errorf(
- "%w: %v",
- errTokenEndpointReturnedError,
- *accessToken.ErrorDescription,
- )
- }
- if accessToken.Content == nil {
- return accessToken,
- fmt.Errorf(
- "error extracting access token: %w",
- errInsufficientFields,
- )
- }
-
- return accessToken, nil
+ return "", false
}