diff options
Diffstat (limited to '')
-rw-r--r-- | auth.go | 97 |
1 files changed, 77 insertions, 20 deletions
@@ -96,7 +96,11 @@ func generateAuthorizationURL() (string, error) { */ func handleAuth(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { - wstr(w, http.StatusMethodNotAllowed, "Only POST is supported on the authentication endpoint") + wstr( + w, + http.StatusMethodNotAllowed, + "Only POST is supported on the authentication endpoint", + ) return } @@ -110,7 +114,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { if returnedError != "" { returnedErrorDescription := req.PostFormValue("error_description") if returnedErrorDescription == "" { - wstr(w, http.StatusBadRequest, fmt.Sprintf("authorize endpoint returned error: %v", returnedErrorDescription)) + wstr( + w, + http.StatusBadRequest, + fmt.Sprintf( + "authorize endpoint returned error: %v", + returnedErrorDescription, + ), + ) return } wstr(w, http.StatusBadRequest, fmt.Sprintf( @@ -149,7 +160,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { return case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): - wstr(w, http.StatusBadRequest, "JWT token expired or not yet valid") + wstr( + w, + http.StatusBadRequest, + "JWT token expired or not yet valid", + ) return default: wstr(w, http.StatusBadRequest, "Unhandled JWT token error") @@ -167,7 +182,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { accessToken, err := getAccessToken(req.Context(), authorizationCode) if err != nil { - wstr(w, http.StatusInternalServerError, fmt.Sprintf("Unable to fetch access token: %v", err)) + wstr( + w, + http.StatusInternalServerError, + fmt.Sprintf("Unable to fetch access token: %v", err), + ) return } @@ -178,9 +197,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { } switch { - case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || department == "High School Teaching & Learning 高中教学部门": + case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || + department == "High School Teaching & Learning 高中教学部门": department = "Staff" - case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12": + case department == "Y9" || department == "Y10" || + department == "Y11" || department == "Y12": default: wstr( w, @@ -245,11 +266,19 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { claims.Oid, ) if err != nil { - wstr(w, http.StatusInternalServerError, "Database error while updating account.") + wstr( + w, + http.StatusInternalServerError, + "Database error while updating account.", + ) return } } else { - wstr(w, http.StatusInternalServerError, "Database error while writing account info.") + wstr( + w, + http.StatusInternalServerError, + "Database error while writing account info.", + ) return } } @@ -312,7 +341,11 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { * "department" field, which hopefully doesn't occur as we * have specified $select=department in the OData query. */ - return "", fmt.Errorf("error getting department: %w", errInsufficientFields) + return "", + fmt.Errorf( + "error getting department: %w", + errInsufficientFields, + ) } return *(departmentWrap.Department), nil @@ -322,7 +355,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { * TODO: Access token expiration is not checked anywhere. */ type accessTokenT struct { - OriginalExpiresIn *int `json:"expires_in"` /* Original time to expiration */ + OriginalExpiresIn *int `json:"expires_in"` /* Original time to expr */ Expiration time.Time Content *string `json:"access_token"` Error *string `json:"error"` @@ -334,7 +367,10 @@ type accessTokenT struct { * Obtain an access token from the token endpoint with an existing * authorization code. */ -func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT, error) { +func getAccessToken( + ctx context.Context, + authorizationCode string, +) (accessTokenT, error) { var accessToken accessTokenT t := time.Now() v := url.Values{} @@ -344,31 +380,52 @@ func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT v.Set("redirect_uri", config.URL+"/auth") v.Set("grant_type", "authorization_code") v.Set("client_secret", config.Auth.Secret) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, config.Auth.Token, strings.NewReader(v.Encode())) + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + config.Auth.Token, + strings.NewReader(v.Encode()), + ) if err != nil { - return accessToken, fmt.Errorf("error making access token request: %w", err) + return accessToken, + fmt.Errorf("error making access token request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return accessToken, fmt.Errorf("error requesting access token: %w", err) + return accessToken, + fmt.Errorf("error requesting access token: %w", err) } defer resp.Body.Close() decoder := json.NewDecoder(resp.Body) err = decoder.Decode(&accessToken) if err != nil { - return accessToken, fmt.Errorf("error decoding access token: %w", err) + return accessToken, + fmt.Errorf("error decoding access token: %w", err) } - if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil { - if accessToken.Error == nil || accessToken.ErrorCodes == nil || accessToken.ErrorDescription == nil { + if accessToken.Error != nil || accessToken.ErrorCodes != nil || + accessToken.ErrorDescription != nil { + if accessToken.Error == nil || accessToken.ErrorCodes == nil || + accessToken.ErrorDescription == nil { return accessToken, errAccessTokenIncompleteError } - return accessToken, fmt.Errorf("%w: %v", errTokenEndpointReturnedError, *accessToken.ErrorDescription) + return accessToken, + fmt.Errorf( + "%w: %v", + errTokenEndpointReturnedError, + *accessToken.ErrorDescription, + ) } if accessToken.Content == nil || accessToken.OriginalExpiresIn == nil { - return accessToken, fmt.Errorf("error extracting access token: %w", errInsufficientFields) + return accessToken, + fmt.Errorf( + "error extracting access token: %w", + errInsufficientFields, + ) } - accessToken.Expiration = t.Add(time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second) + accessToken.Expiration = t.Add( + time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second, + ) return accessToken, nil } |