diff options
Diffstat (limited to 'endpoint_auth.go')
-rw-r--r-- | endpoint_auth.go | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/endpoint_auth.go b/endpoint_auth.go new file mode 100644 index 0000000..58eb46b --- /dev/null +++ b/endpoint_auth.go @@ -0,0 +1,395 @@ +/* + * Custom OAUTH 2.0 implementation for the CCA Selection Service + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +var myKeyfunc keyfunc.Keyfunc + +const tokenLength = 20 + +/* + * These are the claims in the JSON Web Token received from the client, after + * it redirects from the authorize endpoint. Some of these fields must be + * explicitly selected in the Azure app registration and might appear as + * zero strings if it hasn't been configured correctly. + */ +type msclaimsT struct { + Name string `json:"name"` /* Scope: profile */ + Email string `json:"email"` /* Scope: email */ + Oid string `json:"oid"` /* Scope: profile */ + jwt.RegisteredClaims +} + +func generateAuthorizationURL() (string, error) { + nonce, err := randomString(tokenLength) + if err != nil { + return "", err + } + /* + * Note that here we use a hybrid authentication flow to obtain an + * id_token for authentication and an authorization code. The + * authorization code may be used like any other; i.e., it may be used + * to obtain an access token directly, or the refresh token may be used + * to gain persistent access to the upstream API. Sometimes I wish that + * the JWT in id_token could have more claims. The only reason we + * presently use a hybrid flow is to use the authorization code to + * obtain an access code to call the user info endpoint to fetch the + * user's department information. + */ + return fmt.Sprintf( + "https://login.microsoftonline.com/ddd3d26c-b197-4d00-a32d-1ffd84c0c295/oauth2/authorize?client_id=%s&response_type=id_token%%20code&redirect_uri=%s%%2Fauth&response_mode=form_post&scope=openid+profile+email+User.Read&nonce=%s", + config.Auth.Client, + config.URL, + nonce, + ), nil +} + +/* + * Handles redirects to the /auth endpoint from the authorize endpoint. + * Expects JSON Web Keys to be already set up correctly; if myKeyfunc is null, + * a null pointer is dereferenced and the thread panics. + */ +func handleAuth(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + wstr( + w, + http.StatusMethodNotAllowed, + "Only POST is supported on the authentication endpoint", + ) + return + } + + err := req.ParseForm() + if err != nil { + wstr(w, http.StatusBadRequest, "Malformed form data") + return + } + + returnedError := req.PostFormValue("error") + if returnedError != "" { + returnedErrorDescription := req.PostFormValue("error_description") + if returnedErrorDescription == "" { + wstr( + w, + http.StatusBadRequest, + fmt.Sprintf( + "authorize endpoint returned error: %v", + returnedErrorDescription, + ), + ) + return + } + wstr(w, http.StatusBadRequest, fmt.Sprintf( + "%s: %s", + returnedError, + returnedErrorDescription, + )) + return + } + + idTokenString := req.PostFormValue("id_token") + if idTokenString == "" { + wstr(w, http.StatusBadRequest, "Missing id_token") + return + } + + claimsTemplate := &msclaimsT{} //exhaustruct:ignore + token, err := jwt.ParseWithClaims( + idTokenString, + claimsTemplate, + myKeyfunc.Keyfunc, + ) + if err != nil { + wstr(w, http.StatusBadRequest, "Cannot parse claims") + return + } + + switch { + case token.Valid: + break + case errors.Is(err, jwt.ErrTokenMalformed): + wstr(w, http.StatusBadRequest, "Malformed JWT token") + return + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + wstr(w, http.StatusBadRequest, "Invalid JWS signature") + return + case errors.Is(err, jwt.ErrTokenExpired) || + errors.Is(err, jwt.ErrTokenNotValidYet): + wstr( + w, + http.StatusBadRequest, + "JWT token expired or not yet valid", + ) + return + default: + wstr(w, http.StatusBadRequest, "Unhandled JWT token error") + return + } + + claims, claimsOk := token.Claims.(*msclaimsT) + + if !claimsOk { + wstr(w, http.StatusBadRequest, "Cannot unpack claims") + return + } + + authorizationCode := req.PostFormValue("code") + + accessToken, err := getAccessToken(req.Context(), authorizationCode) + if err != nil { + wstr( + w, + http.StatusInternalServerError, + fmt.Sprintf("Unable to fetch access token: %v", err), + ) + return + } + + department, err := getDepartment(req.Context(), *(accessToken.Content)) + if err != nil { + wstr(w, http.StatusInternalServerError, err.Error()) + return + } + + switch { + case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || + department == "High School Teaching & Learning 高中教学部门": + department = staffDepartment + case department == "Y9" || department == "Y10" || + department == "Y11" || department == "Y12": + default: + wstr( + w, + http.StatusForbidden, + fmt.Sprintf( + "Your department \"%s\" is unknown.\nWe currently only allow Y9, Y10, Y11, Y12, and the CCA office.", + department, + ), + ) + return + } + + cookieValue, err := randomString(tokenLength) + if err != nil { + wstr(w, http.StatusInternalServerError, err.Error()) + return + } + + now := time.Now() + expr := now.Add(time.Duration(config.Auth.Expr) * time.Second) + exprU := expr.Unix() + + cookie := http.Cookie{ + Name: "session", + Value: cookieValue, + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: config.Prod, + Expires: expr, + } //exhaustruct:ignore + + http.SetCookie(w, &cookie) + + _, err = db.Exec( + req.Context(), + "INSERT INTO users (id, name, email, department, session, expr, confirmed) VALUES ($1, $2, $3, $4, $5, $6, false)", + claims.Oid, + claims.Name, + claims.Email, + department, + cookieValue, + exprU, + ) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation { + _, err := db.Exec( + req.Context(), + "UPDATE users SET (name, email, department, session, expr) = ($1, $2, $3, $4, $5) WHERE id = $6", + claims.Name, + claims.Email, + department, + cookieValue, + exprU, + claims.Oid, + ) + if err != nil { + wstr( + w, + http.StatusInternalServerError, + "Database error while updating account.", + ) + return + } + } else { + wstr( + w, + http.StatusInternalServerError, + "Database error while writing account info.", + ) + return + } + } + + http.Redirect(w, req, "/", http.StatusSeeOther) +} + +func setupJwks() error { + var err error + myKeyfunc, err = keyfunc.NewDefault([]string{config.Auth.Jwks}) + if err != nil { + return fmt.Errorf("%w: %w", errCannotSetupJwks, err) + } + return nil +} + +/* + * Fetch the department name of the user, mostly to identify which grade + * a student is in. This expects an accessToken obtained from the OAUTH 2.0 + * token endpoint obtained via an authorization code. It might also be able + * to use this as part of a hybrid flow that directly provides access tokens, + * but this flow seems to be only usable for single-page applications according + * to the Azure portal. + */ +func getDepartment(ctx context.Context, accessToken string) (string, error) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + "https://graph.microsoft.com/v1.0/me?$select=department", + nil, + ) + if err != nil { + return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{} //exhaustruct:ignore + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err) + } + defer resp.Body.Close() + + var departmentWrap struct { + Department *string `json:"department"` + } + + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&departmentWrap) + if err != nil { + return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err) + } + + if departmentWrap.Department == nil { + /* + * This is probably because the response does not contain a + * "department" field, which hopefully doesn't occur as we + * have specified $select=department in the OData query. + */ + return "", fmt.Errorf( + "%w: %w", + errCannotGetDepartment, + errInsufficientFields, + ) + } + + return *(departmentWrap.Department), nil +} + +type accessTokenT struct { + Content *string `json:"access_token"` + Error *string `json:"error"` + ErrorDescription *string `json:"error_description"` + ErrorCodes *[]int `json:"error_codes"` +} + +func getAccessToken( + ctx context.Context, + authorizationCode string, +) (accessTokenT, error) { + var accessToken accessTokenT + v := url.Values{} + v.Set("client_id", config.Auth.Client) + v.Set("scope", "https://graph.microsoft.com/User.Read") + v.Set("code", authorizationCode) + v.Set("redirect_uri", config.URL+"/auth") + v.Set("grant_type", "authorization_code") + v.Set("client_secret", config.Auth.Secret) + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + config.Auth.Token, + strings.NewReader(v.Encode()), + ) + if err != nil { + return accessToken, + fmt.Errorf("%w: %w", errCannotFetchAccessToken, err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return accessToken, + fmt.Errorf("%w: %w", errCannotFetchAccessToken, err) + } + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&accessToken) + if err != nil { + return accessToken, + fmt.Errorf("%w: %w", errCannotFetchAccessToken, err) + } + if accessToken.Error != nil || accessToken.ErrorCodes != nil || + accessToken.ErrorDescription != nil { + if accessToken.Error == nil || accessToken.ErrorCodes == nil || + accessToken.ErrorDescription == nil { + return accessToken, errCannotFetchAccessToken + } + return accessToken, + fmt.Errorf( + "%w: %v", + errTokenEndpointReturnedError, + *accessToken.ErrorDescription, + ) + } + if accessToken.Content == nil { + return accessToken, + fmt.Errorf( + "error extracting access token: %w", + errInsufficientFields, + ) + } + + return accessToken, nil +} |