summaryrefslogtreecommitdiff
path: root/ws_connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'ws_connection.go')
-rw-r--r--ws_connection.go349
1 files changed, 349 insertions, 0 deletions
diff --git a/ws_connection.go b/ws_connection.go
new file mode 100644
index 0000000..c407a70
--- /dev/null
+++ b/ws_connection.go
@@ -0,0 +1,349 @@
+/*
+ * WebSocket connection routine
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/coder/websocket"
+)
+
+type errbytesT struct {
+ err error
+ bytes *[]byte
+}
+
+var usemCount int64 /* atomic */
+
+/*
+ * This is more appropriately typed as uint64, but it needs to be cast to int64
+ * later anyway due to time.Duration, so let's just use int64.
+ */
+
+/*
+ * The actual logic in handling the connection, after authentication has been
+ * completed.
+ */
+func handleConn(
+ ctx context.Context,
+ c *websocket.Conn,
+ session string,
+ userID string,
+) (retErr error) {
+ send := make(chan string, config.Perf.SendQ)
+ chanPool.Store(userID, &send)
+ defer chanPool.CompareAndDelete(userID, &send)
+
+ reportError := makeReportError(ctx, c)
+ newCtx, newCancel := context.WithCancel(ctx)
+
+ _cancel, ok := cancelPool.Load(userID)
+ if ok {
+ cancel, ok := _cancel.(*context.CancelFunc)
+ if ok && cancel != nil {
+ (*cancel)()
+ }
+ /* TODO: Make the cancel synchronous */
+ }
+ cancelPool.Store(userID, &newCancel)
+
+ defer func() {
+ cancelPool.CompareAndDelete(userID, &newCancel)
+ if errors.Is(retErr, context.Canceled) {
+ /*
+ * Only works if it's newCtx that has been cancelled
+ * rather than the original ctx, which is kinda what
+ * we intend
+ */
+ _ = writeText(ctx, c, "E :Context canceled")
+ }
+ /* TODO: Report errors properly */
+ }()
+
+ /* TODO: Tell the user their current choices here. Deprecate HELLO. */
+
+ usems := make(map[int]*usemT)
+
+ /* TODO: Check if the LoadUint32 here is a bit too much overhead */
+ atomic.AddInt64(&usemCount, int64(atomic.LoadUint32(&numCourses)))
+ courses.Range(func(key, value interface{}) bool {
+ /* TODO: Remember to change this too when changing the courseID type */
+ courseID, ok := key.(int)
+ if !ok {
+ panic("courses map has non-\"int\" keys")
+ }
+ course, ok := value.(*courseT)
+ if !ok {
+ panic("courses map has non-\"*courseT\" items")
+ }
+ usem := &usemT{} //exhaustruct:ignore
+ usem.init()
+ course.Usems.Store(userID, usem)
+ usems[courseID] = usem
+ return true
+ })
+
+ defer func() {
+ courses.Range(func(key, value interface{}) bool {
+ _ = key
+ course, ok := value.(*courseT)
+ if !ok {
+ panic("courses map has non-\"*courseT\" items")
+ }
+ course.Usems.Delete(userID)
+ return true
+ })
+ atomic.AddInt64(&usemCount, -int64(atomic.LoadUint32(&numCourses)))
+ }()
+
+ usemParent := make(chan int)
+ for courseID, usem := range usems {
+ go func() {
+ for {
+ select {
+ case <-newCtx.Done():
+ return
+ case <-usem.ch:
+ select {
+ case <-newCtx.Done():
+ return
+ case usemParent <- courseID:
+ }
+ }
+ time.Sleep(
+ time.Duration(
+ atomic.LoadInt64(&usemCount)>>
+ config.Perf.UsemDelayShiftBits,
+ ) * time.Millisecond,
+ )
+ }
+ }()
+ }
+
+ /*
+ * userCourseGroups stores whether the user has already chosen a course
+ * in the courseGroup.
+ */
+ var userCourseGroups userCourseGroupsT = make(map[courseGroupT]struct{})
+ err := populateUserCourseGroups(newCtx, &userCourseGroups, userID)
+ if err != nil {
+ return reportError(
+ fmt.Sprintf(
+ "cannot populate user course groups: %v",
+ err,
+ ),
+ )
+ }
+
+ /*
+ * Later we need to select from recv and send and perform the
+ * corresponding action. But we can't just select from c.Read because
+ * the function blocks. Therefore, we must spawn a goroutine that
+ * blocks on c.Read and send what it receives to a channel "recv"; and
+ * then we can select from that channel.
+ */
+ recv := make(chan *errbytesT)
+ go func() {
+ for {
+ /*
+ * Here we use the original connection context instead
+ * of the new context we just created. Apparently when
+ * the context passed to Read expires, the connection
+ * gets closed, which makes it impossible for us to
+ * write the context expiry message to the client.
+ * So we pass the original connection context, which
+ * would get cancelled anyway once we close the
+ * connection.
+ * See: https://github.com/coder/websocket/issues/242
+ * We still need to take care of this while sending so
+ * we don't infinitely block, and leak goroutines and
+ * cause the channel to remain out of reach of the
+ * garbage collector.
+ * It would be nice to return the newCtx.Err() but
+ * the only way to really do that is to use the recv
+ * channel which might not have a listener anymore.
+ * It's not really crucial anyways so we could just
+ * close this goroutine by returning here.
+ */
+ _, b, err := c.Read(ctx)
+ if err != nil {
+ /*
+ * TODO: Prioritize context dones... except
+ * that it's not really possible. I would just
+ * have placed newCtx in here but apparently
+ * that causes the connection to be closed when
+ * the context expires, which makes it
+ * impossible to deliver the final error
+ * message. Probably need to look into this
+ * design again.
+ */
+ select {
+ case <-newCtx.Done():
+ _ = writeText(
+ ctx,
+ c,
+ "E :Context canceled",
+ )
+ /* Not a typo to use ctx here */
+ return
+ case recv <- &errbytesT{err: err, bytes: nil}:
+ }
+ return
+ }
+ select {
+ case <-newCtx.Done():
+ _ = writeText(ctx, c, "E :Context cancelled")
+ /* Not a typo to use ctx here */
+ return
+ case recv <- &errbytesT{err: nil, bytes: &b}:
+ }
+ }
+ }()
+
+ for {
+ var mar []string
+ select {
+ case <-newCtx.Done():
+ /*
+ * We select this context done channel when entering
+ * other cases too (see below) because we need to
+ * make sure the context cancel works even if both
+ * the cancel signal and another event arrive while
+ * processing a select cycle.
+ */
+ return fmt.Errorf(
+ "%w: %w",
+ errContextCancelled,
+ newCtx.Err(),
+ )
+ case sendText := <-send:
+ select {
+ case <-newCtx.Done():
+ return fmt.Errorf(
+ "%w: %w",
+ errContextCancelled,
+ newCtx.Err(),
+ )
+ default:
+ }
+
+ err := writeText(newCtx, c, sendText)
+ if err != nil {
+ return err
+ }
+ case courseID := <-usemParent:
+ select {
+ case <-newCtx.Done():
+ return fmt.Errorf(
+ "%w: %w",
+ errContextCancelled,
+ newCtx.Err(),
+ )
+ default:
+ }
+
+ err := sendSelectedUpdate(newCtx, c, courseID)
+ if err != nil {
+ return fmt.Errorf(
+ "%w: %w",
+ errCannotSend,
+ err,
+ )
+ }
+ continue
+ case errbytes := <-recv:
+ select {
+ case <-newCtx.Done():
+ return fmt.Errorf(
+ "%w: %w",
+ errContextCancelled,
+ newCtx.Err(),
+ )
+ default:
+ }
+
+ if errbytes.err != nil {
+ return fmt.Errorf(
+ "%w: %w",
+ errCannotReceiveMessage,
+ errbytes.err,
+ )
+ /*
+ * Note that this cannot return newCtx.Err(),
+ * so we handle the error reporting in the
+ * reading routine
+ */
+ }
+ mar = splitMsg(errbytes.bytes)
+ switch mar[0] {
+ case "HELLO":
+ err := messageHello(
+ newCtx,
+ c,
+ reportError,
+ mar,
+ userID,
+ session,
+ )
+ if err != nil {
+ return err
+ }
+ case "Y":
+ err := messageChooseCourse(
+ newCtx,
+ c,
+ reportError,
+ mar,
+ userID,
+ session,
+ &userCourseGroups,
+ )
+ if err != nil {
+ return err
+ }
+ case "N":
+ err := messageUnchooseCourse(
+ newCtx,
+ c,
+ reportError,
+ mar,
+ userID,
+ session,
+ &userCourseGroups,
+ )
+ if err != nil {
+ return err
+ }
+ default:
+ return reportError("Unknown command " + mar[0])
+ }
+ }
+ }
+}
+
+var cancelPool sync.Map /* string, *context.CancelFunc */
+
+var chanPool sync.Map /* string, *chan string */