summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--courses.go25
-rw-r--r--ws.go6
-rw-r--r--wsm.go32
3 files changed, 51 insertions, 12 deletions
diff --git a/courses.go b/courses.go
index e49faf3..e38e2dc 100644
--- a/courses.go
+++ b/courses.go
@@ -166,7 +166,9 @@ func setupCourses() error {
return nil
}
-func populateUserCourseGroups(ctx context.Context, userCourseGroups *map[courseGroupT]bool, userID string) error {
+type userCourseGroupsT map[courseGroupT]bool
+
+func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseGroupsT, userID string) error {
rows, err := db.Query(ctx, "SELECT courseid FROM choices WHERE userid = $1", userID)
if err != nil {
return fmt.Errorf("error querying user's choices while populating course groups: %w", err)
@@ -184,13 +186,9 @@ func populateUserCourseGroups(ctx context.Context, userCourseGroups *map[courseG
if err != nil {
return fmt.Errorf("error fetching user's choices while populating course groups: %w", err)
}
- var thisGroupName courseGroupT
- err = db.QueryRow(ctx,
- "SELECT cgroup FROM courses WHERE id = $1",
- thisCourseID,
- ).Scan(&thisGroupName)
+ thisGroupName, err := getCourseGroupFromCourseID(ctx, thisCourseID)
if err != nil {
- return fmt.Errorf("error querying group of course: %w", err)
+ return fmt.Errorf("error while populating course groups: %w", err)
}
if (*userCourseGroups)[thisGroupName] {
return fmt.Errorf("%w: user %v, group %v", errMultipleChoicesInOneGroup, userID, thisGroupName)
@@ -199,3 +197,16 @@ func populateUserCourseGroups(ctx context.Context, userCourseGroups *map[courseG
}
return nil
}
+
+func getCourseGroupFromCourseID(ctx context.Context, courseID int) (courseGroupT, error) {
+ var ret courseGroupT
+ err := db.QueryRow(
+ ctx,
+ "SELECT cgroup FROM courses WHERE id = $1",
+ courseID,
+ ).Scan(&ret)
+ if err != nil {
+ return ret, fmt.Errorf("error querying group of course: %w", err)
+ }
+ return ret, nil
+}
diff --git a/ws.go b/ws.go
index 2ac724d..9a0139d 100644
--- a/ws.go
+++ b/ws.go
@@ -381,7 +381,7 @@ func handleConn(
* userCourseGroups stores whether the user has already chosen a course
* in the courseGroup.
*/
- userCourseGroups := make(map[courseGroupT]bool)
+ var userCourseGroups userCourseGroupsT = make(map[courseGroupT]bool)
populateUserCourseGroups(newCtx, &userCourseGroups, userID)
/*
* TODO: No more HELLO command needed? Or otherwise integrate the two.
@@ -490,12 +490,12 @@ func handleConn(
return err
}
case "Y":
- err := messageChooseCourse(newCtx, c, reportError, mar, userID, session)
+ err := messageChooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups)
if err != nil {
return err
}
case "N":
- err := messageUnchooseCourse(newCtx, c, reportError, mar, userID, session)
+ err := messageUnchooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups)
if err != nil {
return err
}
diff --git a/wsm.go b/wsm.go
index 0808438..50527c0 100644
--- a/wsm.go
+++ b/wsm.go
@@ -63,7 +63,7 @@ func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErro
return nil
}
-func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string) error {
+func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error {
_ = session
select {
@@ -140,6 +140,26 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
}()
return reportError("Database error while committing transaction")
}
+ thisCourseGroup, err := getCourseGroupFromCourseID(ctx, courseID)
+ if err != nil {
+ go func() { /* Duplicate code, could turn into function */
+ course.SelectedLock.Lock()
+ defer course.SelectedLock.Unlock()
+ course.Selected--
+ propagateIgnoreFailures(fmt.Sprintf("M %d %d", courseID, course.Selected))
+ }()
+ return reportError("Database error while committing transaction")
+ }
+ if (*userCourseGroups)[thisCourseGroup] {
+ go func() { /* Duplicate code, could turn into function */
+ course.SelectedLock.Lock()
+ defer course.SelectedLock.Unlock()
+ course.Selected--
+ propagateIgnoreFailures(fmt.Sprintf("M %d %d", courseID, course.Selected))
+ }()
+ return reportError("inconsistent user course groups")
+ }
+ (*userCourseGroups)[thisCourseGroup] = true
err = writeText(ctx, c, "Y "+mar[1])
if err != nil {
return fmt.Errorf("error affirming course choice: %w", err)
@@ -162,7 +182,7 @@ func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError rep
return nil
}
-func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string) error {
+func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error {
_ = session
select {
@@ -202,6 +222,14 @@ func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError r
course.Selected--
propagateIgnoreFailures(fmt.Sprintf("M %d %d", courseID, course.Selected))
}()
+ thisCourseGroup, err := getCourseGroupFromCourseID(ctx, courseID)
+ if err != nil {
+ return reportError("error unsetting course group flag")
+ }
+ if (*userCourseGroups)[thisCourseGroup] == false {
+ return reportError("inconsistent user course groups")
+ }
+ (*userCourseGroups)[thisCourseGroup] = false
}
err = writeText(ctx, c, "N "+mar[1])