summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--auth.go97
-rw-r--r--config.go56
-rw-r--r--courses.go68
-rw-r--r--index.go21
-rw-r--r--main.go75
-rw-r--r--wsc.go67
-rw-r--r--wsh.go22
-rw-r--r--wsm.go103
-rw-r--r--wsp.go17
9 files changed, 431 insertions, 95 deletions
diff --git a/auth.go b/auth.go
index 9ef1254..1fd6086 100644
--- a/auth.go
+++ b/auth.go
@@ -96,7 +96,11 @@ func generateAuthorizationURL() (string, error) {
*/
func handleAuth(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
- wstr(w, http.StatusMethodNotAllowed, "Only POST is supported on the authentication endpoint")
+ wstr(
+ w,
+ http.StatusMethodNotAllowed,
+ "Only POST is supported on the authentication endpoint",
+ )
return
}
@@ -110,7 +114,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
if returnedError != "" {
returnedErrorDescription := req.PostFormValue("error_description")
if returnedErrorDescription == "" {
- wstr(w, http.StatusBadRequest, fmt.Sprintf("authorize endpoint returned error: %v", returnedErrorDescription))
+ wstr(
+ w,
+ http.StatusBadRequest,
+ fmt.Sprintf(
+ "authorize endpoint returned error: %v",
+ returnedErrorDescription,
+ ),
+ )
return
}
wstr(w, http.StatusBadRequest, fmt.Sprintf(
@@ -149,7 +160,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
return
case errors.Is(err, jwt.ErrTokenExpired) ||
errors.Is(err, jwt.ErrTokenNotValidYet):
- wstr(w, http.StatusBadRequest, "JWT token expired or not yet valid")
+ wstr(
+ w,
+ http.StatusBadRequest,
+ "JWT token expired or not yet valid",
+ )
return
default:
wstr(w, http.StatusBadRequest, "Unhandled JWT token error")
@@ -167,7 +182,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
accessToken, err := getAccessToken(req.Context(), authorizationCode)
if err != nil {
- wstr(w, http.StatusInternalServerError, fmt.Sprintf("Unable to fetch access token: %v", err))
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ fmt.Sprintf("Unable to fetch access token: %v", err),
+ )
return
}
@@ -178,9 +197,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
}
switch {
- case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || department == "High School Teaching & Learning 高中教学部门":
+ case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" ||
+ department == "High School Teaching & Learning 高中教学部门":
department = "Staff"
- case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12":
+ case department == "Y9" || department == "Y10" ||
+ department == "Y11" || department == "Y12":
default:
wstr(
w,
@@ -245,11 +266,19 @@ func handleAuth(w http.ResponseWriter, req *http.Request) {
claims.Oid,
)
if err != nil {
- wstr(w, http.StatusInternalServerError, "Database error while updating account.")
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Database error while updating account.",
+ )
return
}
} else {
- wstr(w, http.StatusInternalServerError, "Database error while writing account info.")
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Database error while writing account info.",
+ )
return
}
}
@@ -312,7 +341,11 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
* "department" field, which hopefully doesn't occur as we
* have specified $select=department in the OData query.
*/
- return "", fmt.Errorf("error getting department: %w", errInsufficientFields)
+ return "",
+ fmt.Errorf(
+ "error getting department: %w",
+ errInsufficientFields,
+ )
}
return *(departmentWrap.Department), nil
@@ -322,7 +355,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
* TODO: Access token expiration is not checked anywhere.
*/
type accessTokenT struct {
- OriginalExpiresIn *int `json:"expires_in"` /* Original time to expiration */
+ OriginalExpiresIn *int `json:"expires_in"` /* Original time to expr */
Expiration time.Time
Content *string `json:"access_token"`
Error *string `json:"error"`
@@ -334,7 +367,10 @@ type accessTokenT struct {
* Obtain an access token from the token endpoint with an existing
* authorization code.
*/
-func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT, error) {
+func getAccessToken(
+ ctx context.Context,
+ authorizationCode string,
+) (accessTokenT, error) {
var accessToken accessTokenT
t := time.Now()
v := url.Values{}
@@ -344,31 +380,52 @@ func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT
v.Set("redirect_uri", config.URL+"/auth")
v.Set("grant_type", "authorization_code")
v.Set("client_secret", config.Auth.Secret)
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, config.Auth.Token, strings.NewReader(v.Encode()))
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodPost,
+ config.Auth.Token,
+ strings.NewReader(v.Encode()),
+ )
if err != nil {
- return accessToken, fmt.Errorf("error making access token request: %w", err)
+ return accessToken,
+ fmt.Errorf("error making access token request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
- return accessToken, fmt.Errorf("error requesting access token: %w", err)
+ return accessToken,
+ fmt.Errorf("error requesting access token: %w", err)
}
defer resp.Body.Close()
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&accessToken)
if err != nil {
- return accessToken, fmt.Errorf("error decoding access token: %w", err)
+ return accessToken,
+ fmt.Errorf("error decoding access token: %w", err)
}
- if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil {
- if accessToken.Error == nil || accessToken.ErrorCodes == nil || accessToken.ErrorDescription == nil {
+ if accessToken.Error != nil || accessToken.ErrorCodes != nil ||
+ accessToken.ErrorDescription != nil {
+ if accessToken.Error == nil || accessToken.ErrorCodes == nil ||
+ accessToken.ErrorDescription == nil {
return accessToken, errAccessTokenIncompleteError
}
- return accessToken, fmt.Errorf("%w: %v", errTokenEndpointReturnedError, *accessToken.ErrorDescription)
+ return accessToken,
+ fmt.Errorf(
+ "%w: %v",
+ errTokenEndpointReturnedError,
+ *accessToken.ErrorDescription,
+ )
}
if accessToken.Content == nil || accessToken.OriginalExpiresIn == nil {
- return accessToken, fmt.Errorf("error extracting access token: %w", errInsufficientFields)
+ return accessToken,
+ fmt.Errorf(
+ "error extracting access token: %w",
+ errInsufficientFields,
+ )
}
- accessToken.Expiration = t.Add(time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second)
+ accessToken.Expiration = t.Add(
+ time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second,
+ )
return accessToken, nil
}
diff --git a/config.go b/config.go
index 0f072f1..682bcd3 100644
--- a/config.go
+++ b/config.go
@@ -119,13 +119,21 @@ func fetchConfig(path string) (retErr error) {
if v := recover(); v != nil {
s, ok := v.(error)
if ok {
- retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, s)
+ retErr = fmt.Errorf(
+ "%w: %w",
+ errCannotProcessConfig,
+ s,
+ )
}
retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v)
return
}
if retErr != nil {
- retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, retErr)
+ retErr = fmt.Errorf(
+ "%w: %w",
+ errCannotProcessConfig,
+ retErr,
+ )
return
}
}()
@@ -172,12 +180,18 @@ func fetchConfig(path string) (retErr error) {
if config.Listen.Trans == "tls" {
if configWithPointers.Listen.TLS.Cert == nil {
- return fmt.Errorf("%w: listen.tls.cert", errMissingConfigValue)
+ return fmt.Errorf(
+ "%w: listen.tls.cert",
+ errMissingConfigValue,
+ )
}
config.Listen.TLS.Cert = *(configWithPointers.Listen.TLS.Cert)
if configWithPointers.Listen.TLS.Key == nil {
- return fmt.Errorf("%w: listen.tls.key", errMissingConfigValue)
+ return fmt.Errorf(
+ "%w: listen.tls.key",
+ errMissingConfigValue,
+ )
}
config.Listen.TLS.Key = *(configWithPointers.Listen.TLS.Key)
}
@@ -201,11 +215,19 @@ func fetchConfig(path string) (retErr error) {
/* It's okay to set it to 0 in production */
case 4712, 9080: /* Don't use them unless you know what you're doing */
if config.Prod {
- return fmt.Errorf("%w: fake authentication is incompatible with production mode", errIllegalConfig)
+ return fmt.Errorf(
+ "%w: fake authentication is incompatible with production mode",
+ errIllegalConfig,
+ )
}
- log.Println("!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.")
+ log.Println(
+ "!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.",
+ )
default:
- return fmt.Errorf("%w: invalid option for auth.fake", errIllegalConfig)
+ return fmt.Errorf(
+ "%w: invalid option for auth.fake",
+ errIllegalConfig,
+ )
}
}
@@ -240,22 +262,34 @@ func fetchConfig(path string) (retErr error) {
config.Auth.Expr = *(configWithPointers.Auth.Expr)
if configWithPointers.Perf.MessageArgumentsCap == nil {
- return fmt.Errorf("%w: perf.msg_args_cap", errMissingConfigValue)
+ return fmt.Errorf(
+ "%w: perf.msg_args_cap",
+ errMissingConfigValue,
+ )
}
config.Perf.MessageArgumentsCap = *(configWithPointers.Perf.MessageArgumentsCap)
if configWithPointers.Perf.MessageBytesCap == nil {
- return fmt.Errorf("%w: perf.msg_bytes_cap", errMissingConfigValue)
+ return fmt.Errorf(
+ "%w: perf.msg_bytes_cap",
+ errMissingConfigValue,
+ )
}
config.Perf.MessageBytesCap = *(configWithPointers.Perf.MessageBytesCap)
if configWithPointers.Perf.ReadHeaderTimeout == nil {
- return fmt.Errorf("%w: perf.read_header_timeout", errMissingConfigValue)
+ return fmt.Errorf(
+ "%w: perf.read_header_timeout",
+ errMissingConfigValue,
+ )
}
config.Perf.ReadHeaderTimeout = *(configWithPointers.Perf.ReadHeaderTimeout)
if configWithPointers.Perf.UsemDelayShiftBits == nil {
- return fmt.Errorf("%w: perf.usem_delay_shift_bits", errMissingConfigValue)
+ return fmt.Errorf(
+ "%w: perf.usem_delay_shift_bits",
+ errMissingConfigValue,
+ )
}
config.Perf.UsemDelayShiftBits = *(configWithPointers.Perf.UsemDelayShiftBits)
diff --git a/courses.go b/courses.go
index 17fa311..4f4d0cf 100644
--- a/courses.go
+++ b/courses.go
@@ -143,7 +143,10 @@ func setupCourses() error {
if !rows.Next() {
err := rows.Err()
if err != nil {
- return fmt.Errorf("error fetching courses: %w", err)
+ return fmt.Errorf(
+ "error fetching courses: %w",
+ err,
+ )
}
break
}
@@ -163,17 +166,30 @@ func setupCourses() error {
return fmt.Errorf("error fetching courses: %w", err)
}
if !checkCourseType(currentCourse.Type) {
- return fmt.Errorf("%w: %d %s", errInvalidCourseType, currentCourse.ID, currentCourse.Type)
+ return fmt.Errorf(
+ "%w: %d %s",
+ errInvalidCourseType,
+ currentCourse.ID,
+ currentCourse.Type,
+ )
}
if !checkCourseGroup(currentCourse.Group) {
- return fmt.Errorf("%w: %d %s", errInvalidCourseGroup, currentCourse.ID, currentCourse.Group)
+ return fmt.Errorf(
+ "%w: %d %s",
+ errInvalidCourseGroup,
+ currentCourse.ID,
+ currentCourse.Group,
+ )
}
err := db.QueryRow(context.Background(),
"SELECT COUNT (*) FROM choices WHERE courseid = $1",
currentCourse.ID,
).Scan(&currentCourse.Selected)
if err != nil {
- return fmt.Errorf("error querying course member number: %w", err)
+ return fmt.Errorf(
+ "error querying course member number: %w",
+ err,
+ )
}
courses[currentCourse.ID] = &currentCourse
}
@@ -183,23 +199,40 @@ func setupCourses() error {
type userCourseGroupsT map[courseGroupT]bool
-func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseGroupsT, userID string) error {
- rows, err := db.Query(ctx, "SELECT courseid FROM choices WHERE userid = $1", userID)
+func populateUserCourseGroups(
+ ctx context.Context,
+ userCourseGroups *userCourseGroupsT,
+ userID string,
+) error {
+ rows, err := db.Query(
+ ctx,
+ "SELECT courseid FROM choices WHERE userid = $1",
+ userID,
+ )
if err != nil {
- return fmt.Errorf("error querying user's choices while populating course groups: %w", err)
+ return fmt.Errorf(
+ "error querying user's choices while populating course groups: %w",
+ err,
+ )
}
for {
if !rows.Next() {
err := rows.Err()
if err != nil {
- return fmt.Errorf("error iterating user's choices while populating course groups: %w", err)
+ return fmt.Errorf(
+ "error iterating user's choices while populating course groups: %w",
+ err,
+ )
}
break
}
var thisCourseID int
err := rows.Scan(&thisCourseID)
if err != nil {
- return fmt.Errorf("error fetching user's choices while populating course groups: %w", err)
+ return fmt.Errorf(
+ "error fetching user's choices while populating course groups: %w",
+ err,
+ )
}
var thisGroupName courseGroupT
func() {
@@ -208,14 +241,22 @@ func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseG
thisGroupName = courses[thisCourseID].Group
}()
if (*userCourseGroups)[thisGroupName] {
- return fmt.Errorf("%w: user %v, group %v", errMultipleChoicesInOneGroup, userID, thisGroupName)
+ return fmt.Errorf(
+ "%w: user %v, group %v",
+ errMultipleChoicesInOneGroup,
+ userID,
+ thisGroupName,
+ )
}
(*userCourseGroups)[thisGroupName] = true
}
return nil
}
-func (course *courseT) decrementSelectedAndPropagate(ctx context.Context, conn *websocket.Conn) error {
+func (course *courseT) decrementSelectedAndPropagate(
+ ctx context.Context,
+ conn *websocket.Conn,
+) error {
func() {
course.SelectedLock.Lock()
defer course.SelectedLock.Unlock()
@@ -224,7 +265,10 @@ func (course *courseT) decrementSelectedAndPropagate(ctx context.Context, conn *
go propagateSelectedUpdate(course.ID)
err := sendSelectedUpdate(ctx, conn, course.ID)
if err != nil {
- return fmt.Errorf("error sending selected update on decrement: %w", err)
+ return fmt.Errorf(
+ "error sending selected update on decrement: %w",
+ err,
+ )
}
return nil
}
diff --git a/index.go b/index.go
index 41d2f4d..931cd2d 100644
--- a/index.go
+++ b/index.go
@@ -38,7 +38,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
if errors.Is(err, http.ErrNoCookie) {
authURL, err := generateAuthorizationURL()
if err != nil {
- wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL")
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Cannot generate authorization URL",
+ )
return
}
err = tmpl.ExecuteTemplate(
@@ -73,7 +77,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
if errors.Is(err, pgx.ErrNoRows) {
authURL, err := generateAuthorizationURL()
if err != nil {
- wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL")
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Cannot generate authorization URL",
+ )
return
}
err = tmpl.ExecuteTemplate(
@@ -90,7 +98,14 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
}
return
}
- wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: Unexpected database error: %s", err))
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ fmt.Sprintf(
+ "Error: Unexpected database error: %s",
+ err,
+ ),
+ )
return
}
diff --git a/main.go b/main.go
index 7488d3b..3090579 100644
--- a/main.go
+++ b/main.go
@@ -34,10 +34,15 @@ import (
var tmpl *template.Template
-//go:embed build/static/* tmpl/* build/iadocs/*.pdf build/iadocs/*.htm build/docs/*
+//go:embed build/static/* tmpl/*
+//go:embed build/iadocs/*.pdf build/iadocs/*.htm build/docs/*
var runFS embed.FS
-//go:embed go.* *.go docs/* frontend/* README.md LICENSE Makefile iadocs/* sql/* tmpl/* scripts/* .editorconfig .gitignore
+//go:embed go.* *.go
+//go:embed docs/* iadocs/*
+//go:embed frontend/* tmpl/*
+//go:embed README.md LICENSE Makefile .editorconfig .gitignore
+//go:embed scripts/* sql/*
var srcFS embed.FS
func main() {
@@ -45,7 +50,12 @@ func main() {
var configPath string
- flag.StringVar(&configPath, "config", "cca.scfg", "path to configuration file")
+ flag.StringVar(
+ &configPath,
+ "config",
+ "cca.scfg",
+ "path to configuration file",
+ )
flag.Parse()
if err := fetchConfig(configPath); err != nil {
@@ -63,24 +73,45 @@ func main() {
if err != nil {
log.Fatal(err)
}
- http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS))))
+ http.Handle("/static/",
+ http.StripPrefix(
+ "/static/",
+ http.FileServer(http.FS(staticFS)),
+ ),
+ )
log.Println("Registering iadocs handle")
iaDocsFS, err := fs.Sub(runFS, "build/iadocs")
if err != nil {
log.Fatal(err)
}
- http.Handle("/iadocs/", http.StripPrefix("/iadocs/", http.FileServer(http.FS(iaDocsFS))))
+ http.Handle("/iadocs/",
+ http.StripPrefix(
+ "/iadocs/",
+ http.FileServer(http.FS(iaDocsFS)),
+ ),
+ )
log.Println("Registering docs handle")
docsFS, err := fs.Sub(runFS, "build/docs")
if err != nil {
log.Fatal(err)
}
- http.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.FS(docsFS))))
+ http.Handle(
+ "/docs/",
+ http.StripPrefix(
+ "/docs/",
+ http.FileServer(http.FS(docsFS)),
+ ),
+ )
log.Println("Registering source handle")
- http.Handle("/src/", http.StripPrefix("/src/", http.FileServer(http.FS(srcFS))))
+ http.Handle(
+ "/src/",
+ http.StripPrefix(
+ "/src/", http.FileServer(http.FS(srcFS)),
+ ),
+ )
log.Println("Registering handlers")
http.HandleFunc("/{$}", handleIndex)
@@ -98,12 +129,21 @@ func main() {
)
l, err = net.Listen(config.Listen.Net, config.Listen.Addr)
if err != nil {
- log.Fatalf("Failed to establish plain listener: %v\n", err)
+ log.Fatalf(
+ "Failed to establish plain listener: %v\n",
+ err,
+ )
}
case "tls":
- cer, err := tls.LoadX509KeyPair(config.Listen.TLS.Cert, config.Listen.TLS.Key)
+ cer, err := tls.LoadX509KeyPair(
+ config.Listen.TLS.Cert,
+ config.Listen.TLS.Key,
+ )
if err != nil {
- log.Fatalf("Failed to load TLS certificate and key: %v\n", err)
+ log.Fatalf(
+ "Failed to load TLS certificate and key: %v\n",
+ err,
+ )
}
tlsconfig := &tls.Config{
Certificates: []tls.Certificate{cer},
@@ -114,9 +154,16 @@ func main() {
config.Listen.Net,
config.Listen.Addr,
)
- l, err = tls.Listen(config.Listen.Net, config.Listen.Addr, tlsconfig)
+ l, err = tls.Listen(
+ config.Listen.Net,
+ config.Listen.Addr,
+ tlsconfig,
+ )
if err != nil {
- log.Fatalf("Failed to establish TLS listener: %v\n", err)
+ log.Fatalf(
+ "Failed to establish TLS listener: %v\n",
+ err,
+ )
}
default:
log.Fatalln("listen.trans must be \"plain\" or \"tls\"")
@@ -141,7 +188,9 @@ func main() {
if config.Listen.Proto == "http" {
log.Println("Serving http")
srv := &http.Server{
- ReadHeaderTimeout: time.Duration(config.Perf.ReadHeaderTimeout) * time.Second,
+ ReadHeaderTimeout: time.Duration(
+ config.Perf.ReadHeaderTimeout,
+ ) * time.Second,
} //exhaustruct:ignore
err = srv.Serve(l)
} else {
diff --git a/wsc.go b/wsc.go
index c085789..afdc489 100644
--- a/wsc.go
+++ b/wsc.go
@@ -122,7 +122,12 @@ func handleConn(
case usemParent <- courseID:
}
}
- time.Sleep(time.Duration(usemCount>>config.Perf.UsemDelayShiftBits) * time.Millisecond)
+ time.Sleep(
+ time.Duration(
+ usemCount>>
+ config.Perf.UsemDelayShiftBits,
+ ) * time.Millisecond,
+ )
}
}()
}
@@ -134,7 +139,12 @@ func handleConn(
var userCourseGroups userCourseGroupsT = make(map[courseGroupT]bool)
err := populateUserCourseGroups(newCtx, &userCourseGroups, userID)
if err != nil {
- return reportError(fmt.Sprintf("cannot populate user course groups: %v", err))
+ return reportError(
+ fmt.Sprintf(
+ "cannot populate user course groups: %v",
+ err,
+ ),
+ )
}
/*
@@ -180,7 +190,11 @@ func handleConn(
*/
select {
case <-newCtx.Done():
- _ = writeText(ctx, c, "E :Context canceled")
+ _ = writeText(
+ ctx,
+ c,
+ "E :Context canceled",
+ )
/* Not a typo to use ctx here */
return
case recv <- &errbytesT{err: err, bytes: nil}:
@@ -202,9 +216,13 @@ func handleConn(
select {
case <-newCtx.Done():
/*
- * TODO: Somehow prioritize this case over all other cases
+ * TODO: Somehow prioritize this case over all other
+ * cases
*/
- return fmt.Errorf("context done in main event loop: %w", newCtx.Err())
+ return fmt.Errorf(
+ "context done in main event loop: %w",
+ newCtx.Err(),
+ )
/*
* There are other times when the context could be
* cancelled, and apparently some WebSocket functions
@@ -218,12 +236,18 @@ func handleConn(
case courseID := <-usemParent:
err := sendSelectedUpdate(newCtx, c, courseID)
if err != nil {
- return fmt.Errorf("error acting on usem: %w", err)
+ return fmt.Errorf(
+ "error acting on usem: %w",
+ err,
+ )
}
continue
case errbytes := <-recv:
if errbytes.err != nil {
- return fmt.Errorf("error fetching message from recv channel: %w", errbytes.err)
+ return fmt.Errorf(
+ "error fetching message from recv channel: %w",
+ errbytes.err,
+ )
/*
* Note that this cannot return newCtx.Err(),
* so we handle the error reporting in the
@@ -233,17 +257,40 @@ func handleConn(
mar = splitMsg(errbytes.bytes)
switch mar[0] {
case "HELLO":
- err := messageHello(newCtx, c, reportError, mar, userID, session)
+ err := messageHello(
+ newCtx,
+ c,
+ reportError,
+ mar,
+ userID,
+ session,
+ )
if err != nil {
return err
}
case "Y":
- err := messageChooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups)
+ err := messageChooseCourse(
+ newCtx,
+ c,
+ reportError,
+ mar,
+ userID,
+ session,
+ &userCourseGroups,
+ )
if err != nil {
return err
}
case "N":
- err := messageUnchooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups)
+ err := messageUnchooseCourse(
+ newCtx,
+ c,
+ reportError,
+ mar,
+ userID,
+ session,
+ &userCourseGroups,
+ )
if err != nil {
return err
}
diff --git a/wsh.go b/wsh.go
index abf4e61..4bcb033 100644
--- a/wsh.go
+++ b/wsh.go
@@ -46,7 +46,11 @@ func handleWs(w http.ResponseWriter, req *http.Request) {
wsOptions,
)
if err != nil {
- wstr(w, http.StatusBadRequest, "This endpoint only supports valid WebSocket connections.")
+ wstr(
+ w,
+ http.StatusBadRequest,
+ "This endpoint only supports valid WebSocket connections.",
+ )
return
}
defer func() {
@@ -104,11 +108,17 @@ func handleWs(w http.ResponseWriter, req *http.Request) {
"fake@runxiyu.org",
"Y11",
session,
- time.Now().Add(time.Duration(config.Auth.Expr)*time.Second).Unix(),
+ time.Now().Add(
+ time.Duration(config.Auth.Expr)*time.Second,
+ ).Unix(),
)
if err != nil && config.Auth.Fake != 4712 {
/* TODO check pgerr */
- err := writeText(req.Context(), c, "E :Database error while writing fake account info")
+ err := writeText(
+ req.Context(),
+ c,
+ "E :Database error while writing fake account info",
+ )
if err != nil {
log.Println(err)
}
@@ -133,7 +143,11 @@ func handleWs(w http.ResponseWriter, req *http.Request) {
}
return
} else if err != nil {
- err := writeText(req.Context(), c, "E :Database error while selecting session")
+ err := writeText(
+ req.Context(),
+ c,
+ "E :Database error while selecting session",
+ )
if err != nil {
log.Println(err)
}
diff --git a/wsm.go b/wsm.go
index 11ae0ee..3f9a703 100644
--- a/wsm.go
+++ b/wsm.go
@@ -33,12 +33,22 @@ import (
"github.com/jackc/pgx/v5/pgconn"
)
-func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string) error {
+func messageHello(
+ ctx context.Context,
+ c *websocket.Conn,
+ reportError reportErrorT,
+ mar []string,
+ userID string,
+ session string,
+) error {
_, _ = mar, session
select {
case <-ctx.Done():
- return fmt.Errorf("context done when handling hello: %w", ctx.Err())
+ return fmt.Errorf(
+ "context done when handling hello: %w",
+ ctx.Err(),
+ )
default:
}
@@ -63,12 +73,23 @@ func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErro
return nil
}
-func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error {
+func messageChooseCourse(
+ ctx context.Context,
+ c *websocket.Conn,
+ reportError reportErrorT,
+ mar []string,
+ userID string,
+ session string,
+ userCourseGroups *userCourseGroupsT,
+) error {
_ = session
select {
case <-ctx.Done():
- return fmt.Errorf("context done when handling choose: %w", ctx.Err())
+ return fmt.Errorf(
+ "context done when handling choose: %w",
+ ctx.Err(),
+ )
default:
}
@@ -91,7 +112,10 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
if (*userCourseGroups)[thisCourseGroup] {
err := writeText(ctx, c, "R "+mar[1]+" :Group conflict")
if err != nil {
- return fmt.Errorf("error rejecting course choice due to group conflict: %w", err)
+ return fmt.Errorf(
+ "error rejecting course choice due to group conflict: %w",
+ err,
+ )
}
return nil
}
@@ -99,12 +123,16 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
err = func() (returnedError error) { /* Named returns so I could modify them in defer */
tx, err := db.Begin(ctx)
if err != nil {
- return reportError("Database error while beginning transaction")
+ return reportError(
+ "Database error while beginning transaction",
+ )
}
defer func() {
err := tx.Rollback(ctx)
if err != nil && (!errors.Is(err, pgx.ErrTxClosed)) {
- returnedError = reportError("Database error while rolling back transaction in defer block")
+ returnedError = reportError(
+ "Database error while rolling back transaction in defer block",
+ )
return
}
}()
@@ -121,11 +149,16 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation {
err := writeText(ctx, c, "Y "+mar[1])
if err != nil {
- return fmt.Errorf("error reaffirming course choice: %w", err)
+ return fmt.Errorf(
+ "error reaffirming course choice: %w",
+ err,
+ )
}
return nil
}
- return reportError("Database error while inserting course choice")
+ return reportError(
+ "Database error while inserting course choice",
+ )
}
ok := func() bool {
@@ -144,9 +177,14 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
if err != nil {
err := course.decrementSelectedAndPropagate(ctx, c)
if err != nil {
- return fmt.Errorf("error decrementing and notifying: %w", err)
+ return fmt.Errorf(
+ "error decrementing and notifying: %w",
+ err,
+ )
}
- return reportError("Database error while committing transaction")
+ return reportError(
+ "Database error while committing transaction",
+ )
}
/*
@@ -157,20 +195,31 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
err = writeText(ctx, c, "Y "+mar[1])
if err != nil {
- return fmt.Errorf("error affirming course choice: %w", err)
+ return fmt.Errorf(
+ "error affirming course choice: %w",
+ err,
+ )
}
err = sendSelectedUpdate(ctx, c, courseID)
if err != nil {
- return fmt.Errorf("error notifying after increment: %w", err)
+ return fmt.Errorf(
+ "error notifying after increment: %w",
+ err,
+ )
}
} else {
err := tx.Rollback(ctx)
if err != nil {
- return reportError("Database error while rolling back transaction due to course limit")
+ return reportError(
+ "Database error while rolling back transaction due to course limit",
+ )
}
err = writeText(ctx, c, "R "+mar[1]+" :Full")
if err != nil {
- return fmt.Errorf("error rejecting course choice: %w", err)
+ return fmt.Errorf(
+ "error rejecting course choice: %w",
+ err,
+ )
}
}
return nil
@@ -181,7 +230,15 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
return nil
}
-func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error {
+func messageUnchooseCourse(
+ ctx context.Context,
+ c *websocket.Conn,
+ reportError reportErrorT,
+ mar []string,
+ userID string,
+ session string,
+ userCourseGroups *userCourseGroupsT,
+) error {
_ = session
select {
@@ -211,13 +268,18 @@ func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError r
courseID,
)
if err != nil {
- return reportError("Database error while deleting course choice")
+ return reportError(
+ "Database error while deleting course choice",
+ )
}
if ct.RowsAffected() != 0 {
err := course.decrementSelectedAndPropagate(ctx, c)
if err != nil {
- return fmt.Errorf("error decrementing and notifying: %w", err)
+ return fmt.Errorf(
+ "error decrementing and notifying: %w",
+ err,
+ )
}
var thisCourseGroup courseGroupT
func() {
@@ -233,7 +295,10 @@ func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError r
err = writeText(ctx, c, "N "+mar[1])
if err != nil {
- return fmt.Errorf("error replying that course has been deselected: %w", err)
+ return fmt.Errorf(
+ "error replying that course has been deselected: %w",
+ err,
+ )
}
return nil
diff --git a/wsp.go b/wsp.go
index d4b8f08..09a1360 100644
--- a/wsp.go
+++ b/wsp.go
@@ -55,7 +55,11 @@ endl:
return mar
}
-func baseReportError(ctx context.Context, conn *websocket.Conn, e string) error {
+func baseReportError(
+ ctx context.Context,
+ conn *websocket.Conn,
+ e string,
+) error {
err := writeText(ctx, conn, "E :"+e)
if err != nil {
return fmt.Errorf("error reporting protocol violation: %w", err)
@@ -84,7 +88,11 @@ func propagateSelectedUpdate(courseID int) {
}
}
-func sendSelectedUpdate(ctx context.Context, conn *websocket.Conn, courseID int) error {
+func sendSelectedUpdate(
+ ctx context.Context,
+ conn *websocket.Conn,
+ courseID int,
+) error {
var selected int
func() {
course := courses[courseID]
@@ -94,7 +102,10 @@ func sendSelectedUpdate(ctx context.Context, conn *websocket.Conn, courseID int)
}()
err := writeText(ctx, conn, fmt.Sprintf("M %d %d", courseID, selected))
if err != nil {
- return fmt.Errorf("error sending to websocket for course selected update: %w", err)
+ return fmt.Errorf(
+ "error sending to websocket for course selected update: %w",
+ err,
+ )
}
return nil
}