diff options
Diffstat (limited to '')
-rw-r--r-- | config.go | 23 | ||||
-rw-r--r-- | course_groups.go | 9 | ||||
-rw-r--r-- | courses.go | 13 | ||||
-rw-r--r-- | database.go | 3 | ||||
-rw-r--r-- | endpoint_auth.go | 17 | ||||
-rw-r--r-- | errors.go | 8 | ||||
-rw-r--r-- | misc_utils.go | 3 | ||||
-rw-r--r-- | state.go | 6 | ||||
-rw-r--r-- | ws_connection.go | 18 | ||||
-rw-r--r-- | ws_utils.go | 2 | ||||
-rw-r--r-- | wsmsg_choose.go | 21 | ||||
-rw-r--r-- | wsmsg_hello.go | 8 | ||||
-rw-r--r-- | wsmsg_unchoose.go | 13 |
13 files changed, 53 insertions, 91 deletions
@@ -111,36 +111,19 @@ var config struct { func fetchConfig(path string) (retErr error) { defer func() { - if v := recover(); v != nil { - s, ok := v.(error) - if ok { - retErr = fmt.Errorf( - "%w: %w", - errCannotProcessConfig, - s, - ) - } - retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v) - return - } if retErr != nil { - retErr = fmt.Errorf( - "%w: %w", - errCannotProcessConfig, - retErr, - ) - return + retErr = wrapError(errCannotProcessConfig, retErr) } }() f, err := os.Open(path) if err != nil { - return fmt.Errorf("%w: %w", errCannotOpenConfig, err) + return wrapError(errCannotOpenConfig, err) } err = scfg.NewDecoder(bufio.NewReader(f)).Decode(&configWithPointers) if err != nil { - return fmt.Errorf("%w: %w", errCannotDecodeConfig, err) + return wrapError(errCannotDecodeConfig, err) } if configWithPointers.URL == nil { diff --git a/course_groups.go b/course_groups.go index c1541ce..f81fcfa 100644 --- a/course_groups.go +++ b/course_groups.go @@ -63,8 +63,7 @@ func populateUserCourseGroups( userID, ) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errUnexpectedDBError, err, ) @@ -73,8 +72,7 @@ func populateUserCourseGroups( if !rows.Next() { err := rows.Err() if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errUnexpectedDBError, err, ) @@ -84,8 +82,7 @@ func populateUserCourseGroups( var thisCourseID int err := rows.Scan(&thisCourseID) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errUnexpectedDBError, err, ) @@ -66,15 +66,14 @@ func setupCourses(ctx context.Context) error { "SELECT id, nmax, title, ctype, cgroup, teacher, location FROM courses", ) if err != nil { - return fmt.Errorf("%w: %w", errUnexpectedDBError, err) + return wrapError(errUnexpectedDBError, err) } for { if !rows.Next() { err := rows.Err() if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errUnexpectedDBError, err, ) @@ -92,7 +91,7 @@ func setupCourses(ctx context.Context) error { ¤tCourse.Location, ) if err != nil { - return fmt.Errorf("%w: %w", errUnexpectedDBError, err) + return wrapError(errUnexpectedDBError, err) } if !checkCourseType(currentCourse.Type) { return fmt.Errorf( @@ -116,8 +115,7 @@ func setupCourses(ctx context.Context) error { currentCourse.ID, ).Scan(¤tCourse.Selected) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errUnexpectedDBError, err, ) @@ -141,8 +139,7 @@ func (course *courseT) decrementSelectedAndPropagate( go propagateSelectedUpdate(course) err := sendSelectedUpdate(ctx, conn, course.ID) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) diff --git a/database.go b/database.go index e08f42f..337793a 100644 --- a/database.go +++ b/database.go @@ -22,7 +22,6 @@ package main import ( "context" - "fmt" "github.com/jackc/pgx/v5/pgxpool" ) @@ -42,7 +41,7 @@ func setupDatabase() error { } db, err = pgxpool.New(context.Background(), config.DB.Conn) if err != nil { - return fmt.Errorf("%w: %w", errUnexpectedDBError, err) + return wrapError(errUnexpectedDBError, err) } return nil } diff --git a/endpoint_auth.go b/endpoint_auth.go index 58eb46b..85acab8 100644 --- a/endpoint_auth.go +++ b/endpoint_auth.go @@ -270,7 +270,7 @@ func setupJwks() error { var err error myKeyfunc, err = keyfunc.NewDefault([]string{config.Auth.Jwks}) if err != nil { - return fmt.Errorf("%w: %w", errCannotSetupJwks, err) + return wrapError(errCannotSetupJwks, err) } return nil } @@ -291,14 +291,14 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { nil, ) if err != nil { - return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err) + return "", wrapError(errCannotGetDepartment, err) } req.Header.Set("Authorization", "Bearer "+accessToken) client := &http.Client{} //exhaustruct:ignore resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err) + return "", wrapError(errCannotGetDepartment, err) } defer resp.Body.Close() @@ -309,7 +309,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { decoder := json.NewDecoder(resp.Body) err = decoder.Decode(&departmentWrap) if err != nil { - return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err) + return "", wrapError(errCannotGetDepartment, err) } if departmentWrap.Department == nil { @@ -318,8 +318,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) { * "department" field, which hopefully doesn't occur as we * have specified $select=department in the OData query. */ - return "", fmt.Errorf( - "%w: %w", + return "", wrapError( errCannotGetDepartment, errInsufficientFields, ) @@ -355,12 +354,12 @@ func getAccessToken( ) if err != nil { return accessToken, - fmt.Errorf("%w: %w", errCannotFetchAccessToken, err) + wrapError(errCannotFetchAccessToken, err) } resp, err := http.DefaultClient.Do(req) if err != nil { return accessToken, - fmt.Errorf("%w: %w", errCannotFetchAccessToken, err) + wrapError(errCannotFetchAccessToken, err) } defer resp.Body.Close() @@ -368,7 +367,7 @@ func getAccessToken( err = decoder.Decode(&accessToken) if err != nil { return accessToken, - fmt.Errorf("%w: %w", errCannotFetchAccessToken, err) + wrapError(errCannotFetchAccessToken, err) } if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil { @@ -22,6 +22,7 @@ package main import ( "errors" + "fmt" ) var ( @@ -48,3 +49,10 @@ var ( errInvalidState = errors.New("invalid state") errWebSocketWrite = errors.New("error writing to websocket") ) + +func wrapError(a, b error) error { + if a == nil && b == nil { + return nil + } + return fmt.Errorf("%w: %w", a, b) +} diff --git a/misc_utils.go b/misc_utils.go index 4830fc1..d541118 100644 --- a/misc_utils.go +++ b/misc_utils.go @@ -23,7 +23,6 @@ package main import ( "crypto/rand" "encoding/base64" - "fmt" "log" "net/http" ) @@ -48,7 +47,7 @@ func randomString(sz int) (string, error) { r := make([]byte, 3*sz) _, err := rand.Read(r) if err != nil { - return "", fmt.Errorf("%w: %w", errCannotGenerateRandomString, err) + return "", wrapError(errCannotGenerateRandomString, err) } return base64.RawURLEncoding.EncodeToString(r), nil } @@ -53,10 +53,10 @@ func loadState() error { _state, ) if err != nil { - return fmt.Errorf("%w: %w", errUnexpectedDBError, err) + return wrapError(errUnexpectedDBError, err) } } else { - return fmt.Errorf("%w: %w", errUnexpectedDBError, err) + return wrapError(errUnexpectedDBError, err) } } atomic.StoreUint32(&state, _state) @@ -70,7 +70,7 @@ func saveStateValue(ctx context.Context, newState uint32) error { newState, ) if err != nil { - return fmt.Errorf("%w: %w", errUnexpectedDBError, err) + return wrapError(errUnexpectedDBError, err) } return nil } diff --git a/ws_connection.go b/ws_connection.go index 63ae86b..48c2b3a 100644 --- a/ws_connection.go +++ b/ws_connection.go @@ -233,16 +233,14 @@ func handleConn( * the cancel signal and another event arrive while * processing a select cycle. */ - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, newCtx.Err(), ) case sendText := <-send: select { case <-newCtx.Done(): - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, newCtx.Err(), ) @@ -256,8 +254,7 @@ func handleConn( case courseID := <-usemParent: select { case <-newCtx.Done(): - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, newCtx.Err(), ) @@ -266,8 +263,7 @@ func handleConn( err := sendSelectedUpdate(newCtx, c, courseID) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -276,8 +272,7 @@ func handleConn( case errbytes := <-recv: select { case <-newCtx.Done(): - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, newCtx.Err(), ) @@ -285,8 +280,7 @@ func handleConn( } if errbytes.err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotReceiveMessage, errbytes.err, ) diff --git a/ws_utils.go b/ws_utils.go index 949863f..adf4677 100644 --- a/ws_utils.go +++ b/ws_utils.go @@ -165,7 +165,7 @@ func propagate(msg string) { func writeText(ctx context.Context, c *websocket.Conn, msg string) error { err := c.Write(ctx, websocket.MessageText, []byte(msg)) if err != nil { - return fmt.Errorf("%w: %w", errWebSocketWrite, err) + return wrapError(errWebSocketWrite, err) } return nil } diff --git a/wsmsg_choose.go b/wsmsg_choose.go index bde9bb8..9b8708b 100644 --- a/wsmsg_choose.go +++ b/wsmsg_choose.go @@ -44,8 +44,7 @@ func messageChooseCourse( if atomic.LoadUint32(&state) != 2 { err := writeText(ctx, c, "E :Course selections are not open") if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -55,8 +54,7 @@ func messageChooseCourse( select { case <-ctx.Done(): - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, ctx.Err(), ) @@ -87,8 +85,7 @@ func messageChooseCourse( if _, ok := (*userCourseGroups)[course.Group]; ok { err := writeText(ctx, c, "R "+mar[1]+" :Group conflict") if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -159,8 +156,7 @@ func messageChooseCourse( if err != nil { err := course.decrementSelectedAndPropagate(ctx, c) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -178,8 +174,7 @@ func messageChooseCourse( err = writeText(ctx, c, "Y "+mar[1]) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -188,8 +183,7 @@ func messageChooseCourse( if config.Perf.PropagateImmediate { err = sendSelectedUpdate(ctx, c, courseID) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -204,8 +198,7 @@ func messageChooseCourse( } err = writeText(ctx, c, "R "+mar[1]+" :Full") if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) diff --git a/wsmsg_hello.go b/wsmsg_hello.go index f0bc191..f7a984c 100644 --- a/wsmsg_hello.go +++ b/wsmsg_hello.go @@ -22,7 +22,6 @@ package main import ( "context" - "fmt" "strings" "sync/atomic" @@ -41,8 +40,7 @@ func messageHello( select { case <-ctx.Done(): - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, ctx.Err(), ) @@ -65,12 +63,12 @@ func messageHello( if atomic.LoadUint32(&state) == 2 { err = writeText(ctx, c, "START") if err != nil { - return fmt.Errorf("%w: %w", errCannotSend, err) + return wrapError(errCannotSend, err) } } err = writeText(ctx, c, "HI :"+strings.Join(courseIDs, ",")) if err != nil { - return fmt.Errorf("%w: %w", errCannotSend, err) + return wrapError(errCannotSend, err) } return nil diff --git a/wsmsg_unchoose.go b/wsmsg_unchoose.go index e3a7ec6..a99e3f4 100644 --- a/wsmsg_unchoose.go +++ b/wsmsg_unchoose.go @@ -22,7 +22,6 @@ package main import ( "context" - "fmt" "strconv" "sync/atomic" @@ -40,8 +39,7 @@ func messageUnchooseCourse( if atomic.LoadUint32(&state) != 2 { err := writeText(ctx, c, "E :Course selections are not open") if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -51,8 +49,7 @@ func messageUnchooseCourse( select { case <-ctx.Done(): - return fmt.Errorf( - "%w: %w", + return wrapError( errContextCancelled, ctx.Err(), ) @@ -95,8 +92,7 @@ func messageUnchooseCourse( if ct.RowsAffected() != 0 { err := course.decrementSelectedAndPropagate(ctx, c) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) @@ -122,8 +118,7 @@ func messageUnchooseCourse( err = writeText(ctx, c, "N "+mar[1]) if err != nil { - return fmt.Errorf( - "%w: %w", + return wrapError( errCannotSend, err, ) |