diff options
Diffstat (limited to '')
-rw-r--r-- | auth.go | 97 | ||||
-rw-r--r-- | config.go | 56 | ||||
-rw-r--r-- | courses.go | 68 | ||||
-rw-r--r-- | index.go | 21 | ||||
-rw-r--r-- | main.go | 75 | ||||
-rw-r--r-- | wsc.go | 67 | ||||
-rw-r--r-- | wsh.go | 22 | ||||
-rw-r--r-- | wsm.go | 103 | ||||
-rw-r--r-- | wsp.go | 17 |
9 files changed, 431 insertions, 95 deletions
@@ -96,7 +96,11 @@ func generateAuthorizationURL() (string, error) { */ func handleAuth(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { - wstr(w, http.StatusMethodNotAllowed, "Only POST is supported on the authentication endpoint") + wstr( + w, + http.StatusMethodNotAllowed, + "Only POST is supported on the authentication endpoint", + ) return } @@ -110,7 +114,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { if returnedError != "" { returnedErrorDescription := req.PostFormValue("error_description") if returnedErrorDescription == "" { - wstr(w, http.StatusBadRequest, fmt.Sprintf("authorize endpoint returned error: %v", returnedErrorDescription)) + wstr( + w, + http.StatusBadRequest, + fmt.Sprintf( + "authorize endpoint returned error: %v", + returnedErrorDescription, + ), + ) return } wstr(w, http.StatusBadRequest, fmt.Sprintf( @@ -149,7 +160,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { return case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): - wstr(w, http.StatusBadRequest, "JWT token expired or not yet valid") + wstr( + w, + http.StatusBadRequest, + "JWT token expired or not yet valid", + ) return default: wstr(w, http.StatusBadRequest, "Unhandled JWT token error") @@ -167,7 +182,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { accessToken, err := getAccessToken(req.Context(), authorizationCode) if err != nil { - wstr(w, http.StatusInternalServerError, fmt.Sprintf("Unable to fetch access token: %v", err)) + wstr( + w, + http.StatusInternalServerError, + fmt.Sprintf("Unable to fetch access token: %v", err), + ) return } @@ -178,9 +197,11 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { } switch { - case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || department == "High School Teaching & Learning 高中教学部门": + case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || + department == "High School Teaching & Learning 高中教学部门": department = "Staff" - case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12": + case department == "Y9" || department == "Y10" || + department == "Y11" || department == "Y12": default: wstr( w, @@ -245,11 +266,19 @@ func handleAuth(w http.ResponseWriter, req *http.Request) { claims.Oid, ) if err != nil { - wstr(w, http.StatusInternalServerError, "Database error while updating account.") + wstr( + w, + http.StatusInternalServerError, + "Database error while updating account.", + ) return } } else { - wstr(w, http.StatusInternalServerError, "Database error while writing account info.") + wstr( + w, + http.StatusInternalServerError, + "Database error while writing account info.", + ) return } } @@ -312,7 +341,11 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { * "department" field, which hopefully doesn't occur as we * have specified $select=department in the OData query. */ - return "", fmt.Errorf("error getting department: %w", errInsufficientFields) + return "", + fmt.Errorf( + "error getting department: %w", + errInsufficientFields, + ) } return *(departmentWrap.Department), nil @@ -322,7 +355,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { * TODO: Access token expiration is not checked anywhere. */ type accessTokenT struct { - OriginalExpiresIn *int `json:"expires_in"` /* Original time to expiration */ + OriginalExpiresIn *int `json:"expires_in"` /* Original time to expr */ Expiration time.Time Content *string `json:"access_token"` Error *string `json:"error"` @@ -334,7 +367,10 @@ type accessTokenT struct { * Obtain an access token from the token endpoint with an existing * authorization code. */ -func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT, error) { +func getAccessToken( + ctx context.Context, + authorizationCode string, +) (accessTokenT, error) { var accessToken accessTokenT t := time.Now() v := url.Values{} @@ -344,31 +380,52 @@ func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT v.Set("redirect_uri", config.URL+"/auth") v.Set("grant_type", "authorization_code") v.Set("client_secret", config.Auth.Secret) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, config.Auth.Token, strings.NewReader(v.Encode())) + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + config.Auth.Token, + strings.NewReader(v.Encode()), + ) if err != nil { - return accessToken, fmt.Errorf("error making access token request: %w", err) + return accessToken, + fmt.Errorf("error making access token request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return accessToken, fmt.Errorf("error requesting access token: %w", err) + return accessToken, + fmt.Errorf("error requesting access token: %w", err) } defer resp.Body.Close() decoder := json.NewDecoder(resp.Body) err = decoder.Decode(&accessToken) if err != nil { - return accessToken, fmt.Errorf("error decoding access token: %w", err) + return accessToken, + fmt.Errorf("error decoding access token: %w", err) } - if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil { - if accessToken.Error == nil || accessToken.ErrorCodes == nil || accessToken.ErrorDescription == nil { + if accessToken.Error != nil || accessToken.ErrorCodes != nil || + accessToken.ErrorDescription != nil { + if accessToken.Error == nil || accessToken.ErrorCodes == nil || + accessToken.ErrorDescription == nil { return accessToken, errAccessTokenIncompleteError } - return accessToken, fmt.Errorf("%w: %v", errTokenEndpointReturnedError, *accessToken.ErrorDescription) + return accessToken, + fmt.Errorf( + "%w: %v", + errTokenEndpointReturnedError, + *accessToken.ErrorDescription, + ) } if accessToken.Content == nil || accessToken.OriginalExpiresIn == nil { - return accessToken, fmt.Errorf("error extracting access token: %w", errInsufficientFields) + return accessToken, + fmt.Errorf( + "error extracting access token: %w", + errInsufficientFields, + ) } - accessToken.Expiration = t.Add(time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second) + accessToken.Expiration = t.Add( + time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second, + ) return accessToken, nil } @@ -119,13 +119,21 @@ func fetchConfig(path string) (retErr error) { if v := recover(); v != nil { s, ok := v.(error) if ok { - retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, s) + retErr = fmt.Errorf( + "%w: %w", + errCannotProcessConfig, + s, + ) } retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v) return } if retErr != nil { - retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, retErr) + retErr = fmt.Errorf( + "%w: %w", + errCannotProcessConfig, + retErr, + ) return } }() @@ -172,12 +180,18 @@ func fetchConfig(path string) (retErr error) { if config.Listen.Trans == "tls" { if configWithPointers.Listen.TLS.Cert == nil { - return fmt.Errorf("%w: listen.tls.cert", errMissingConfigValue) + return fmt.Errorf( + "%w: listen.tls.cert", + errMissingConfigValue, + ) } config.Listen.TLS.Cert = *(configWithPointers.Listen.TLS.Cert) if configWithPointers.Listen.TLS.Key == nil { - return fmt.Errorf("%w: listen.tls.key", errMissingConfigValue) + return fmt.Errorf( + "%w: listen.tls.key", + errMissingConfigValue, + ) } config.Listen.TLS.Key = *(configWithPointers.Listen.TLS.Key) } @@ -201,11 +215,19 @@ func fetchConfig(path string) (retErr error) { /* It's okay to set it to 0 in production */ case 4712, 9080: /* Don't use them unless you know what you're doing */ if config.Prod { - return fmt.Errorf("%w: fake authentication is incompatible with production mode", errIllegalConfig) + return fmt.Errorf( + "%w: fake authentication is incompatible with production mode", + errIllegalConfig, + ) } - log.Println("!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.") + log.Println( + "!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.", + ) default: - return fmt.Errorf("%w: invalid option for auth.fake", errIllegalConfig) + return fmt.Errorf( + "%w: invalid option for auth.fake", + errIllegalConfig, + ) } } @@ -240,22 +262,34 @@ func fetchConfig(path string) (retErr error) { config.Auth.Expr = *(configWithPointers.Auth.Expr) if configWithPointers.Perf.MessageArgumentsCap == nil { - return fmt.Errorf("%w: perf.msg_args_cap", errMissingConfigValue) + return fmt.Errorf( + "%w: perf.msg_args_cap", + errMissingConfigValue, + ) } config.Perf.MessageArgumentsCap = *(configWithPointers.Perf.MessageArgumentsCap) if configWithPointers.Perf.MessageBytesCap == nil { - return fmt.Errorf("%w: perf.msg_bytes_cap", errMissingConfigValue) + return fmt.Errorf( + "%w: perf.msg_bytes_cap", + errMissingConfigValue, + ) } config.Perf.MessageBytesCap = *(configWithPointers.Perf.MessageBytesCap) if configWithPointers.Perf.ReadHeaderTimeout == nil { - return fmt.Errorf("%w: perf.read_header_timeout", errMissingConfigValue) + return fmt.Errorf( + "%w: perf.read_header_timeout", + errMissingConfigValue, + ) } config.Perf.ReadHeaderTimeout = *(configWithPointers.Perf.ReadHeaderTimeout) if configWithPointers.Perf.UsemDelayShiftBits == nil { - return fmt.Errorf("%w: perf.usem_delay_shift_bits", errMissingConfigValue) + return fmt.Errorf( + "%w: perf.usem_delay_shift_bits", + errMissingConfigValue, + ) } config.Perf.UsemDelayShiftBits = *(configWithPointers.Perf.UsemDelayShiftBits) @@ -143,7 +143,10 @@ func setupCourses() error { if !rows.Next() { err := rows.Err() if err != nil { - return fmt.Errorf("error fetching courses: %w", err) + return fmt.Errorf( + "error fetching courses: %w", + err, + ) } break } @@ -163,17 +166,30 @@ func setupCourses() error { return fmt.Errorf("error fetching courses: %w", err) } if !checkCourseType(currentCourse.Type) { - return fmt.Errorf("%w: %d %s", errInvalidCourseType, currentCourse.ID, currentCourse.Type) + return fmt.Errorf( + "%w: %d %s", + errInvalidCourseType, + currentCourse.ID, + currentCourse.Type, + ) } if !checkCourseGroup(currentCourse.Group) { - return fmt.Errorf("%w: %d %s", errInvalidCourseGroup, currentCourse.ID, currentCourse.Group) + return fmt.Errorf( + "%w: %d %s", + errInvalidCourseGroup, + currentCourse.ID, + currentCourse.Group, + ) } err := db.QueryRow(context.Background(), "SELECT COUNT (*) FROM choices WHERE courseid = $1", currentCourse.ID, ).Scan(¤tCourse.Selected) if err != nil { - return fmt.Errorf("error querying course member number: %w", err) + return fmt.Errorf( + "error querying course member number: %w", + err, + ) } courses[currentCourse.ID] = ¤tCourse } @@ -183,23 +199,40 @@ func setupCourses() error { type userCourseGroupsT map[courseGroupT]bool -func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseGroupsT, userID string) error { - rows, err := db.Query(ctx, "SELECT courseid FROM choices WHERE userid = $1", userID) +func populateUserCourseGroups( + ctx context.Context, + userCourseGroups *userCourseGroupsT, + userID string, +) error { + rows, err := db.Query( + ctx, + "SELECT courseid FROM choices WHERE userid = $1", + userID, + ) if err != nil { - return fmt.Errorf("error querying user's choices while populating course groups: %w", err) + return fmt.Errorf( + "error querying user's choices while populating course groups: %w", + err, + ) } for { if !rows.Next() { err := rows.Err() if err != nil { - return fmt.Errorf("error iterating user's choices while populating course groups: %w", err) + return fmt.Errorf( + "error iterating user's choices while populating course groups: %w", + err, + ) } break } var thisCourseID int err := rows.Scan(&thisCourseID) if err != nil { - return fmt.Errorf("error fetching user's choices while populating course groups: %w", err) + return fmt.Errorf( + "error fetching user's choices while populating course groups: %w", + err, + ) } var thisGroupName courseGroupT func() { @@ -208,14 +241,22 @@ func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseG thisGroupName = courses[thisCourseID].Group }() if (*userCourseGroups)[thisGroupName] { - return fmt.Errorf("%w: user %v, group %v", errMultipleChoicesInOneGroup, userID, thisGroupName) + return fmt.Errorf( + "%w: user %v, group %v", + errMultipleChoicesInOneGroup, + userID, + thisGroupName, + ) } (*userCourseGroups)[thisGroupName] = true } return nil } -func (course *courseT) decrementSelectedAndPropagate(ctx context.Context, conn *websocket.Conn) error { +func (course *courseT) decrementSelectedAndPropagate( + ctx context.Context, + conn *websocket.Conn, +) error { func() { course.SelectedLock.Lock() defer course.SelectedLock.Unlock() @@ -224,7 +265,10 @@ func (course *courseT) decrementSelectedAndPropagate(ctx context.Context, conn * go propagateSelectedUpdate(course.ID) err := sendSelectedUpdate(ctx, conn, course.ID) if err != nil { - return fmt.Errorf("error sending selected update on decrement: %w", err) + return fmt.Errorf( + "error sending selected update on decrement: %w", + err, + ) } return nil } @@ -38,7 +38,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { if errors.Is(err, http.ErrNoCookie) { authURL, err := generateAuthorizationURL() if err != nil { - wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL") + wstr( + w, + http.StatusInternalServerError, + "Cannot generate authorization URL", + ) return } err = tmpl.ExecuteTemplate( @@ -73,7 +77,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { if errors.Is(err, pgx.ErrNoRows) { authURL, err := generateAuthorizationURL() if err != nil { - wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL") + wstr( + w, + http.StatusInternalServerError, + "Cannot generate authorization URL", + ) return } err = tmpl.ExecuteTemplate( @@ -90,7 +98,14 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { } return } - wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: Unexpected database error: %s", err)) + wstr( + w, + http.StatusInternalServerError, + fmt.Sprintf( + "Error: Unexpected database error: %s", + err, + ), + ) return } @@ -34,10 +34,15 @@ import ( var tmpl *template.Template -//go:embed build/static/* tmpl/* build/iadocs/*.pdf build/iadocs/*.htm build/docs/* +//go:embed build/static/* tmpl/* +//go:embed build/iadocs/*.pdf build/iadocs/*.htm build/docs/* var runFS embed.FS -//go:embed go.* *.go docs/* frontend/* README.md LICENSE Makefile iadocs/* sql/* tmpl/* scripts/* .editorconfig .gitignore +//go:embed go.* *.go +//go:embed docs/* iadocs/* +//go:embed frontend/* tmpl/* +//go:embed README.md LICENSE Makefile .editorconfig .gitignore +//go:embed scripts/* sql/* var srcFS embed.FS func main() { @@ -45,7 +50,12 @@ func main() { var configPath string - flag.StringVar(&configPath, "config", "cca.scfg", "path to configuration file") + flag.StringVar( + &configPath, + "config", + "cca.scfg", + "path to configuration file", + ) flag.Parse() if err := fetchConfig(configPath); err != nil { @@ -63,24 +73,45 @@ func main() { if err != nil { log.Fatal(err) } - http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS)))) + http.Handle("/static/", + http.StripPrefix( + "/static/", + http.FileServer(http.FS(staticFS)), + ), + ) log.Println("Registering iadocs handle") iaDocsFS, err := fs.Sub(runFS, "build/iadocs") if err != nil { log.Fatal(err) } - http.Handle("/iadocs/", http.StripPrefix("/iadocs/", http.FileServer(http.FS(iaDocsFS)))) + http.Handle("/iadocs/", + http.StripPrefix( + "/iadocs/", + http.FileServer(http.FS(iaDocsFS)), + ), + ) log.Println("Registering docs handle") docsFS, err := fs.Sub(runFS, "build/docs") if err != nil { log.Fatal(err) } - http.Handle("/docs/", http.StripPrefix("/docs/", http.FileServer(http.FS(docsFS)))) + http.Handle( + "/docs/", + http.StripPrefix( + "/docs/", + http.FileServer(http.FS(docsFS)), + ), + ) log.Println("Registering source handle") - http.Handle("/src/", http.StripPrefix("/src/", http.FileServer(http.FS(srcFS)))) + http.Handle( + "/src/", + http.StripPrefix( + "/src/", http.FileServer(http.FS(srcFS)), + ), + ) log.Println("Registering handlers") http.HandleFunc("/{$}", handleIndex) @@ -98,12 +129,21 @@ func main() { ) l, err = net.Listen(config.Listen.Net, config.Listen.Addr) if err != nil { - log.Fatalf("Failed to establish plain listener: %v\n", err) + log.Fatalf( + "Failed to establish plain listener: %v\n", + err, + ) } case "tls": - cer, err := tls.LoadX509KeyPair(config.Listen.TLS.Cert, config.Listen.TLS.Key) + cer, err := tls.LoadX509KeyPair( + config.Listen.TLS.Cert, + config.Listen.TLS.Key, + ) if err != nil { - log.Fatalf("Failed to load TLS certificate and key: %v\n", err) + log.Fatalf( + "Failed to load TLS certificate and key: %v\n", + err, + ) } tlsconfig := &tls.Config{ Certificates: []tls.Certificate{cer}, @@ -114,9 +154,16 @@ func main() { config.Listen.Net, config.Listen.Addr, ) - l, err = tls.Listen(config.Listen.Net, config.Listen.Addr, tlsconfig) + l, err = tls.Listen( + config.Listen.Net, + config.Listen.Addr, + tlsconfig, + ) if err != nil { - log.Fatalf("Failed to establish TLS listener: %v\n", err) + log.Fatalf( + "Failed to establish TLS listener: %v\n", + err, + ) } default: log.Fatalln("listen.trans must be \"plain\" or \"tls\"") @@ -141,7 +188,9 @@ func main() { if config.Listen.Proto == "http" { log.Println("Serving http") srv := &http.Server{ - ReadHeaderTimeout: time.Duration(config.Perf.ReadHeaderTimeout) * time.Second, + ReadHeaderTimeout: time.Duration( + config.Perf.ReadHeaderTimeout, + ) * time.Second, } //exhaustruct:ignore err = srv.Serve(l) } else { @@ -122,7 +122,12 @@ func handleConn( case usemParent <- courseID: } } - time.Sleep(time.Duration(usemCount>>config.Perf.UsemDelayShiftBits) * time.Millisecond) + time.Sleep( + time.Duration( + usemCount>> + config.Perf.UsemDelayShiftBits, + ) * time.Millisecond, + ) } }() } @@ -134,7 +139,12 @@ func handleConn( var userCourseGroups userCourseGroupsT = make(map[courseGroupT]bool) err := populateUserCourseGroups(newCtx, &userCourseGroups, userID) if err != nil { - return reportError(fmt.Sprintf("cannot populate user course groups: %v", err)) + return reportError( + fmt.Sprintf( + "cannot populate user course groups: %v", + err, + ), + ) } /* @@ -180,7 +190,11 @@ func handleConn( */ select { case <-newCtx.Done(): - _ = writeText(ctx, c, "E :Context canceled") + _ = writeText( + ctx, + c, + "E :Context canceled", + ) /* Not a typo to use ctx here */ return case recv <- &errbytesT{err: err, bytes: nil}: @@ -202,9 +216,13 @@ func handleConn( select { case <-newCtx.Done(): /* - * TODO: Somehow prioritize this case over all other cases + * TODO: Somehow prioritize this case over all other + * cases */ - return fmt.Errorf("context done in main event loop: %w", newCtx.Err()) + return fmt.Errorf( + "context done in main event loop: %w", + newCtx.Err(), + ) /* * There are other times when the context could be * cancelled, and apparently some WebSocket functions @@ -218,12 +236,18 @@ func handleConn( case courseID := <-usemParent: err := sendSelectedUpdate(newCtx, c, courseID) if err != nil { - return fmt.Errorf("error acting on usem: %w", err) + return fmt.Errorf( + "error acting on usem: %w", + err, + ) } continue case errbytes := <-recv: if errbytes.err != nil { - return fmt.Errorf("error fetching message from recv channel: %w", errbytes.err) + return fmt.Errorf( + "error fetching message from recv channel: %w", + errbytes.err, + ) /* * Note that this cannot return newCtx.Err(), * so we handle the error reporting in the @@ -233,17 +257,40 @@ func handleConn( mar = splitMsg(errbytes.bytes) switch mar[0] { case "HELLO": - err := messageHello(newCtx, c, reportError, mar, userID, session) + err := messageHello( + newCtx, + c, + reportError, + mar, + userID, + session, + ) if err != nil { return err } case "Y": - err := messageChooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups) + err := messageChooseCourse( + newCtx, + c, + reportError, + mar, + userID, + session, + &userCourseGroups, + ) if err != nil { return err } case "N": - err := messageUnchooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups) + err := messageUnchooseCourse( + newCtx, + c, + reportError, + mar, + userID, + session, + &userCourseGroups, + ) if err != nil { return err } @@ -46,7 +46,11 @@ func handleWs(w http.ResponseWriter, req *http.Request) { wsOptions, ) if err != nil { - wstr(w, http.StatusBadRequest, "This endpoint only supports valid WebSocket connections.") + wstr( + w, + http.StatusBadRequest, + "This endpoint only supports valid WebSocket connections.", + ) return } defer func() { @@ -104,11 +108,17 @@ func handleWs(w http.ResponseWriter, req *http.Request) { "fake@runxiyu.org", "Y11", session, - time.Now().Add(time.Duration(config.Auth.Expr)*time.Second).Unix(), + time.Now().Add( + time.Duration(config.Auth.Expr)*time.Second, + ).Unix(), ) if err != nil && config.Auth.Fake != 4712 { /* TODO check pgerr */ - err := writeText(req.Context(), c, "E :Database error while writing fake account info") + err := writeText( + req.Context(), + c, + "E :Database error while writing fake account info", + ) if err != nil { log.Println(err) } @@ -133,7 +143,11 @@ func handleWs(w http.ResponseWriter, req *http.Request) { } return } else if err != nil { - err := writeText(req.Context(), c, "E :Database error while selecting session") + err := writeText( + req.Context(), + c, + "E :Database error while selecting session", + ) if err != nil { log.Println(err) } @@ -33,12 +33,22 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string) error { +func messageHello( + ctx context.Context, + c *websocket.Conn, + reportError reportErrorT, + mar []string, + userID string, + session string, +) error { _, _ = mar, session select { case <-ctx.Done(): - return fmt.Errorf("context done when handling hello: %w", ctx.Err()) + return fmt.Errorf( + "context done when handling hello: %w", + ctx.Err(), + ) default: } @@ -63,12 +73,23 @@ func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErro return nil } -func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error { +func messageChooseCourse( + ctx context.Context, + c *websocket.Conn, + reportError reportErrorT, + mar []string, + userID string, + session string, + userCourseGroups *userCourseGroupsT, +) error { _ = session select { case <-ctx.Done(): - return fmt.Errorf("context done when handling choose: %w", ctx.Err()) + return fmt.Errorf( + "context done when handling choose: %w", + ctx.Err(), + ) default: } @@ -91,7 +112,10 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep if (*userCourseGroups)[thisCourseGroup] { err := writeText(ctx, c, "R "+mar[1]+" :Group conflict") if err != nil { - return fmt.Errorf("error rejecting course choice due to group conflict: %w", err) + return fmt.Errorf( + "error rejecting course choice due to group conflict: %w", + err, + ) } return nil } @@ -99,12 +123,16 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep err = func() (returnedError error) { /* Named returns so I could modify them in defer */ tx, err := db.Begin(ctx) if err != nil { - return reportError("Database error while beginning transaction") + return reportError( + "Database error while beginning transaction", + ) } defer func() { err := tx.Rollback(ctx) if err != nil && (!errors.Is(err, pgx.ErrTxClosed)) { - returnedError = reportError("Database error while rolling back transaction in defer block") + returnedError = reportError( + "Database error while rolling back transaction in defer block", + ) return } }() @@ -121,11 +149,16 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation { err := writeText(ctx, c, "Y "+mar[1]) if err != nil { - return fmt.Errorf("error reaffirming course choice: %w", err) + return fmt.Errorf( + "error reaffirming course choice: %w", + err, + ) } return nil } - return reportError("Database error while inserting course choice") + return reportError( + "Database error while inserting course choice", + ) } ok := func() bool { @@ -144,9 +177,14 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep if err != nil { err := course.decrementSelectedAndPropagate(ctx, c) if err != nil { - return fmt.Errorf("error decrementing and notifying: %w", err) + return fmt.Errorf( + "error decrementing and notifying: %w", + err, + ) } - return reportError("Database error while committing transaction") + return reportError( + "Database error while committing transaction", + ) } /* @@ -157,20 +195,31 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep err = writeText(ctx, c, "Y "+mar[1]) if err != nil { - return fmt.Errorf("error affirming course choice: %w", err) + return fmt.Errorf( + "error affirming course choice: %w", + err, + ) } err = sendSelectedUpdate(ctx, c, courseID) if err != nil { - return fmt.Errorf("error notifying after increment: %w", err) + return fmt.Errorf( + "error notifying after increment: %w", + err, + ) } } else { err := tx.Rollback(ctx) if err != nil { - return reportError("Database error while rolling back transaction due to course limit") + return reportError( + "Database error while rolling back transaction due to course limit", + ) } err = writeText(ctx, c, "R "+mar[1]+" :Full") if err != nil { - return fmt.Errorf("error rejecting course choice: %w", err) + return fmt.Errorf( + "error rejecting course choice: %w", + err, + ) } } return nil @@ -181,7 +230,15 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep return nil } -func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error { +func messageUnchooseCourse( + ctx context.Context, + c *websocket.Conn, + reportError reportErrorT, + mar []string, + userID string, + session string, + userCourseGroups *userCourseGroupsT, +) error { _ = session select { @@ -211,13 +268,18 @@ func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError r courseID, ) if err != nil { - return reportError("Database error while deleting course choice") + return reportError( + "Database error while deleting course choice", + ) } if ct.RowsAffected() != 0 { err := course.decrementSelectedAndPropagate(ctx, c) if err != nil { - return fmt.Errorf("error decrementing and notifying: %w", err) + return fmt.Errorf( + "error decrementing and notifying: %w", + err, + ) } var thisCourseGroup courseGroupT func() { @@ -233,7 +295,10 @@ func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError r err = writeText(ctx, c, "N "+mar[1]) if err != nil { - return fmt.Errorf("error replying that course has been deselected: %w", err) + return fmt.Errorf( + "error replying that course has been deselected: %w", + err, + ) } return nil @@ -55,7 +55,11 @@ endl: return mar } -func baseReportError(ctx context.Context, conn *websocket.Conn, e string) error { +func baseReportError( + ctx context.Context, + conn *websocket.Conn, + e string, +) error { err := writeText(ctx, conn, "E :"+e) if err != nil { return fmt.Errorf("error reporting protocol violation: %w", err) @@ -84,7 +88,11 @@ func propagateSelectedUpdate(courseID int) { } } -func sendSelectedUpdate(ctx context.Context, conn *websocket.Conn, courseID int) error { +func sendSelectedUpdate( + ctx context.Context, + conn *websocket.Conn, + courseID int, +) error { var selected int func() { course := courses[courseID] @@ -94,7 +102,10 @@ func sendSelectedUpdate(ctx context.Context, conn *websocket.Conn, courseID int) }() err := writeText(ctx, conn, fmt.Sprintf("M %d %d", courseID, selected)) if err != nil { - return fmt.Errorf("error sending to websocket for course selected update: %w", err) + return fmt.Errorf( + "error sending to websocket for course selected update: %w", + err, + ) } return nil } |