summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.go23
-rw-r--r--course_groups.go9
-rw-r--r--courses.go13
-rw-r--r--database.go3
-rw-r--r--endpoint_auth.go17
-rw-r--r--errors.go8
-rw-r--r--misc_utils.go3
-rw-r--r--state.go6
-rw-r--r--ws_connection.go18
-rw-r--r--ws_utils.go2
-rw-r--r--wsmsg_choose.go21
-rw-r--r--wsmsg_hello.go8
-rw-r--r--wsmsg_unchoose.go13
13 files changed, 53 insertions, 91 deletions
diff --git a/config.go b/config.go
index e991516..576ac6d 100644
--- a/config.go
+++ b/config.go
@@ -111,36 +111,19 @@ var config struct {
func fetchConfig(path string) (retErr error) {
defer func() {
- if v := recover(); v != nil {
- s, ok := v.(error)
- if ok {
- retErr = fmt.Errorf(
- "%w: %w",
- errCannotProcessConfig,
- s,
- )
- }
- retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v)
- return
- }
if retErr != nil {
- retErr = fmt.Errorf(
- "%w: %w",
- errCannotProcessConfig,
- retErr,
- )
- return
+ retErr = wrapError(errCannotProcessConfig, retErr)
}
}()
f, err := os.Open(path)
if err != nil {
- return fmt.Errorf("%w: %w", errCannotOpenConfig, err)
+ return wrapError(errCannotOpenConfig, err)
}
err = scfg.NewDecoder(bufio.NewReader(f)).Decode(&configWithPointers)
if err != nil {
- return fmt.Errorf("%w: %w", errCannotDecodeConfig, err)
+ return wrapError(errCannotDecodeConfig, err)
}
if configWithPointers.URL == nil {
diff --git a/course_groups.go b/course_groups.go
index c1541ce..f81fcfa 100644
--- a/course_groups.go
+++ b/course_groups.go
@@ -63,8 +63,7 @@ func populateUserCourseGroups(
userID,
)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errUnexpectedDBError,
err,
)
@@ -73,8 +72,7 @@ func populateUserCourseGroups(
if !rows.Next() {
err := rows.Err()
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errUnexpectedDBError,
err,
)
@@ -84,8 +82,7 @@ func populateUserCourseGroups(
var thisCourseID int
err := rows.Scan(&thisCourseID)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errUnexpectedDBError,
err,
)
diff --git a/courses.go b/courses.go
index 4591cfb..8751fa1 100644
--- a/courses.go
+++ b/courses.go
@@ -66,15 +66,14 @@ func setupCourses(ctx context.Context) error {
"SELECT id, nmax, title, ctype, cgroup, teacher, location FROM courses",
)
if err != nil {
- return fmt.Errorf("%w: %w", errUnexpectedDBError, err)
+ return wrapError(errUnexpectedDBError, err)
}
for {
if !rows.Next() {
err := rows.Err()
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errUnexpectedDBError,
err,
)
@@ -92,7 +91,7 @@ func setupCourses(ctx context.Context) error {
&currentCourse.Location,
)
if err != nil {
- return fmt.Errorf("%w: %w", errUnexpectedDBError, err)
+ return wrapError(errUnexpectedDBError, err)
}
if !checkCourseType(currentCourse.Type) {
return fmt.Errorf(
@@ -116,8 +115,7 @@ func setupCourses(ctx context.Context) error {
currentCourse.ID,
).Scan(&currentCourse.Selected)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errUnexpectedDBError,
err,
)
@@ -141,8 +139,7 @@ func (course *courseT) decrementSelectedAndPropagate(
go propagateSelectedUpdate(course)
err := sendSelectedUpdate(ctx, conn, course.ID)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
diff --git a/database.go b/database.go
index e08f42f..337793a 100644
--- a/database.go
+++ b/database.go
@@ -22,7 +22,6 @@ package main
import (
"context"
- "fmt"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -42,7 +41,7 @@ func setupDatabase() error {
}
db, err = pgxpool.New(context.Background(), config.DB.Conn)
if err != nil {
- return fmt.Errorf("%w: %w", errUnexpectedDBError, err)
+ return wrapError(errUnexpectedDBError, err)
}
return nil
}
diff --git a/endpoint_auth.go b/endpoint_auth.go
index 58eb46b..85acab8 100644
--- a/endpoint_auth.go
+++ b/endpoint_auth.go
@@ -270,7 +270,7 @@ func setupJwks() error {
var err error
myKeyfunc, err = keyfunc.NewDefault([]string{config.Auth.Jwks})
if err != nil {
- return fmt.Errorf("%w: %w", errCannotSetupJwks, err)
+ return wrapError(errCannotSetupJwks, err)
}
return nil
}
@@ -291,14 +291,14 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
nil,
)
if err != nil {
- return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err)
+ return "", wrapError(errCannotGetDepartment, err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
client := &http.Client{} //exhaustruct:ignore
resp, err := client.Do(req)
if err != nil {
- return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err)
+ return "", wrapError(errCannotGetDepartment, err)
}
defer resp.Body.Close()
@@ -309,7 +309,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&departmentWrap)
if err != nil {
- return "", fmt.Errorf("%w: %w", errCannotGetDepartment, err)
+ return "", wrapError(errCannotGetDepartment, err)
}
if departmentWrap.Department == nil {
@@ -318,8 +318,7 @@ func getDepartment(ctx context.Context, accessToken string) (string, error) {
* "department" field, which hopefully doesn't occur as we
* have specified $select=department in the OData query.
*/
- return "", fmt.Errorf(
- "%w: %w",
+ return "", wrapError(
errCannotGetDepartment,
errInsufficientFields,
)
@@ -355,12 +354,12 @@ func getAccessToken(
)
if err != nil {
return accessToken,
- fmt.Errorf("%w: %w", errCannotFetchAccessToken, err)
+ wrapError(errCannotFetchAccessToken, err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return accessToken,
- fmt.Errorf("%w: %w", errCannotFetchAccessToken, err)
+ wrapError(errCannotFetchAccessToken, err)
}
defer resp.Body.Close()
@@ -368,7 +367,7 @@ func getAccessToken(
err = decoder.Decode(&accessToken)
if err != nil {
return accessToken,
- fmt.Errorf("%w: %w", errCannotFetchAccessToken, err)
+ wrapError(errCannotFetchAccessToken, err)
}
if accessToken.Error != nil || accessToken.ErrorCodes != nil ||
accessToken.ErrorDescription != nil {
diff --git a/errors.go b/errors.go
index 4597b3f..77de2a3 100644
--- a/errors.go
+++ b/errors.go
@@ -22,6 +22,7 @@ package main
import (
"errors"
+ "fmt"
)
var (
@@ -48,3 +49,10 @@ var (
errInvalidState = errors.New("invalid state")
errWebSocketWrite = errors.New("error writing to websocket")
)
+
+func wrapError(a, b error) error {
+ if a == nil && b == nil {
+ return nil
+ }
+ return fmt.Errorf("%w: %w", a, b)
+}
diff --git a/misc_utils.go b/misc_utils.go
index 4830fc1..d541118 100644
--- a/misc_utils.go
+++ b/misc_utils.go
@@ -23,7 +23,6 @@ package main
import (
"crypto/rand"
"encoding/base64"
- "fmt"
"log"
"net/http"
)
@@ -48,7 +47,7 @@ func randomString(sz int) (string, error) {
r := make([]byte, 3*sz)
_, err := rand.Read(r)
if err != nil {
- return "", fmt.Errorf("%w: %w", errCannotGenerateRandomString, err)
+ return "", wrapError(errCannotGenerateRandomString, err)
}
return base64.RawURLEncoding.EncodeToString(r), nil
}
diff --git a/state.go b/state.go
index e4c468b..20e1d84 100644
--- a/state.go
+++ b/state.go
@@ -53,10 +53,10 @@ func loadState() error {
_state,
)
if err != nil {
- return fmt.Errorf("%w: %w", errUnexpectedDBError, err)
+ return wrapError(errUnexpectedDBError, err)
}
} else {
- return fmt.Errorf("%w: %w", errUnexpectedDBError, err)
+ return wrapError(errUnexpectedDBError, err)
}
}
atomic.StoreUint32(&state, _state)
@@ -70,7 +70,7 @@ func saveStateValue(ctx context.Context, newState uint32) error {
newState,
)
if err != nil {
- return fmt.Errorf("%w: %w", errUnexpectedDBError, err)
+ return wrapError(errUnexpectedDBError, err)
}
return nil
}
diff --git a/ws_connection.go b/ws_connection.go
index 63ae86b..48c2b3a 100644
--- a/ws_connection.go
+++ b/ws_connection.go
@@ -233,16 +233,14 @@ func handleConn(
* the cancel signal and another event arrive while
* processing a select cycle.
*/
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
newCtx.Err(),
)
case sendText := <-send:
select {
case <-newCtx.Done():
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
newCtx.Err(),
)
@@ -256,8 +254,7 @@ func handleConn(
case courseID := <-usemParent:
select {
case <-newCtx.Done():
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
newCtx.Err(),
)
@@ -266,8 +263,7 @@ func handleConn(
err := sendSelectedUpdate(newCtx, c, courseID)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -276,8 +272,7 @@ func handleConn(
case errbytes := <-recv:
select {
case <-newCtx.Done():
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
newCtx.Err(),
)
@@ -285,8 +280,7 @@ func handleConn(
}
if errbytes.err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotReceiveMessage,
errbytes.err,
)
diff --git a/ws_utils.go b/ws_utils.go
index 949863f..adf4677 100644
--- a/ws_utils.go
+++ b/ws_utils.go
@@ -165,7 +165,7 @@ func propagate(msg string) {
func writeText(ctx context.Context, c *websocket.Conn, msg string) error {
err := c.Write(ctx, websocket.MessageText, []byte(msg))
if err != nil {
- return fmt.Errorf("%w: %w", errWebSocketWrite, err)
+ return wrapError(errWebSocketWrite, err)
}
return nil
}
diff --git a/wsmsg_choose.go b/wsmsg_choose.go
index bde9bb8..9b8708b 100644
--- a/wsmsg_choose.go
+++ b/wsmsg_choose.go
@@ -44,8 +44,7 @@ func messageChooseCourse(
if atomic.LoadUint32(&state) != 2 {
err := writeText(ctx, c, "E :Course selections are not open")
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -55,8 +54,7 @@ func messageChooseCourse(
select {
case <-ctx.Done():
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
ctx.Err(),
)
@@ -87,8 +85,7 @@ func messageChooseCourse(
if _, ok := (*userCourseGroups)[course.Group]; ok {
err := writeText(ctx, c, "R "+mar[1]+" :Group conflict")
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -159,8 +156,7 @@ func messageChooseCourse(
if err != nil {
err := course.decrementSelectedAndPropagate(ctx, c)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -178,8 +174,7 @@ func messageChooseCourse(
err = writeText(ctx, c, "Y "+mar[1])
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -188,8 +183,7 @@ func messageChooseCourse(
if config.Perf.PropagateImmediate {
err = sendSelectedUpdate(ctx, c, courseID)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -204,8 +198,7 @@ func messageChooseCourse(
}
err = writeText(ctx, c, "R "+mar[1]+" :Full")
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
diff --git a/wsmsg_hello.go b/wsmsg_hello.go
index f0bc191..f7a984c 100644
--- a/wsmsg_hello.go
+++ b/wsmsg_hello.go
@@ -22,7 +22,6 @@ package main
import (
"context"
- "fmt"
"strings"
"sync/atomic"
@@ -41,8 +40,7 @@ func messageHello(
select {
case <-ctx.Done():
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
ctx.Err(),
)
@@ -65,12 +63,12 @@ func messageHello(
if atomic.LoadUint32(&state) == 2 {
err = writeText(ctx, c, "START")
if err != nil {
- return fmt.Errorf("%w: %w", errCannotSend, err)
+ return wrapError(errCannotSend, err)
}
}
err = writeText(ctx, c, "HI :"+strings.Join(courseIDs, ","))
if err != nil {
- return fmt.Errorf("%w: %w", errCannotSend, err)
+ return wrapError(errCannotSend, err)
}
return nil
diff --git a/wsmsg_unchoose.go b/wsmsg_unchoose.go
index e3a7ec6..a99e3f4 100644
--- a/wsmsg_unchoose.go
+++ b/wsmsg_unchoose.go
@@ -22,7 +22,6 @@ package main
import (
"context"
- "fmt"
"strconv"
"sync/atomic"
@@ -40,8 +39,7 @@ func messageUnchooseCourse(
if atomic.LoadUint32(&state) != 2 {
err := writeText(ctx, c, "E :Course selections are not open")
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -51,8 +49,7 @@ func messageUnchooseCourse(
select {
case <-ctx.Done():
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errContextCancelled,
ctx.Err(),
)
@@ -95,8 +92,7 @@ func messageUnchooseCourse(
if ct.RowsAffected() != 0 {
err := course.decrementSelectedAndPropagate(ctx, c)
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)
@@ -122,8 +118,7 @@ func messageUnchooseCourse(
err = writeText(ctx, c, "N "+mar[1])
if err != nil {
- return fmt.Errorf(
- "%w: %w",
+ return wrapError(
errCannotSend,
err,
)