diff options
-rw-r--r-- | endpoint_auth.go | 116 | ||||
-rw-r--r-- | endpoint_export_choices.go | 73 | ||||
-rw-r--r-- | endpoint_export_students.go | 60 | ||||
-rw-r--r-- | endpoint_index.go | 31 | ||||
-rw-r--r-- | endpoint_newcourses.go | 217 | ||||
-rw-r--r-- | endpoint_state.go | 31 | ||||
-rw-r--r-- | endpoint_ws.go | 7 | ||||
-rw-r--r-- | errors.go | 80 | ||||
-rw-r--r-- | main.go | 12 | ||||
-rw-r--r-- | sethandler.go | 46 |
10 files changed, 249 insertions, 424 deletions
diff --git a/endpoint_auth.go b/endpoint_auth.go index 85acab8..4a9c770 100644 --- a/endpoint_auth.go +++ b/endpoint_auth.go @@ -81,48 +81,31 @@ func generateAuthorizationURL() (string, error) { * Expects JSON Web Keys to be already set up correctly; if myKeyfunc is null, * a null pointer is dereferenced and the thread panics. */ -func handleAuth(w http.ResponseWriter, req *http.Request) { +func handleAuth(w http.ResponseWriter, req *http.Request) (string, int, error) { if req.Method != http.MethodPost { - wstr( - w, - http.StatusMethodNotAllowed, - "Only POST is supported on the authentication endpoint", - ) - return + return "", http.StatusMethodNotAllowed, errPostOnly } err := req.ParseForm() if err != nil { - wstr(w, http.StatusBadRequest, "Malformed form data") - return + return "", http.StatusBadRequest, wrapError(errMalformedForm, err) } returnedError := req.PostFormValue("error") if returnedError != "" { returnedErrorDescription := req.PostFormValue("error_description") - if returnedErrorDescription == "" { - wstr( - w, - http.StatusBadRequest, - fmt.Sprintf( - "authorize endpoint returned error: %v", - returnedErrorDescription, - ), - ) - return - } - wstr(w, http.StatusBadRequest, fmt.Sprintf( - "%s: %s", - returnedError, - returnedErrorDescription, - )) - return + return "", http.StatusUnauthorized, wrapAny( + errAuthorizeEndpointError, + returnedError+": "+returnedErrorDescription, + ) } idTokenString := req.PostFormValue("id_token") if idTokenString == "" { - wstr(w, http.StatusBadRequest, "Missing id_token") - return + return "", http.StatusBadRequest, wrapAny( + errInsufficientFields, + "id_token", + ) } claimsTemplate := &msclaimsT{} //exhaustruct:ignore @@ -132,79 +115,68 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { myKeyfunc.Keyfunc, ) if err != nil { - wstr(w, http.StatusBadRequest, "Cannot parse claims") - return + return "", http.StatusBadRequest, wrapError( + errCannotParseClaims, + err, + ) } switch { case token.Valid: break case errors.Is(err, jwt.ErrTokenMalformed): - wstr(w, http.StatusBadRequest, "Malformed JWT token") - return + return "", http.StatusBadRequest, wrapError( + errJWTMalformed, + err, + ) case errors.Is(err, jwt.ErrTokenSignatureInvalid): - wstr(w, http.StatusBadRequest, "Invalid JWS signature") - return + return "", http.StatusBadRequest, wrapError( + errJWTSignatureInvalid, + err, + ) case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): - wstr( - w, - http.StatusBadRequest, - "JWT token expired or not yet valid", + return "", http.StatusBadRequest, wrapError( + errJWTExpired, + err, ) - return default: - wstr(w, http.StatusBadRequest, "Unhandled JWT token error") - return + return "", http.StatusBadRequest, wrapError( + errJWTInvalid, + err, + ) } claims, claimsOk := token.Claims.(*msclaimsT) if !claimsOk { - wstr(w, http.StatusBadRequest, "Cannot unpack claims") - return + return "", http.StatusBadRequest, errCannotUnpackClaims } authorizationCode := req.PostFormValue("code") accessToken, err := getAccessToken(req.Context(), authorizationCode) if err != nil { - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf("Unable to fetch access token: %v", err), - ) - return + return "", http.StatusInternalServerError, err } department, err := getDepartment(req.Context(), *(accessToken.Content)) if err != nil { - wstr(w, http.StatusInternalServerError, err.Error()) - return + return "", http.StatusInternalServerError, err } switch { - case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || - department == "High School Teaching & Learning 高中教学部门": + case department == "SJ Co-Curricular Activities Office 松江课外项目办公室": department = staffDepartment case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12": default: - wstr( - w, - http.StatusForbidden, - fmt.Sprintf( - "Your department \"%s\" is unknown.\nWe currently only allow Y9, Y10, Y11, Y12, and the CCA office.", - department, - ), - ) - return + return "", http.StatusForbidden, errUnknownDepartment } cookieValue, err := randomString(tokenLength) if err != nil { - wstr(w, http.StatusInternalServerError, err.Error()) - return + return "", http.StatusInternalServerError, err } now := time.Now() @@ -246,24 +218,16 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { claims.Oid, ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Database error while updating account.", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } } else { - wstr( - w, - http.StatusInternalServerError, - "Database error while writing account info.", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } } http.Redirect(w, req, "/", http.StatusSeeOther) + + return "", -1, nil } func setupJwks() error { diff --git a/endpoint_export_choices.go b/endpoint_export_choices.go index 890f170..80734b2 100644 --- a/endpoint_export_choices.go +++ b/endpoint_export_choices.go @@ -27,7 +27,7 @@ import ( "strings" ) -func handleExportChoices(w http.ResponseWriter, req *http.Request) { +func handleExportChoices(w http.ResponseWriter, req *http.Request) (string, int, error) { _, _, department, err := getUserInfoFromRequest(req) if err != nil { wstr( @@ -37,12 +37,7 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) { ) } if department != staffDepartment { - wstr( - w, - http.StatusForbidden, - "You are not authorized to view this page", - ) - return + return "", http.StatusForbidden, errStaffOnly } type userCacheT struct { @@ -54,24 +49,14 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) { rows, err := db.Query(req.Context(), "SELECT userid, courseid FROM choices") if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } output := make([][]string, 0) for { if !rows.Next() { err := rows.Err() if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } break } @@ -79,12 +64,7 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) { var currentCourseID int err := rows.Scan(¤tUserID, ¤tCourseID) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } currentUserCache, ok := userCacheMap[currentUserID] if ok { @@ -103,12 +83,7 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) { ¤tDepartment, ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } before, _, found := strings.Cut(currentUserEmail, "@") if found { @@ -125,24 +100,14 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) { _course, ok := courses.Load(currentCourseID) if !ok { - wstr( - w, - http.StatusInternalServerError, - "Reference to non-existent course", - ) - return + return "", http.StatusInternalServerError, wrapAny(errNoSuchCourse, currentCourseID) } course, ok := _course.(*courseT) if !ok { panic("courses map has non-\"*courseT\" items") } if course == nil { - wstr( - w, - http.StatusInternalServerError, - "Course is nil", - ) - return + return "", http.StatusInternalServerError, wrapAny(errNoSuchCourse, currentCourseID) } output = append( output, @@ -171,29 +136,15 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) { "Course ID", }) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Error writing output", - ) - return + return "", http.StatusInternalServerError, wrapError(errHTTPWrite, err) } err = csvWriter.WriteAll(output) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Error writing output", - ) - return + return "", http.StatusInternalServerError, wrapError(errHTTPWrite, err) } csvWriter.Flush() if csvWriter.Error() != nil { - wstr( - w, - http.StatusInternalServerError, - "Error occurred flushing output", - ) - return + return "", http.StatusInternalServerError, wrapError(errHTTPWrite, err) } + return "", -1, nil } diff --git a/endpoint_export_students.go b/endpoint_export_students.go index 75f9150..932f7fd 100644 --- a/endpoint_export_students.go +++ b/endpoint_export_students.go @@ -22,49 +22,29 @@ package main import ( "encoding/csv" - "fmt" "net/http" "strconv" ) -func handleExportStudents(w http.ResponseWriter, req *http.Request) { +func handleExportStudents(w http.ResponseWriter, req *http.Request) (string, int, error) { _, _, department, err := getUserInfoFromRequest(req) if err != nil { - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err), - ) + return "", http.StatusInternalServerError, err } if department != staffDepartment { - wstr( - w, - http.StatusForbidden, - "You are not authorized to view this page", - ) - return + return "", http.StatusInternalServerError, errStaffOnly } rows, err := db.Query(req.Context(), "SELECT name, email, department, confirmed FROM users") if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } output := make([][]string, 0) for { if !rows.Next() { err := rows.Err() if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } break } @@ -77,12 +57,7 @@ func handleExportStudents(w http.ResponseWriter, req *http.Request) { ¤tConfirmed, ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return + return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err) } if currentDepartment == staffDepartment { @@ -113,29 +88,16 @@ func handleExportStudents(w http.ResponseWriter, req *http.Request) { "Course ID", }) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Error writing output", - ) - return + return "", http.StatusInternalServerError, errHTTPWrite } err = csvWriter.WriteAll(output) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Error writing output", - ) - return + return "", http.StatusInternalServerError, errHTTPWrite } csvWriter.Flush() if csvWriter.Error() != nil { - wstr( - w, - http.StatusInternalServerError, - "Error occurred flushing output", - ) - return + return "", http.StatusInternalServerError, errHTTPWrite } + + return "", -1, nil } diff --git a/endpoint_index.go b/endpoint_index.go index dc81275..7376f76 100644 --- a/endpoint_index.go +++ b/endpoint_index.go @@ -22,23 +22,17 @@ package main import ( "errors" - "fmt" "log" "net/http" "sync/atomic" ) -func handleIndex(w http.ResponseWriter, req *http.Request) { +func handleIndex(w http.ResponseWriter, req *http.Request) (string, int, error) { _, username, department, err := getUserInfoFromRequest(req) if errors.Is(err, errNoCookie) || errors.Is(err, errNoSuchUser) { authURL, err2 := generateAuthorizationURL() if err2 != nil { - wstr( - w, - http.StatusInternalServerError, - "Cannot generate authorization URL", - ) - return + return "", -1, err2 } var noteString string if errors.Is(err, errNoSuchUser) { @@ -57,10 +51,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { ) if err2 != nil { log.Println(err2) + return "", -1, wrapError(errCannotWriteTemplate, err2) } - return + return "", -1, nil } else if err != nil { - wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: %v", err)) + return "", -1, err } /* TODO: The below should be completed on-update. */ @@ -106,9 +101,9 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { }, ) if err != nil { - log.Println(err) + return "", -1, wrapError(errCannotWriteTemplate, err) } - return + return "", -1, nil } if atomic.LoadUint32(&state) == 0 { @@ -124,17 +119,17 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { }, ) if err != nil { - log.Println(err) + return "", -1, wrapError(errCannotWriteTemplate, err) } - return + return "", -1, nil } sportRequired, err := getCourseTypeMinimumForYearGroup(department, sport) if err != nil { - wstr(w, http.StatusInternalServerError, "Failed to get sport requirement") + return "", -1, err } nonSportRequired, err := getCourseTypeMinimumForYearGroup(department, nonSport) if err != nil { - wstr(w, http.StatusInternalServerError, "Failed to get non-sport requirement") + return "", -1, err } err = tmpl.ExecuteTemplate( @@ -159,7 +154,7 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { }, ) if err != nil { - log.Println(err) - return + return "", -1, wrapError(errCannotWriteTemplate, err) } + return "", -1, nil } diff --git a/endpoint_newcourses.go b/endpoint_newcourses.go index 9f26b73..e11ed1e 100644 --- a/endpoint_newcourses.go +++ b/endpoint_newcourses.go @@ -33,84 +33,44 @@ import ( "github.com/jackc/pgx/v5" ) -func handleNewCourses(w http.ResponseWriter, req *http.Request) { +func handleNewCourses(w http.ResponseWriter, req *http.Request) (string, int, error) { if req.Method != http.MethodPost { - wstr(w, http.StatusMethodNotAllowed, "Only POST is allowed here") - return + return "", http.StatusMethodNotAllowed, errPostOnly } _, _, department, err := getUserInfoFromRequest(req) if err != nil { - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err), - ) + return "", -1, err } if department != staffDepartment { - wstr( - w, - http.StatusForbidden, - "You are not authorized to view this page", - ) - return + return "", http.StatusForbidden, errStaffOnly } if atomic.LoadUint32(&state) != 0 { - wstr( - w, - http.StatusBadRequest, - "Uploading the course table is only supported when student-access is disabled", - ) - return + return "", http.StatusBadRequest, errDisableStudentAccessFirst } /* TODO: Potential race. The global state may need to be write-locked. */ file, fileHeader, err := req.FormFile("coursecsv") if err != nil { - wstr( - w, - http.StatusBadRequest, - "Failed loading file from request... did you select a file before hitting that red button?", - ) - return + return "", http.StatusBadRequest, wrapError(errFormNoFile, err) } if fileHeader.Header.Get("Content-Type") != "text/csv" { - wstr( - w, - http.StatusBadRequest, - "Does not look like a proper CSV file", - ) - return + return "", http.StatusBadRequest, errNotACSV } csvReader := csv.NewReader(file) titleLine, err := csvReader.Read() if err != nil { - wstr( - w, - http.StatusBadRequest, - "Error reading CSV", - ) - return + return "", http.StatusBadRequest, wrapError(errCannotReadCSV, err) } if titleLine == nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected nil titleLine slice", - ) - return + return "", -1, errUnexpectedNilCSVLine } if len(titleLine) != 8 { - wstr( - w, - http.StatusBadRequest, - "First line has more than 8 elements", - ) - return + return "", -1, wrapAny(errBadCSVFormat, "expecting 8 fields on the first line") } var titleIndex, maxIndex, teacherIndex, locationIndex, typeIndex, groupIndex, sectionIDIndex, @@ -136,66 +96,41 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) { } } - { - check := func(indexName string, indexNum int) bool { - if indexNum == -1 { - wstr( - w, - http.StatusBadRequest, - fmt.Sprintf( - "Missing column \"%s\"", - indexName, - ), - ) - return true - } - return false - } - - if check("Title", titleIndex) { - return - } - if check("Max", maxIndex) { - return - } - if check("Teacher", teacherIndex) { - return - } - if check("Location", locationIndex) { - return - } - if check("Type", typeIndex) { - return - } - if check("Group", groupIndex) { - return - } - if check("Course ID", courseIDIndex) { - return - } - if check("Section ID", sectionIDIndex) { - return - } + if titleIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Title") + } + if maxIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Max") + } + if teacherIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Teacher") + } + if locationIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Location") + } + if typeIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Type") + } + if groupIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Group") + } + if courseIDIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Course ID") + } + if sectionIDIndex == -1 { + return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Section ID") } lineNumber := 1 - ok := func(ctx context.Context) bool { + ok, statusCode, err := func(ctx context.Context) (retBool bool, retStatus int, retErr error) { tx, err := db.Begin(ctx) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) + return false, -1, wrapError(errUnexpectedDBError, err) } defer func() { err := tx.Rollback(ctx) if err != nil && (!errors.Is(err, pgx.ErrTxClosed)) { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) + retBool, retStatus, retErr = false, -1, wrapError(errUnexpectedDBError, err) return } }() @@ -204,22 +139,14 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) { "DELETE FROM choices", ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) + return false, -1, wrapError(errUnexpectedDBError, err) } _, err = tx.Exec( ctx, "DELETE FROM courses", ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) + return false, -1, wrapError(errUnexpectedDBError, err) } for { @@ -229,57 +156,36 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) { if errors.Is(err, io.EOF) { break } - wstr( - w, - http.StatusInternalServerError, - "Error reading CSV", - ) - return false + return false, -1, wrapError(errCannotReadCSV, err) } if line == nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected nil line", - ) - return false + return false, -1, wrapError(errCannotReadCSV, errUnexpectedNilCSVLine) } if len(line) != 8 { - wstr( - w, - http.StatusBadRequest, - fmt.Sprintf( - "Line %d has insufficient items", - lineNumber, - ), - ) - return false + return false, -1, wrapAny(errInsufficientFields, fmt.Sprintf( + "line %d has insufficient items", + lineNumber, + )) } if !checkCourseType(line[typeIndex]) { - wstr( - w, - http.StatusBadRequest, + return false, -1, wrapAny(errInvalidCourseType, fmt.Sprintf( - "Line %d has invalid course type \"%s\"\nAllowed course types: %s", + "line %d has invalid course type \"%s\"\nallowed course types: %s", lineNumber, line[typeIndex], strings.Join(getKeysOfMap(courseTypes), ", "), ), ) - return false } if !checkCourseGroup(line[groupIndex]) { - wstr( - w, - http.StatusBadRequest, + return false, -1, wrapAny(errInvalidCourseGroup, fmt.Sprintf( - "Line %d has invalid course group \"%s\"\nAllowed course groups: %s", + "line %d has invalid course group \"%s\"\nallowed course groups: %s", lineNumber, line[groupIndex], strings.Join(getKeysOfMap(courseGroups), ", "), ), ) - return false } _, err = tx.Exec( ctx, @@ -294,39 +200,26 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) { line[courseIDIndex], ) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return false + return false, -1, wrapError(errUnexpectedDBError, err) } } err = tx.Commit(ctx) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Unexpected database error", - ) - return false + return false, -1, wrapError(errUnexpectedDBError, err) } - return true + return true, -1, nil }(req.Context()) if !ok { - return + return "", statusCode, err } courses.Clear() err = setupCourses(req.Context()) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Error setting up course table again, the data might be corrupted!", - ) - return + return "", -1, wrapError(errWhileSetttingUpCourseTablesAgain, err) } http.Redirect(w, req, "/", http.StatusSeeOther) + + return "", -1, nil } diff --git a/endpoint_state.go b/endpoint_state.go index 4e0ee56..ecba26a 100644 --- a/endpoint_state.go +++ b/endpoint_state.go @@ -21,48 +21,29 @@ package main import ( - "fmt" "net/http" "strconv" ) -func handleState(w http.ResponseWriter, req *http.Request) { +func handleState(w http.ResponseWriter, req *http.Request) (string, int, error) { _, _, department, err := getUserInfoFromRequest(req) if err != nil { - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf("Error: %v", err), - ) + return "", http.StatusUnauthorized, err } if department != staffDepartment { - wstr( - w, - http.StatusForbidden, - "You are not authorized to view this page", - ) - return + return "", http.StatusForbidden, errStaffOnly } basePath := req.PathValue("s") newState, err := strconv.ParseUint(basePath, 10, 32) if err != nil { - wstr( - w, - http.StatusBadRequest, - "State must be an unsigned 32-bit integer", - ) - return + return "", http.StatusBadRequest, wrapError(errInvalidState, err) } err = setState(req.Context(), uint32(newState)) if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Failed setting state, please return to previous page; are you sure it's within limits?", - ) - return + return "", http.StatusBadRequest, wrapError(errCannotSetState, err) } http.Redirect(w, req, "/", http.StatusSeeOther) + return "", -1, nil } diff --git a/endpoint_ws.go b/endpoint_ws.go index a46ea27..27dead6 100644 --- a/endpoint_ws.go +++ b/endpoint_ws.go @@ -45,7 +45,7 @@ func handleWs(w http.ResponseWriter, req *http.Request) { wstr( w, http.StatusBadRequest, - "This endpoint only supports valid WebSocket connections.", + "this endpoint only supports valid WebSocket connections: "+err.Error(), ) return } @@ -64,7 +64,10 @@ func handleWs(w http.ResponseWriter, req *http.Request) { err = handleConn(req.Context(), c, userID, department) if err != nil { - log.Println(err) + err := writeText(req.Context(), c, "E :"+err.Error()) + if err != nil { + log.Println(err) + } return } } @@ -26,31 +26,54 @@ import ( ) var ( - errCannotSetupJwks = errors.New("cannot set up jwks") - errInsufficientFields = errors.New("insufficient fields") - errCannotGetDepartment = errors.New("cannot get department") - errCannotFetchAccessToken = errors.New("cannot fetch access token") - errTokenEndpointReturnedError = errors.New("token endpoint returned error") - errCannotProcessConfig = errors.New("cannot process configuration file") - errCannotOpenConfig = errors.New("cannot open configuration file") - errCannotDecodeConfig = errors.New("cannot decode configuration file") - errMissingConfigValue = errors.New("missing configuration value") - errInvalidCourseType = errors.New("invalid course type") - errInvalidCourseGroup = errors.New("invalid course group") - errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user") - errUnsupportedDatabaseType = errors.New("unsupported db type") - errUnexpectedDBError = errors.New("unexpected database error") - errCannotSend = errors.New("cannot send") - errCannotGenerateRandomString = errors.New("cannot generate random string") - errContextCanceled = errors.New("context canceled") - errCannotReceiveMessage = errors.New("cannot receive message") - errNoSuchCourse = errors.New("no such course") - errInvalidState = errors.New("invalid state") - errWebSocketWrite = errors.New("error writing to websocket") - errCannotCheckCookie = errors.New("error checking cookie") - errNoCookie = errors.New("no cookie found") - errNoSuchUser = errors.New("no such user") - errNoSuchYearGroup = errors.New("no such year group") + errCannotSetupJwks = errors.New("cannot set up jwks") + errInsufficientFields = errors.New("insufficient fields") + errCannotGetDepartment = errors.New("cannot get department") + errUnknownDepartment = errors.New("unknown department") + errCannotFetchAccessToken = errors.New("cannot fetch access token") + errTokenEndpointReturnedError = errors.New("token endpoint returned error") + errCannotProcessConfig = errors.New("cannot process configuration file") + errCannotOpenConfig = errors.New("cannot open configuration file") + errCannotDecodeConfig = errors.New("cannot decode configuration file") + errMissingConfigValue = errors.New("missing configuration value") + errInvalidCourseType = errors.New("invalid course type") + errInvalidCourseGroup = errors.New("invalid course group") + errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user") + errUnsupportedDatabaseType = errors.New("unsupported db type") + errUnexpectedDBError = errors.New("unexpected database error") + errCannotSend = errors.New("cannot send") + errCannotGenerateRandomString = errors.New("cannot generate random string") + errContextCanceled = errors.New("context canceled") + errCannotReceiveMessage = errors.New("cannot receive message") + errNoSuchCourse = errors.New("reference to non-existent course") + errInvalidState = errors.New("invalid state") + errCannotSetState = errors.New("cannot set state") + errWebSocketWrite = errors.New("error writing to websocket") + errHTTPWrite = errors.New("error writing to http writer") + errCannotCheckCookie = errors.New("error checking cookie") + errNoCookie = errors.New("no cookie found") + errNoSuchUser = errors.New("no such user") + errNoSuchYearGroup = errors.New("no such year group") + errPostOnly = errors.New("only post is supported on this endpoint") + errMalformedForm = errors.New("malformed form") + errAuthorizeEndpointError = errors.New("authorize endpoint returned error") + errCannotParseClaims = errors.New("cannot parse claims") + errCannotUnpackClaims = errors.New("cannot unpack claims") + errJWTMalformed = errors.New("jwt token is malformed") + errJWTSignatureInvalid = errors.New("jwt token has invalid signature") + errJWTExpired = errors.New("jwt token has expired or is not yet valid") + errJWTInvalid = errors.New("jwt token is somehow invalid") + errStaffOnly = errors.New("this page is only available to staff") + errDisableStudentAccessFirst = errors.New("you must disable student access before performing this operation") + errFormNoFile = errors.New("you need to select a file before submitting the form") + errNotACSV = errors.New("the file you uploaded is not a csv file") + errCannotReadCSV = errors.New("cannot read csv") + errBadCSVFormat = errors.New("bad csv format") + errMissingCSVColumn = errors.New("missing csv column") + errUnexpectedNilCSVLine = errors.New("unexpected nil csv line") + errWhileSetttingUpCourseTablesAgain = errors.New("error while setting up course tables again") + errCannotWriteTemplate = errors.New("cannot write template") + // errInvalidCourseID = errors.New("invalid course id") ) func wrapError(a, b error) error { @@ -59,3 +82,10 @@ func wrapError(a, b error) error { } return fmt.Errorf("%w: %w", a, b) } + +func wrapAny(a error, b any) error { + if a == nil && b == nil { + return nil + } + return fmt.Errorf("%w: %v", a, b) +} @@ -116,13 +116,13 @@ func main() { ) log.Println("Registering handlers") - http.HandleFunc("/{$}", handleIndex) - http.HandleFunc("/export/choices", handleExportChoices) - http.HandleFunc("/export/students", handleExportStudents) - http.HandleFunc("/auth", handleAuth) http.HandleFunc("/ws", handleWs) - http.HandleFunc("/state/{s}", handleState) - http.HandleFunc("/newcourses", handleNewCourses) + setHandler("/{$}", handleIndex) + setHandler("/export/choices", handleExportChoices) + setHandler("/export/students", handleExportStudents) + setHandler("/auth", handleAuth) + setHandler("/state/{s}", handleState) + setHandler("/newcourses", handleNewCourses) var l net.Listener diff --git a/sethandler.go b/sethandler.go new file mode 100644 index 0000000..4022d5c --- /dev/null +++ b/sethandler.go @@ -0,0 +1,46 @@ +/* + * HTTP handler setting + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "net/http" +) + +func setHandler(pattern string, handler func(http.ResponseWriter, *http.Request) (string, int, error)) { + http.HandleFunc(pattern, func(w http.ResponseWriter, req *http.Request) { + msg, statusCode, err := handler(w, req) + if err != nil { + if statusCode == -1 || statusCode == 0 { + statusCode = 500 + } + if msg != "" { + wstr(w, statusCode, msg+"\n"+err.Error()) + } else { + wstr(w, statusCode, err.Error()) + } + } else { + if statusCode == -1 || statusCode == 0 { + statusCode = 200 + } + wstr(w, statusCode, msg) + } + }) +} |