aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--endpoint_auth.go116
-rw-r--r--endpoint_export_choices.go73
-rw-r--r--endpoint_export_students.go60
-rw-r--r--endpoint_index.go31
-rw-r--r--endpoint_newcourses.go217
-rw-r--r--endpoint_state.go31
-rw-r--r--endpoint_ws.go7
-rw-r--r--errors.go80
-rw-r--r--main.go12
-rw-r--r--sethandler.go46
10 files changed, 249 insertions, 424 deletions
diff --git a/endpoint_auth.go b/endpoint_auth.go
index 85acab8..4a9c770 100644
--- a/endpoint_auth.go
+++ b/endpoint_auth.go
@@ -81,48 +81,31 @@ func generateAuthorizationURL() (string, error) {
* Expects JSON Web Keys to be already set up correctly; if myKeyfunc is null,
* a null pointer is dereferenced and the thread panics.
*/
-func handleAuth(w http.ResponseWriter, req *http.Request) {
+func handleAuth(w http.ResponseWriter, req *http.Request) (string, int, error) {
if req.Method != http.MethodPost {
- wstr(
- w,
- http.StatusMethodNotAllowed,
- "Only POST is supported on the authentication endpoint",
- )
- return
+ return "", http.StatusMethodNotAllowed, errPostOnly
}
err := req.ParseForm()
if err != nil {
- wstr(w, http.StatusBadRequest, "Malformed form data")
- return
+ return "", http.StatusBadRequest, wrapError(errMalformedForm, err)
}
returnedError := req.PostFormValue("error")
if returnedError != "" {
returnedErrorDescription := req.PostFormValue("error_description")
- if returnedErrorDescription == "" {
- wstr(
- w,
- http.StatusBadRequest,
- fmt.Sprintf(
- "authorize endpoint returned error: %v",
- returnedErrorDescription,
- ),
- )
- return
- }
- wstr(w, http.StatusBadRequest, fmt.Sprintf(
- "%s: %s",
- returnedError,
- returnedErrorDescription,
- ))
- return
+ return "", http.StatusUnauthorized, wrapAny(
+ errAuthorizeEndpointError,
+ returnedError+": "+returnedErrorDescription,
+ )
}
idTokenString := req.PostFormValue("id_token")
if idTokenString == "" {
- wstr(w, http.StatusBadRequest, "Missing id_token")
- return
+ return "", http.StatusBadRequest, wrapAny(
+ errInsufficientFields,
+ "id_token",
+ )
}
claimsTemplate := &msclaimsT{} //exhaustruct:ignore
@@ -132,79 +115,68 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
myKeyfunc.Keyfunc,
)
if err != nil {
- wstr(w, http.StatusBadRequest, "Cannot parse claims")
- return
+ return "", http.StatusBadRequest, wrapError(
+ errCannotParseClaims,
+ err,
+ )
}
switch {
case token.Valid:
break
case errors.Is(err, jwt.ErrTokenMalformed):
- wstr(w, http.StatusBadRequest, "Malformed JWT token")
- return
+ return "", http.StatusBadRequest, wrapError(
+ errJWTMalformed,
+ err,
+ )
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
- wstr(w, http.StatusBadRequest, "Invalid JWS signature")
- return
+ return "", http.StatusBadRequest, wrapError(
+ errJWTSignatureInvalid,
+ err,
+ )
case errors.Is(err, jwt.ErrTokenExpired) ||
errors.Is(err, jwt.ErrTokenNotValidYet):
- wstr(
- w,
- http.StatusBadRequest,
- "JWT token expired or not yet valid",
+ return "", http.StatusBadRequest, wrapError(
+ errJWTExpired,
+ err,
)
- return
default:
- wstr(w, http.StatusBadRequest, "Unhandled JWT token error")
- return
+ return "", http.StatusBadRequest, wrapError(
+ errJWTInvalid,
+ err,
+ )
}
claims, claimsOk := token.Claims.(*msclaimsT)
if !claimsOk {
- wstr(w, http.StatusBadRequest, "Cannot unpack claims")
- return
+ return "", http.StatusBadRequest, errCannotUnpackClaims
}
authorizationCode := req.PostFormValue("code")
accessToken, err := getAccessToken(req.Context(), authorizationCode)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- fmt.Sprintf("Unable to fetch access token: %v", err),
- )
- return
+ return "", http.StatusInternalServerError, err
}
department, err := getDepartment(req.Context(), *(accessToken.Content))
if err != nil {
- wstr(w, http.StatusInternalServerError, err.Error())
- return
+ return "", http.StatusInternalServerError, err
}
switch {
- case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" ||
- department == "High School Teaching & Learning 高中教学部门":
+ case department == "SJ Co-Curricular Activities Office 松江课外项目办公室":
department = staffDepartment
case department == "Y9" || department == "Y10" ||
department == "Y11" || department == "Y12":
default:
- wstr(
- w,
- http.StatusForbidden,
- fmt.Sprintf(
- "Your department \"%s\" is unknown.\nWe currently only allow Y9, Y10, Y11, Y12, and the CCA office.",
- department,
- ),
- )
- return
+ return "", http.StatusForbidden, errUnknownDepartment
}
cookieValue, err := randomString(tokenLength)
if err != nil {
- wstr(w, http.StatusInternalServerError, err.Error())
- return
+ return "", http.StatusInternalServerError, err
}
now := time.Now()
@@ -246,24 +218,16 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
claims.Oid,
)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Database error while updating account.",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
} else {
- wstr(
- w,
- http.StatusInternalServerError,
- "Database error while writing account info.",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
}
http.Redirect(w, req, "/", http.StatusSeeOther)
+
+ return "", -1, nil
}
func setupJwks() error {
diff --git a/endpoint_export_choices.go b/endpoint_export_choices.go
index 890f170..80734b2 100644
--- a/endpoint_export_choices.go
+++ b/endpoint_export_choices.go
@@ -27,7 +27,7 @@ import (
"strings"
)
-func handleExportChoices(w http.ResponseWriter, req *http.Request) {
+func handleExportChoices(w http.ResponseWriter, req *http.Request) (string, int, error) {
_, _, department, err := getUserInfoFromRequest(req)
if err != nil {
wstr(
@@ -37,12 +37,7 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) {
)
}
if department != staffDepartment {
- wstr(
- w,
- http.StatusForbidden,
- "You are not authorized to view this page",
- )
- return
+ return "", http.StatusForbidden, errStaffOnly
}
type userCacheT struct {
@@ -54,24 +49,14 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) {
rows, err := db.Query(req.Context(), "SELECT userid, courseid FROM choices")
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
output := make([][]string, 0)
for {
if !rows.Next() {
err := rows.Err()
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
break
}
@@ -79,12 +64,7 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) {
var currentCourseID int
err := rows.Scan(&currentUserID, &currentCourseID)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
currentUserCache, ok := userCacheMap[currentUserID]
if ok {
@@ -103,12 +83,7 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) {
&currentDepartment,
)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
before, _, found := strings.Cut(currentUserEmail, "@")
if found {
@@ -125,24 +100,14 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) {
_course, ok := courses.Load(currentCourseID)
if !ok {
- wstr(
- w,
- http.StatusInternalServerError,
- "Reference to non-existent course",
- )
- return
+ return "", http.StatusInternalServerError, wrapAny(errNoSuchCourse, currentCourseID)
}
course, ok := _course.(*courseT)
if !ok {
panic("courses map has non-\"*courseT\" items")
}
if course == nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Course is nil",
- )
- return
+ return "", http.StatusInternalServerError, wrapAny(errNoSuchCourse, currentCourseID)
}
output = append(
output,
@@ -171,29 +136,15 @@ func handleExportChoices(w http.ResponseWriter, req *http.Request) {
"Course ID",
})
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error writing output",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errHTTPWrite, err)
}
err = csvWriter.WriteAll(output)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error writing output",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errHTTPWrite, err)
}
csvWriter.Flush()
if csvWriter.Error() != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error occurred flushing output",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errHTTPWrite, err)
}
+ return "", -1, nil
}
diff --git a/endpoint_export_students.go b/endpoint_export_students.go
index 75f9150..932f7fd 100644
--- a/endpoint_export_students.go
+++ b/endpoint_export_students.go
@@ -22,49 +22,29 @@ package main
import (
"encoding/csv"
- "fmt"
"net/http"
"strconv"
)
-func handleExportStudents(w http.ResponseWriter, req *http.Request) {
+func handleExportStudents(w http.ResponseWriter, req *http.Request) (string, int, error) {
_, _, department, err := getUserInfoFromRequest(req)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- fmt.Sprintf("Error: %v", err),
- )
+ return "", http.StatusInternalServerError, err
}
if department != staffDepartment {
- wstr(
- w,
- http.StatusForbidden,
- "You are not authorized to view this page",
- )
- return
+ return "", http.StatusInternalServerError, errStaffOnly
}
rows, err := db.Query(req.Context(), "SELECT name, email, department, confirmed FROM users")
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
output := make([][]string, 0)
for {
if !rows.Next() {
err := rows.Err()
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
break
}
@@ -77,12 +57,7 @@ func handleExportStudents(w http.ResponseWriter, req *http.Request) {
&currentConfirmed,
)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return
+ return "", http.StatusInternalServerError, wrapError(errUnexpectedDBError, err)
}
if currentDepartment == staffDepartment {
@@ -113,29 +88,16 @@ func handleExportStudents(w http.ResponseWriter, req *http.Request) {
"Course ID",
})
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error writing output",
- )
- return
+ return "", http.StatusInternalServerError, errHTTPWrite
}
err = csvWriter.WriteAll(output)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error writing output",
- )
- return
+ return "", http.StatusInternalServerError, errHTTPWrite
}
csvWriter.Flush()
if csvWriter.Error() != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error occurred flushing output",
- )
- return
+ return "", http.StatusInternalServerError, errHTTPWrite
}
+
+ return "", -1, nil
}
diff --git a/endpoint_index.go b/endpoint_index.go
index dc81275..7376f76 100644
--- a/endpoint_index.go
+++ b/endpoint_index.go
@@ -22,23 +22,17 @@ package main
import (
"errors"
- "fmt"
"log"
"net/http"
"sync/atomic"
)
-func handleIndex(w http.ResponseWriter, req *http.Request) {
+func handleIndex(w http.ResponseWriter, req *http.Request) (string, int, error) {
_, username, department, err := getUserInfoFromRequest(req)
if errors.Is(err, errNoCookie) || errors.Is(err, errNoSuchUser) {
authURL, err2 := generateAuthorizationURL()
if err2 != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Cannot generate authorization URL",
- )
- return
+ return "", -1, err2
}
var noteString string
if errors.Is(err, errNoSuchUser) {
@@ -57,10 +51,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
)
if err2 != nil {
log.Println(err2)
+ return "", -1, wrapError(errCannotWriteTemplate, err2)
}
- return
+ return "", -1, nil
} else if err != nil {
- wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: %v", err))
+ return "", -1, err
}
/* TODO: The below should be completed on-update. */
@@ -106,9 +101,9 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
},
)
if err != nil {
- log.Println(err)
+ return "", -1, wrapError(errCannotWriteTemplate, err)
}
- return
+ return "", -1, nil
}
if atomic.LoadUint32(&state) == 0 {
@@ -124,17 +119,17 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
},
)
if err != nil {
- log.Println(err)
+ return "", -1, wrapError(errCannotWriteTemplate, err)
}
- return
+ return "", -1, nil
}
sportRequired, err := getCourseTypeMinimumForYearGroup(department, sport)
if err != nil {
- wstr(w, http.StatusInternalServerError, "Failed to get sport requirement")
+ return "", -1, err
}
nonSportRequired, err := getCourseTypeMinimumForYearGroup(department, nonSport)
if err != nil {
- wstr(w, http.StatusInternalServerError, "Failed to get non-sport requirement")
+ return "", -1, err
}
err = tmpl.ExecuteTemplate(
@@ -159,7 +154,7 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
},
)
if err != nil {
- log.Println(err)
- return
+ return "", -1, wrapError(errCannotWriteTemplate, err)
}
+ return "", -1, nil
}
diff --git a/endpoint_newcourses.go b/endpoint_newcourses.go
index 9f26b73..e11ed1e 100644
--- a/endpoint_newcourses.go
+++ b/endpoint_newcourses.go
@@ -33,84 +33,44 @@ import (
"github.com/jackc/pgx/v5"
)
-func handleNewCourses(w http.ResponseWriter, req *http.Request) {
+func handleNewCourses(w http.ResponseWriter, req *http.Request) (string, int, error) {
if req.Method != http.MethodPost {
- wstr(w, http.StatusMethodNotAllowed, "Only POST is allowed here")
- return
+ return "", http.StatusMethodNotAllowed, errPostOnly
}
_, _, department, err := getUserInfoFromRequest(req)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- fmt.Sprintf("Error: %v", err),
- )
+ return "", -1, err
}
if department != staffDepartment {
- wstr(
- w,
- http.StatusForbidden,
- "You are not authorized to view this page",
- )
- return
+ return "", http.StatusForbidden, errStaffOnly
}
if atomic.LoadUint32(&state) != 0 {
- wstr(
- w,
- http.StatusBadRequest,
- "Uploading the course table is only supported when student-access is disabled",
- )
- return
+ return "", http.StatusBadRequest, errDisableStudentAccessFirst
}
/* TODO: Potential race. The global state may need to be write-locked. */
file, fileHeader, err := req.FormFile("coursecsv")
if err != nil {
- wstr(
- w,
- http.StatusBadRequest,
- "Failed loading file from request... did you select a file before hitting that red button?",
- )
- return
+ return "", http.StatusBadRequest, wrapError(errFormNoFile, err)
}
if fileHeader.Header.Get("Content-Type") != "text/csv" {
- wstr(
- w,
- http.StatusBadRequest,
- "Does not look like a proper CSV file",
- )
- return
+ return "", http.StatusBadRequest, errNotACSV
}
csvReader := csv.NewReader(file)
titleLine, err := csvReader.Read()
if err != nil {
- wstr(
- w,
- http.StatusBadRequest,
- "Error reading CSV",
- )
- return
+ return "", http.StatusBadRequest, wrapError(errCannotReadCSV, err)
}
if titleLine == nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected nil titleLine slice",
- )
- return
+ return "", -1, errUnexpectedNilCSVLine
}
if len(titleLine) != 8 {
- wstr(
- w,
- http.StatusBadRequest,
- "First line has more than 8 elements",
- )
- return
+ return "", -1, wrapAny(errBadCSVFormat, "expecting 8 fields on the first line")
}
var titleIndex, maxIndex, teacherIndex, locationIndex,
typeIndex, groupIndex, sectionIDIndex,
@@ -136,66 +96,41 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) {
}
}
- {
- check := func(indexName string, indexNum int) bool {
- if indexNum == -1 {
- wstr(
- w,
- http.StatusBadRequest,
- fmt.Sprintf(
- "Missing column \"%s\"",
- indexName,
- ),
- )
- return true
- }
- return false
- }
-
- if check("Title", titleIndex) {
- return
- }
- if check("Max", maxIndex) {
- return
- }
- if check("Teacher", teacherIndex) {
- return
- }
- if check("Location", locationIndex) {
- return
- }
- if check("Type", typeIndex) {
- return
- }
- if check("Group", groupIndex) {
- return
- }
- if check("Course ID", courseIDIndex) {
- return
- }
- if check("Section ID", sectionIDIndex) {
- return
- }
+ if titleIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Title")
+ }
+ if maxIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Max")
+ }
+ if teacherIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Teacher")
+ }
+ if locationIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Location")
+ }
+ if typeIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Type")
+ }
+ if groupIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Group")
+ }
+ if courseIDIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Course ID")
+ }
+ if sectionIDIndex == -1 {
+ return "", http.StatusBadRequest, wrapAny(errMissingCSVColumn, "Section ID")
}
lineNumber := 1
- ok := func(ctx context.Context) bool {
+ ok, statusCode, err := func(ctx context.Context) (retBool bool, retStatus int, retErr error) {
tx, err := db.Begin(ctx)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
+ return false, -1, wrapError(errUnexpectedDBError, err)
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && (!errors.Is(err, pgx.ErrTxClosed)) {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
+ retBool, retStatus, retErr = false, -1, wrapError(errUnexpectedDBError, err)
return
}
}()
@@ -204,22 +139,14 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) {
"DELETE FROM choices",
)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
+ return false, -1, wrapError(errUnexpectedDBError, err)
}
_, err = tx.Exec(
ctx,
"DELETE FROM courses",
)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
+ return false, -1, wrapError(errUnexpectedDBError, err)
}
for {
@@ -229,57 +156,36 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) {
if errors.Is(err, io.EOF) {
break
}
- wstr(
- w,
- http.StatusInternalServerError,
- "Error reading CSV",
- )
- return false
+ return false, -1, wrapError(errCannotReadCSV, err)
}
if line == nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected nil line",
- )
- return false
+ return false, -1, wrapError(errCannotReadCSV, errUnexpectedNilCSVLine)
}
if len(line) != 8 {
- wstr(
- w,
- http.StatusBadRequest,
- fmt.Sprintf(
- "Line %d has insufficient items",
- lineNumber,
- ),
- )
- return false
+ return false, -1, wrapAny(errInsufficientFields, fmt.Sprintf(
+ "line %d has insufficient items",
+ lineNumber,
+ ))
}
if !checkCourseType(line[typeIndex]) {
- wstr(
- w,
- http.StatusBadRequest,
+ return false, -1, wrapAny(errInvalidCourseType,
fmt.Sprintf(
- "Line %d has invalid course type \"%s\"\nAllowed course types: %s",
+ "line %d has invalid course type \"%s\"\nallowed course types: %s",
lineNumber,
line[typeIndex],
strings.Join(getKeysOfMap(courseTypes), ", "),
),
)
- return false
}
if !checkCourseGroup(line[groupIndex]) {
- wstr(
- w,
- http.StatusBadRequest,
+ return false, -1, wrapAny(errInvalidCourseGroup,
fmt.Sprintf(
- "Line %d has invalid course group \"%s\"\nAllowed course groups: %s",
+ "line %d has invalid course group \"%s\"\nallowed course groups: %s",
lineNumber,
line[groupIndex],
strings.Join(getKeysOfMap(courseGroups), ", "),
),
)
- return false
}
_, err = tx.Exec(
ctx,
@@ -294,39 +200,26 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) {
line[courseIDIndex],
)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return false
+ return false, -1, wrapError(errUnexpectedDBError, err)
}
}
err = tx.Commit(ctx)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Unexpected database error",
- )
- return false
+ return false, -1, wrapError(errUnexpectedDBError, err)
}
- return true
+ return true, -1, nil
}(req.Context())
if !ok {
- return
+ return "", statusCode, err
}
courses.Clear()
err = setupCourses(req.Context())
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Error setting up course table again, the data might be corrupted!",
- )
- return
+ return "", -1, wrapError(errWhileSetttingUpCourseTablesAgain, err)
}
http.Redirect(w, req, "/", http.StatusSeeOther)
+
+ return "", -1, nil
}
diff --git a/endpoint_state.go b/endpoint_state.go
index 4e0ee56..ecba26a 100644
--- a/endpoint_state.go
+++ b/endpoint_state.go
@@ -21,48 +21,29 @@
package main
import (
- "fmt"
"net/http"
"strconv"
)
-func handleState(w http.ResponseWriter, req *http.Request) {
+func handleState(w http.ResponseWriter, req *http.Request) (string, int, error) {
_, _, department, err := getUserInfoFromRequest(req)
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- fmt.Sprintf("Error: %v", err),
- )
+ return "", http.StatusUnauthorized, err
}
if department != staffDepartment {
- wstr(
- w,
- http.StatusForbidden,
- "You are not authorized to view this page",
- )
- return
+ return "", http.StatusForbidden, errStaffOnly
}
basePath := req.PathValue("s")
newState, err := strconv.ParseUint(basePath, 10, 32)
if err != nil {
- wstr(
- w,
- http.StatusBadRequest,
- "State must be an unsigned 32-bit integer",
- )
- return
+ return "", http.StatusBadRequest, wrapError(errInvalidState, err)
}
err = setState(req.Context(), uint32(newState))
if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Failed setting state, please return to previous page; are you sure it's within limits?",
- )
- return
+ return "", http.StatusBadRequest, wrapError(errCannotSetState, err)
}
http.Redirect(w, req, "/", http.StatusSeeOther)
+ return "", -1, nil
}
diff --git a/endpoint_ws.go b/endpoint_ws.go
index a46ea27..27dead6 100644
--- a/endpoint_ws.go
+++ b/endpoint_ws.go
@@ -45,7 +45,7 @@ func handleWs(w http.ResponseWriter, req *http.Request) {
wstr(
w,
http.StatusBadRequest,
- "This endpoint only supports valid WebSocket connections.",
+ "this endpoint only supports valid WebSocket connections: "+err.Error(),
)
return
}
@@ -64,7 +64,10 @@ func handleWs(w http.ResponseWriter, req *http.Request) {
err = handleConn(req.Context(), c, userID, department)
if err != nil {
- log.Println(err)
+ err := writeText(req.Context(), c, "E :"+err.Error())
+ if err != nil {
+ log.Println(err)
+ }
return
}
}
diff --git a/errors.go b/errors.go
index 257c935..3aa92f4 100644
--- a/errors.go
+++ b/errors.go
@@ -26,31 +26,54 @@ import (
)
var (
- errCannotSetupJwks = errors.New("cannot set up jwks")
- errInsufficientFields = errors.New("insufficient fields")
- errCannotGetDepartment = errors.New("cannot get department")
- errCannotFetchAccessToken = errors.New("cannot fetch access token")
- errTokenEndpointReturnedError = errors.New("token endpoint returned error")
- errCannotProcessConfig = errors.New("cannot process configuration file")
- errCannotOpenConfig = errors.New("cannot open configuration file")
- errCannotDecodeConfig = errors.New("cannot decode configuration file")
- errMissingConfigValue = errors.New("missing configuration value")
- errInvalidCourseType = errors.New("invalid course type")
- errInvalidCourseGroup = errors.New("invalid course group")
- errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user")
- errUnsupportedDatabaseType = errors.New("unsupported db type")
- errUnexpectedDBError = errors.New("unexpected database error")
- errCannotSend = errors.New("cannot send")
- errCannotGenerateRandomString = errors.New("cannot generate random string")
- errContextCanceled = errors.New("context canceled")
- errCannotReceiveMessage = errors.New("cannot receive message")
- errNoSuchCourse = errors.New("no such course")
- errInvalidState = errors.New("invalid state")
- errWebSocketWrite = errors.New("error writing to websocket")
- errCannotCheckCookie = errors.New("error checking cookie")
- errNoCookie = errors.New("no cookie found")
- errNoSuchUser = errors.New("no such user")
- errNoSuchYearGroup = errors.New("no such year group")
+ errCannotSetupJwks = errors.New("cannot set up jwks")
+ errInsufficientFields = errors.New("insufficient fields")
+ errCannotGetDepartment = errors.New("cannot get department")
+ errUnknownDepartment = errors.New("unknown department")
+ errCannotFetchAccessToken = errors.New("cannot fetch access token")
+ errTokenEndpointReturnedError = errors.New("token endpoint returned error")
+ errCannotProcessConfig = errors.New("cannot process configuration file")
+ errCannotOpenConfig = errors.New("cannot open configuration file")
+ errCannotDecodeConfig = errors.New("cannot decode configuration file")
+ errMissingConfigValue = errors.New("missing configuration value")
+ errInvalidCourseType = errors.New("invalid course type")
+ errInvalidCourseGroup = errors.New("invalid course group")
+ errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user")
+ errUnsupportedDatabaseType = errors.New("unsupported db type")
+ errUnexpectedDBError = errors.New("unexpected database error")
+ errCannotSend = errors.New("cannot send")
+ errCannotGenerateRandomString = errors.New("cannot generate random string")
+ errContextCanceled = errors.New("context canceled")
+ errCannotReceiveMessage = errors.New("cannot receive message")
+ errNoSuchCourse = errors.New("reference to non-existent course")
+ errInvalidState = errors.New("invalid state")
+ errCannotSetState = errors.New("cannot set state")
+ errWebSocketWrite = errors.New("error writing to websocket")
+ errHTTPWrite = errors.New("error writing to http writer")
+ errCannotCheckCookie = errors.New("error checking cookie")
+ errNoCookie = errors.New("no cookie found")
+ errNoSuchUser = errors.New("no such user")
+ errNoSuchYearGroup = errors.New("no such year group")
+ errPostOnly = errors.New("only post is supported on this endpoint")
+ errMalformedForm = errors.New("malformed form")
+ errAuthorizeEndpointError = errors.New("authorize endpoint returned error")
+ errCannotParseClaims = errors.New("cannot parse claims")
+ errCannotUnpackClaims = errors.New("cannot unpack claims")
+ errJWTMalformed = errors.New("jwt token is malformed")
+ errJWTSignatureInvalid = errors.New("jwt token has invalid signature")
+ errJWTExpired = errors.New("jwt token has expired or is not yet valid")
+ errJWTInvalid = errors.New("jwt token is somehow invalid")
+ errStaffOnly = errors.New("this page is only available to staff")
+ errDisableStudentAccessFirst = errors.New("you must disable student access before performing this operation")
+ errFormNoFile = errors.New("you need to select a file before submitting the form")
+ errNotACSV = errors.New("the file you uploaded is not a csv file")
+ errCannotReadCSV = errors.New("cannot read csv")
+ errBadCSVFormat = errors.New("bad csv format")
+ errMissingCSVColumn = errors.New("missing csv column")
+ errUnexpectedNilCSVLine = errors.New("unexpected nil csv line")
+ errWhileSetttingUpCourseTablesAgain = errors.New("error while setting up course tables again")
+ errCannotWriteTemplate = errors.New("cannot write template")
+ // errInvalidCourseID = errors.New("invalid course id")
)
func wrapError(a, b error) error {
@@ -59,3 +82,10 @@ func wrapError(a, b error) error {
}
return fmt.Errorf("%w: %w", a, b)
}
+
+func wrapAny(a error, b any) error {
+ if a == nil && b == nil {
+ return nil
+ }
+ return fmt.Errorf("%w: %v", a, b)
+}
diff --git a/main.go b/main.go
index 4afac06..00c8ae7 100644
--- a/main.go
+++ b/main.go
@@ -116,13 +116,13 @@ func main() {
)
log.Println("Registering handlers")
- http.HandleFunc("/{$}", handleIndex)
- http.HandleFunc("/export/choices", handleExportChoices)
- http.HandleFunc("/export/students", handleExportStudents)
- http.HandleFunc("/auth", handleAuth)
http.HandleFunc("/ws", handleWs)
- http.HandleFunc("/state/{s}", handleState)
- http.HandleFunc("/newcourses", handleNewCourses)
+ setHandler("/{$}", handleIndex)
+ setHandler("/export/choices", handleExportChoices)
+ setHandler("/export/students", handleExportStudents)
+ setHandler("/auth", handleAuth)
+ setHandler("/state/{s}", handleState)
+ setHandler("/newcourses", handleNewCourses)
var l net.Listener
diff --git a/sethandler.go b/sethandler.go
new file mode 100644
index 0000000..4022d5c
--- /dev/null
+++ b/sethandler.go
@@ -0,0 +1,46 @@
+/*
+ * HTTP handler setting
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "net/http"
+)
+
+func setHandler(pattern string, handler func(http.ResponseWriter, *http.Request) (string, int, error)) {
+ http.HandleFunc(pattern, func(w http.ResponseWriter, req *http.Request) {
+ msg, statusCode, err := handler(w, req)
+ if err != nil {
+ if statusCode == -1 || statusCode == 0 {
+ statusCode = 500
+ }
+ if msg != "" {
+ wstr(w, statusCode, msg+"\n"+err.Error())
+ } else {
+ wstr(w, statusCode, err.Error())
+ }
+ } else {
+ if statusCode == -1 || statusCode == 0 {
+ statusCode = 200
+ }
+ wstr(w, statusCode, msg)
+ }
+ })
+}