summaryrefslogtreecommitdiff
path: root/auth.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--auth.go97
1 files changed, 77 insertions, 20 deletions
diff --git a/auth.go b/auth.go
index 9ef1254..1fd6086 100644
--- a/auth.go
+++ b/auth.go
@@ -96,7 +96,11 @@ func generateAuthorizationURL() (string, error) {
*/
func handleAuth(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
- wstr(w, http.StatusMethodNotAllowed, "Only POST is supported on the authentication endpoint")
+ wstr(
+ w,
+ http.StatusMethodNotAllowed,
+ "Only POST is supported on the authentication endpoint",
+ )
return
}
@@ -110,7 +114,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
if returnedError != "" {
returnedErrorDescription := req.PostFormValue("error_description")
if returnedErrorDescription == "" {
- wstr(w, http.StatusBadRequest, fmt.Sprintf("authorize endpoint returned error: %v", returnedErrorDescription))
+ wstr(
+ w,
+ http.StatusBadRequest,
+ fmt.Sprintf(
+ "authorize endpoint returned error: %v",
+ returnedErrorDescription,
+ ),
+ )
return
}
wstr(w, http.StatusBadRequest, fmt.Sprintf(
@@ -149,7 +160,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
return
case errors.Is(err, jwt.ErrTokenExpired) ||
errors.Is(err, jwt.ErrTokenNotValidYet):
- wstr(w, http.StatusBadRequest, "JWT token expired or not yet valid")
+ wstr(
+ w,
+ http.StatusBadRequest,
+ "JWT token expired or not yet valid",
+ )
return
default:
wstr(w, http.StatusBadRequest, "Unhandled JWT token error")
@@ -167,7 +182,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
accessToken, err := getAccessToken(req.Context(), authorizationCode)
if err != nil {
- wstr(w, http.StatusInternalServerError, fmt.Sprintf("Unable to fetch access token: %v", err))
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ fmt.Sprintf("Unable to fetch access token: %v", err),
+ )
return
}
@@ -178,9 +197,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
}
switch {
- case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || department == "High School Teaching & Learning 高中教学部门":
+ case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" ||
+ department == "High School Teaching & Learning 高中教学部门":
department = "Staff"
- case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12":
+ case department == "Y9" || department == "Y10" ||
+ department == "Y11" || department == "Y12":
default:
wstr(
w,
@@ -245,11 +266,19 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
claims.Oid,
)
if err != nil {
- wstr(w, http.StatusInternalServerError, "Database error while updating account.")
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Database error while updating account.",
+ )
return
}
} else {
- wstr(w, http.StatusInternalServerError, "Database error while writing account info.")
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Database error while writing account info.",
+ )
return
}
}
@@ -312,7 +341,11 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
* "department" field, which hopefully doesn't occur as we
* have specified $select=department in the OData query.
*/
- return "", fmt.Errorf("error getting department: %w", errInsufficientFields)
+ return "",
+ fmt.Errorf(
+ "error getting department: %w",
+ errInsufficientFields,
+ )
}
return *(departmentWrap.Department), nil
@@ -322,7 +355,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
* TODO: Access token expiration is not checked anywhere.
*/
type accessTokenT struct {
- OriginalExpiresIn *int `json:"expires_in"` /* Original time to expiration */
+ OriginalExpiresIn *int `json:"expires_in"` /* Original time to expr */
Expiration time.Time
Content *string `json:"access_token"`
Error *string `json:"error"`
@@ -334,7 +367,10 @@ type accessTokenT struct {
* Obtain an access token from the token endpoint with an existing
* authorization code.
*/
-func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT, error) {
+func getAccessToken(
+ ctx context.Context,
+ authorizationCode string,
+) (accessTokenT, error) {
var accessToken accessTokenT
t := time.Now()
v := url.Values{}
@@ -344,31 +380,52 @@ func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT
v.Set("redirect_uri", config.URL+"/auth")
v.Set("grant_type", "authorization_code")
v.Set("client_secret", config.Auth.Secret)
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, config.Auth.Token, strings.NewReader(v.Encode()))
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodPost,
+ config.Auth.Token,
+ strings.NewReader(v.Encode()),
+ )
if err != nil {
- return accessToken, fmt.Errorf("error making access token request: %w", err)
+ return accessToken,
+ fmt.Errorf("error making access token request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
- return accessToken, fmt.Errorf("error requesting access token: %w", err)
+ return accessToken,
+ fmt.Errorf("error requesting access token: %w", err)
}
defer resp.Body.Close()
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&accessToken)
if err != nil {
- return accessToken, fmt.Errorf("error decoding access token: %w", err)
+ return accessToken,
+ fmt.Errorf("error decoding access token: %w", err)
}
- if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil {
- if accessToken.Error == nil || accessToken.ErrorCodes == nil || accessToken.ErrorDescription == nil {
+ if accessToken.Error != nil || accessToken.ErrorCodes != nil ||
+ accessToken.ErrorDescription != nil {
+ if accessToken.Error == nil || accessToken.ErrorCodes == nil ||
+ accessToken.ErrorDescription == nil {
return accessToken, errAccessTokenIncompleteError
}
- return accessToken, fmt.Errorf("%w: %v", errTokenEndpointReturnedError, *accessToken.ErrorDescription)
+ return accessToken,
+ fmt.Errorf(
+ "%w: %v",
+ errTokenEndpointReturnedError,
+ *accessToken.ErrorDescription,
+ )
}
if accessToken.Content == nil || accessToken.OriginalExpiresIn == nil {
- return accessToken, fmt.Errorf("error extracting access token: %w", errInsufficientFields)
+ return accessToken,
+ fmt.Errorf(
+ "error extracting access token: %w",
+ errInsufficientFields,
+ )
}
- accessToken.Expiration = t.Add(time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second)
+ accessToken.Expiration = t.Add(
+ time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second,
+ )
return accessToken, nil
}