diff options
Diffstat (limited to 'endpoint_auth.go')
-rw-r--r-- | endpoint_auth.go | 116 |
1 files changed, 40 insertions, 76 deletions
diff --git a/endpoint_auth.go b/endpoint_auth.go index 85acab8..4a9c770 100644 --- a/endpoint_auth.go +++ b/endpoint_auth.go @@ -81,48 +81,31 @@ func generateAuthorizationURL() (string, error) { * Expects JSON Web Keys to be already set up correctly; if myKeyfunc is null, * a null pointer is dereferenced and the thread panics. */ -func handleAuth(w http.ResponseWriter, req *http.Request) { +func handleAuth(w http.ResponseWriter, req *http.Request) (string, int, error) { if req.Method != http.MethodPost { - wstr( - w, - http.StatusMethodNotAllowed, - "Only POST is supported on the authentication endpoint", - ) - return + return "", http.StatusMethodNotAllowed, errPostOnly } err := req.ParseForm() if err != nil { - wstr(w, http.StatusBadRequest, "Malformed form data") - return + return "", http.StatusBadRequest, wrapError(errMalformedForm, err) } returnedError := req.PostFormValue("error") if returnedError != "" { returnedErrorDescription := req.PostFormValue("error_description") - if returnedErrorDescription == "" { - wstr( - w, - http.StatusBadRequest, - fmt.Sprintf( - "authorize endpoint returned error: %v", - returnedErrorDescription, - ), - ) - return - } - wstr(w, http.StatusBadRequest, fmt.Sprintf( - "%s: %s", - returnedError, - returnedErrorDescription, - )) - return + return "", http.StatusUnauthorized, wrapAny( + errAuthorizeEndpointError, + returnedError+": "+returnedErrorDescription, + ) } idTokenString := req.PostFormValue("id_token") if idTokenString == "" { - wstr(w, http.StatusBadRequest, "Missing id_token") - return + return "", http.StatusBadRequest, wrapAny( + errInsufficientFields, + "id_token", + ) } claimsTemplate := &msclaimsT{} //exhaustruct:ignore @@ -132,79 +115,68 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { myKeyfunc.Keyfunc, ) if err != nil { - wstr(w, http.StatusBadRequest, "Cannot parse claims") - return + return "", http.StatusBadRequest, wrapError( + errCannotParseClaims, + err, + ) } switch { case token.Valid: break case errors.Is(err, jwt.ErrTokenMalformed): - wstr(w, http.StatusBadRequest, "Malformed JWT token") - return + return "", http.StatusBadRequest, wrapError( + errJWTMalformed, + err, + ) case errors.Is(err, jwt.ErrTokenSignatureInvalid): - wstr(w, http.StatusBadRequest, "Invalid JWS signature") - return + return "", http.StatusBadRequest, wrapError( + errJWTSignatureInvalid, + err, + ) case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): - wstr( - w, - http.StatusBadRequest, - "JWT token expired or not yet valid", + return "", http.StatusBadRequest, wrapError( + errJWTExpired, + err, ) - return default: - wstr(w, http.StatusBadRequest, "Unhandled JWT token error") - return + return "", http.StatusBadRequest, wrapError( + errJWTInvalid, + err, + ) } claims, claimsOk := token.Claims.(*msclaimsT) if !claimsOk { - wstr(w, http.StatusBadRequest, "Cannot unpack claims") - return + return "", http.StatusBadRequest, errCannotUnpackClaims } authorizationCode := req.PostFormValue("code") accessToken, err := getAccessToken(req.Context(), authorizationCode) if err != nil { - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf("Unable to fetch access token: %v", err), - ) - return + return "", http.StatusInternalServerError, err } department, err := getDepartment(req.Context(), *(accessToken.Content)) if err != nil { - wstr(w, http.StatusInternalServerError, err.Error()) - return + return "", http.StatusInternalServerError, err } switch { - case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || - department == "High School Teaching & Learning 高中教学部门": + case department == "SJ Co-Curricular Activities Office 松江课外项目办公室": department = staffDepartment case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12": default: - wstr( - w, - http.StatusForbidden, - fmt.Sprintf( - "Your department \"%s\" is unknown.\nWe currently only allow Y9, Y10, Y11, Y12, and the CCA office.", - department, - ), - ) - return + return "", http.StatusForbidden, errUnknownDepartment } cookieValue, err := randomString(tokenLength) if err != nil { - wstr(w, http.StatusInternalServerError, err.Error()) - return + return "", http.StatusInternalServerError, err } now := time.Now() @@ -246,24 +218,16 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { claims.Oid, ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Database error while updating account.", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } } else { - wstr( - w, - http.StatusInternalServerError, - "Database error while writing account info.", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } } http.Redirect(w, req, "/", http.StatusSeeOther) + + return "", -1, nil } func setupJwks() error { |