summaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/Makefile2
-rw-r--r--backend/auth.go374
-rw-r--r--backend/config.go284
-rw-r--r--backend/courses.go229
-rw-r--r--backend/db.go51
-rw-r--r--backend/go.mod23
-rw-r--r--backend/go.sum42
-rw-r--r--backend/index.go120
-rw-r--r--backend/main.go125
-rw-r--r--backend/usem.go36
-rw-r--r--backend/utils.go58
-rw-r--r--backend/wsc.go266
-rw-r--r--backend/wsh.go165
-rw-r--r--backend/wsm.go224
-rw-r--r--backend/wsp.go85
-rw-r--r--backend/wsx.go59
16 files changed, 2143 insertions, 0 deletions
diff --git a/backend/Makefile b/backend/Makefile
new file mode 100644
index 0000000..08566d8
--- /dev/null
+++ b/backend/Makefile
@@ -0,0 +1,2 @@
+cca: *.go
+ go build -o cca
diff --git a/backend/auth.go b/backend/auth.go
new file mode 100644
index 0000000..9ef1254
--- /dev/null
+++ b/backend/auth.go
@@ -0,0 +1,374 @@
+/*
+ * Custom OAUTH 2.0 implementation for the CCA Selection Service
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/MicahParks/keyfunc/v3"
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/jackc/pgx/v5/pgconn"
+)
+
+var myKeyfunc keyfunc.Keyfunc
+
+var (
+ errInsufficientFields = errors.New("insufficient fields")
+ errAccessTokenIncompleteError = errors.New("access token has unpopulated error fields")
+ errTokenEndpointReturnedError = errors.New("token endpoint returned error")
+)
+
+const tokenLength = 20
+
+/*
+ * These are the claims in the JSON Web Token received from the client, after
+ * it redirects from the authorize endpoint. Some of these fields must be
+ * explicitly selected in the Azure app registration and might appear as
+ * zero strings if it hasn't been configured correctly.
+ */
+type msclaimsT struct {
+ Name string `json:"name"` /* Scope: profile */
+ Email string `json:"email"` /* Scope: email */
+ Oid string `json:"oid"` /* Scope: profile */
+ jwt.RegisteredClaims
+}
+
+func generateAuthorizationURL() (string, error) {
+ /*
+ * TODO: Handle nonces and anti-replay. Incremental nonces would be
+ * nice on memory and speed (depending on how maps are implemented in
+ * Go, hopefully it's some sort of btree), but that requires either
+ * hacky atomics or having a multiple goroutines to handle
+ * authentication, neither of which are desirable.
+ */
+ nonce, err := randomString(tokenLength)
+ if err != nil {
+ return "", err
+ }
+ /*
+ * Note that here we use a hybrid authentication flow to obtain an
+ * id_token for authentication and an authorization code. The
+ * authorization code may be used like any other; i.e., it may be used
+ * to obtain an access token directly, or the refresh token may be used
+ * to gain persistent access to the upstream API. Sometimes I wish that
+ * the JWT in id_token could have more claims. The only reason we
+ * presently use a hybrid flow is to use the authorization code to
+ * obtain an access code to call the user info endpoint to fetch the
+ * user's department information.
+ */
+ return fmt.Sprintf(
+ "https://login.microsoftonline.com/ddd3d26c-b197-4d00-a32d-1ffd84c0c295/oauth2/authorize?client_id=%s&response_type=id_token%%20code&redirect_uri=%s%%2Fauth&response_mode=form_post&scope=openid+profile+email+User.Read&nonce=%s",
+ config.Auth.Client,
+ config.URL,
+ nonce,
+ ), nil
+}
+
+/*
+ * Handles redirects to the /auth endpoint from the authorize endpoint.
+ * Expects JSON Web Keys to be already set up correctly; if myKeyfunc is null,
+ * a null pointer is dereferenced and the thread panics.
+ */
+func handleAuth(w http.ResponseWriter, req *http.Request) {
+ if req.Method != http.MethodPost {
+ wstr(w, http.StatusMethodNotAllowed, "Only POST is supported on the authentication endpoint")
+ return
+ }
+
+ err := req.ParseForm()
+ if err != nil {
+ wstr(w, http.StatusBadRequest, "Malformed form data")
+ return
+ }
+
+ returnedError := req.PostFormValue("error")
+ if returnedError != "" {
+ returnedErrorDescription := req.PostFormValue("error_description")
+ if returnedErrorDescription == "" {
+ wstr(w, http.StatusBadRequest, fmt.Sprintf("authorize endpoint returned error: %v", returnedErrorDescription))
+ return
+ }
+ wstr(w, http.StatusBadRequest, fmt.Sprintf(
+ "%s: %s",
+ returnedError,
+ returnedErrorDescription,
+ ))
+ return
+ }
+
+ idTokenString := req.PostFormValue("id_token")
+ if idTokenString == "" {
+ wstr(w, http.StatusBadRequest, "Missing id_token")
+ return
+ }
+
+ claimsTemplate := &msclaimsT{} //exhaustruct:ignore
+ token, err := jwt.ParseWithClaims(
+ idTokenString,
+ claimsTemplate,
+ myKeyfunc.Keyfunc,
+ )
+ if err != nil {
+ wstr(w, http.StatusBadRequest, "Cannot parse claims")
+ return
+ }
+
+ switch {
+ case token.Valid:
+ break
+ case errors.Is(err, jwt.ErrTokenMalformed):
+ wstr(w, http.StatusBadRequest, "Malformed JWT token")
+ return
+ case errors.Is(err, jwt.ErrTokenSignatureInvalid):
+ wstr(w, http.StatusBadRequest, "Invalid JWS signature")
+ return
+ case errors.Is(err, jwt.ErrTokenExpired) ||
+ errors.Is(err, jwt.ErrTokenNotValidYet):
+ wstr(w, http.StatusBadRequest, "JWT token expired or not yet valid")
+ return
+ default:
+ wstr(w, http.StatusBadRequest, "Unhandled JWT token error")
+ return
+ }
+
+ claims, claimsOk := token.Claims.(*msclaimsT)
+
+ if !claimsOk {
+ wstr(w, http.StatusBadRequest, "Cannot unpack claims")
+ return
+ }
+
+ authorizationCode := req.PostFormValue("code")
+
+ accessToken, err := getAccessToken(req.Context(), authorizationCode)
+ if err != nil {
+ wstr(w, http.StatusInternalServerError, fmt.Sprintf("Unable to fetch access token: %v", err))
+ return
+ }
+
+ department, err := getDepartment(req.Context(), *(accessToken.Content))
+ if err != nil {
+ wstr(w, http.StatusInternalServerError, err.Error())
+ return
+ }
+
+ switch {
+ case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || department == "High School Teaching & Learning 高中教学部门":
+ department = "Staff"
+ case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12":
+ default:
+ wstr(
+ w,
+ http.StatusForbidden,
+ fmt.Sprintf(
+ "Your department \"%s\" is unknown.\nWe currently only allow Y9, Y10, Y11, Y12, and the CCA office.",
+ department,
+ ),
+ )
+ return
+ }
+
+ cookieValue, err := randomString(tokenLength)
+ if err != nil {
+ wstr(w, http.StatusInternalServerError, err.Error())
+ return
+ }
+
+ now := time.Now()
+ expr := now.Add(time.Duration(config.Auth.Expr) * time.Second)
+ exprU := expr.Unix()
+
+ cookie := http.Cookie{
+ Name: "session",
+ Value: cookieValue,
+ SameSite: http.SameSiteLaxMode,
+ HttpOnly: true,
+ Secure: config.Prod,
+ Expires: expr,
+ } //exhaustruct:ignore
+
+ http.SetCookie(w, &cookie)
+
+ /*
+ * TODO: Here we attempt to insert and call update if we receive a
+ * conflict. This works but is not idiomatic (and could confuse the
+ * database administrator with database integrity warnings in the log.
+ * The INSERT statement actually supports updating on conflict:
+ * https://www.postgresql.org/docs/current/sql-insert.html
+ */
+ _, err = db.Exec(
+ req.Context(),
+ "INSERT INTO users (id, name, email, department, session, expr) VALUES ($1, $2, $3, $4, $5, $6)",
+ claims.Oid,
+ claims.Name,
+ claims.Email,
+ department,
+ cookieValue,
+ exprU,
+ )
+ if err != nil {
+ var pgErr *pgconn.PgError
+ if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation {
+ _, err := db.Exec(
+ req.Context(),
+ "UPDATE users SET (name, email, department, session, expr) = ($1, $2, $3, $4, $5) WHERE id = $6",
+ claims.Name,
+ claims.Email,
+ department,
+ cookieValue,
+ exprU,
+ claims.Oid,
+ )
+ if err != nil {
+ wstr(w, http.StatusInternalServerError, "Database error while updating account.")
+ return
+ }
+ } else {
+ wstr(w, http.StatusInternalServerError, "Database error while writing account info.")
+ return
+ }
+ }
+
+ http.Redirect(w, req, "/", http.StatusSeeOther)
+}
+
+/*
+ * Setting up JSON Web Keys. Note that myKeyfunc is a global variable.
+ */
+func setupJwks() error {
+ var err error
+ myKeyfunc, err = keyfunc.NewDefault([]string{config.Auth.Jwks})
+ if err != nil {
+ return fmt.Errorf("error setting up jwks: %w", err)
+ }
+ return nil
+}
+
+/*
+ * Fetch the department name of the user, mostly to identify which grade
+ * a student is in. This expects an accessToken obtained from the OAUTH 2.0
+ * token endpoint obtained via an authorization code. It might also be able
+ * to use this as part of a hybrid flow that directly provides access tokens,
+ * but this flow seems to be only usable for single-page applications according
+ * to the Azure portal.
+ */
+func getDepartment(ctx context.Context, accessToken string) (string, error) {
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodGet,
+ "https://graph.microsoft.com/v1.0/me?$select=department",
+ nil,
+ )
+ if err != nil {
+ return "", fmt.Errorf("error getting department: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+
+ client := &http.Client{} //exhaustruct:ignore
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", fmt.Errorf("error getting department: %w", err)
+ }
+ defer resp.Body.Close()
+
+ var departmentWrap struct {
+ Department *string `json:"department"`
+ }
+
+ decoder := json.NewDecoder(resp.Body)
+ err = decoder.Decode(&departmentWrap)
+ if err != nil {
+ return "", fmt.Errorf("error getting department: %w", err)
+ }
+
+ if departmentWrap.Department == nil {
+ /*
+ * This is probably because the response does not contain a
+ * "department" field, which hopefully doesn't occur as we
+ * have specified $select=department in the OData query.
+ */
+ return "", fmt.Errorf("error getting department: %w", errInsufficientFields)
+ }
+
+ return *(departmentWrap.Department), nil
+}
+
+/*
+ * TODO: Access token expiration is not checked anywhere.
+ */
+type accessTokenT struct {
+ OriginalExpiresIn *int `json:"expires_in"` /* Original time to expiration */
+ Expiration time.Time
+ Content *string `json:"access_token"`
+ Error *string `json:"error"`
+ ErrorDescription *string `json:"error_description"`
+ ErrorCodes *[]int `json:"error_codes"`
+}
+
+/*
+ * Obtain an access token from the token endpoint with an existing
+ * authorization code.
+ */
+func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT, error) {
+ var accessToken accessTokenT
+ t := time.Now()
+ v := url.Values{}
+ v.Set("client_id", config.Auth.Client)
+ v.Set("scope", "https://graph.microsoft.com/User.Read")
+ v.Set("code", authorizationCode)
+ v.Set("redirect_uri", config.URL+"/auth")
+ v.Set("grant_type", "authorization_code")
+ v.Set("client_secret", config.Auth.Secret)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, config.Auth.Token, strings.NewReader(v.Encode()))
+ if err != nil {
+ return accessToken, fmt.Errorf("error making access token request: %w", err)
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return accessToken, fmt.Errorf("error requesting access token: %w", err)
+ }
+ defer resp.Body.Close()
+
+ decoder := json.NewDecoder(resp.Body)
+ err = decoder.Decode(&accessToken)
+ if err != nil {
+ return accessToken, fmt.Errorf("error decoding access token: %w", err)
+ }
+ if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil {
+ if accessToken.Error == nil || accessToken.ErrorCodes == nil || accessToken.ErrorDescription == nil {
+ return accessToken, errAccessTokenIncompleteError
+ }
+ return accessToken, fmt.Errorf("%w: %v", errTokenEndpointReturnedError, *accessToken.ErrorDescription)
+ }
+ if accessToken.Content == nil || accessToken.OriginalExpiresIn == nil {
+ return accessToken, fmt.Errorf("error extracting access token: %w", errInsufficientFields)
+ }
+ accessToken.Expiration = t.Add(time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second)
+
+ return accessToken, nil
+}
diff --git a/backend/config.go b/backend/config.go
new file mode 100644
index 0000000..e5bea02
--- /dev/null
+++ b/backend/config.go
@@ -0,0 +1,284 @@
+/*
+ * Handle configuration
+ *
+ * Copyright (c) 2024 Runxi Yu <me@runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "log"
+ "os"
+
+ "git.sr.ht/~emersion/go-scfg"
+)
+
+/*
+ * We use two structs. The first has all of its values as pointers, and scfg
+ * unmarshals the configuration to it. Then we take each value, dereference
+ * it, and throw it into a normal config struct without pointers.
+ * This means that any missing configuration options will simply cause a
+ * segmentation fault.
+ */
+
+var configWithPointers struct {
+ URL *string `scfg:"url"`
+ Prod *bool `scfg:"prod"`
+ Tmpl *string `scfg:"tmpl"`
+ Static *string `scfg:"static"`
+ Source *string `scfg:"source"`
+ Listen struct {
+ Proto *string `scfg:"proto"`
+ Net *string `scfg:"net"`
+ Addr *string `scfg:"addr"`
+ Trans *string `scfg:"trans"`
+ TLS struct {
+ Cert *string `scfg:"cert"`
+ Key *string `scfg:"key"`
+ } `scfg:"tls"`
+ } `scfg:"listen"`
+ DB struct {
+ Type *string `scfg:"type"`
+ Conn *string `scfg:"conn"`
+ } `scfg:"db"`
+ Auth struct {
+ Fake *int `scfg:"fake"`
+ Client *string `scfg:"client"`
+ Authorize *string `scfg:"authorize"`
+ Jwks *string `scfg:"jwks"`
+ Token *string `scfg:"token"`
+ Secret *string `scfg:"secret"`
+ Expr *int `scfg:"expr"`
+ } `scfg:"auth"`
+ Perf struct {
+ MessageArgumentsCap *int `scfg:"msg_args_cap"`
+ MessageBytesCap *int `scfg:"msg_bytes_cap"`
+ ReadHeaderTimeout *int `scfg:"read_header_timeout"`
+ CourseUpdateInterval *int `scfg:"course_update_interval"`
+ } `scfg:"perf"`
+}
+
+var config struct {
+ URL string
+ Prod bool
+ Tmpl string
+ Static string
+ Source string
+ Listen struct {
+ Proto string
+ Net string
+ Addr string
+ Trans string
+ TLS struct {
+ Cert string
+ Key string
+ }
+ }
+ DB struct {
+ Type string
+ Conn string
+ }
+ Auth struct {
+ Fake int
+ Client string
+ Authorize string
+ Jwks string
+ Token string
+ Secret string
+ Expr int
+ }
+ Perf struct {
+ MessageArgumentsCap int
+ MessageBytesCap int
+ ReadHeaderTimeout int
+ CourseUpdateInterval int
+ } `scfg:"perf"`
+}
+
+var (
+ errCannotProcessConfig = errors.New("cannot process configuration file")
+ errCannotOpenConfig = errors.New("cannot open configuration file")
+ errCannotDecodeConfig = errors.New("cannot decode configuration file")
+ errMissingConfigValue = errors.New("missing configuration value")
+ errIllegalConfig = errors.New("illegal configuration")
+)
+
+func fetchConfig(path string) (retErr error) {
+ defer func() {
+ if v := recover(); v != nil {
+ s, ok := v.(error)
+ if ok {
+ retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, s)
+ }
+ retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v)
+ return
+ }
+ if retErr != nil {
+ retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, retErr)
+ return
+ }
+ }()
+
+ f, err := os.Open(path)
+ if err != nil {
+ return fmt.Errorf("%w: %w", errCannotOpenConfig, err)
+ }
+
+ err = scfg.NewDecoder(bufio.NewReader(f)).Decode(&configWithPointers)
+ if err != nil {
+ return fmt.Errorf("%w: %w", errCannotDecodeConfig, err)
+ }
+
+ if configWithPointers.URL == nil {
+ return fmt.Errorf("%w: url", errMissingConfigValue)
+ }
+ config.URL = *(configWithPointers.URL)
+
+ if configWithPointers.Prod == nil {
+ return fmt.Errorf("%w: prod", errMissingConfigValue)
+ }
+ config.Prod = *(configWithPointers.Prod)
+
+ if configWithPointers.Tmpl == nil {
+ return fmt.Errorf("%w: tmpl", errMissingConfigValue)
+ }
+ config.Tmpl = *(configWithPointers.Tmpl)
+
+ if configWithPointers.Static == nil {
+ return fmt.Errorf("%w: static", errMissingConfigValue)
+ }
+ config.Static = *(configWithPointers.Static)
+
+ if configWithPointers.Source == nil {
+ return fmt.Errorf("%w: source", errMissingConfigValue)
+ }
+ config.Source = *(configWithPointers.Source)
+
+ if configWithPointers.Listen.Proto == nil {
+ return fmt.Errorf("%w: listen.proto", errMissingConfigValue)
+ }
+ config.Listen.Proto = *(configWithPointers.Listen.Proto)
+
+ if configWithPointers.Listen.Net == nil {
+ return fmt.Errorf("%w: listen.net", errMissingConfigValue)
+ }
+ config.Listen.Net = *(configWithPointers.Listen.Net)
+
+ if configWithPointers.Listen.Addr == nil {
+ return fmt.Errorf("%w: listen.addr", errMissingConfigValue)
+ }
+ config.Listen.Addr = *(configWithPointers.Listen.Addr)
+
+ if configWithPointers.Listen.Trans == nil {
+ return fmt.Errorf("%w: listen.trans", errMissingConfigValue)
+ }
+ config.Listen.Trans = *(configWithPointers.Listen.Trans)
+
+ if config.Listen.Trans == "tls" {
+ if configWithPointers.Listen.TLS.Cert == nil {
+ return fmt.Errorf("%w: listen.tls.cert", errMissingConfigValue)
+ }
+ config.Listen.TLS.Cert = *(configWithPointers.Listen.TLS.Cert)
+
+ if configWithPointers.Listen.TLS.Key == nil {
+ return fmt.Errorf("%w: listen.tls.key", errMissingConfigValue)
+ }
+ config.Listen.TLS.Key = *(configWithPointers.Listen.TLS.Key)
+ }
+
+ if configWithPointers.DB.Type == nil {
+ return fmt.Errorf("%w: db.type", errMissingConfigValue)
+ }
+ config.DB.Type = *(configWithPointers.DB.Type)
+
+ if configWithPointers.DB.Conn == nil {
+ return fmt.Errorf("%w: db.conn", errMissingConfigValue)
+ }
+ config.DB.Conn = *(configWithPointers.DB.Conn)
+
+ if configWithPointers.Auth.Fake == nil {
+ config.Auth.Fake = 0
+ } else {
+ config.Auth.Fake = *(configWithPointers.Auth.Fake)
+ switch config.Auth.Fake {
+ case 0:
+ /* It's okay to set it to 0 in production */
+ case 4712, 9080: /* Don't use them unless you know what you're doing */
+ if config.Prod {
+ return fmt.Errorf("%w: fake authentication is incompatible with production mode", errIllegalConfig)
+ }
+ log.Println("!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.")
+ default:
+ return fmt.Errorf("%w: invalid option for auth.fake", errIllegalConfig)
+ }
+ }
+
+ if configWithPointers.Auth.Client == nil {
+ return fmt.Errorf("%w: auth.client", errMissingConfigValue)
+ }
+ config.Auth.Client = *(configWithPointers.Auth.Client)
+
+ if configWithPointers.Auth.Authorize == nil {
+ return fmt.Errorf("%w: auth.authorize", errMissingConfigValue)
+ }
+ config.Auth.Authorize = *(configWithPointers.Auth.Authorize)
+
+ if configWithPointers.Auth.Jwks == nil {
+ return fmt.Errorf("%w: auth.jwks", errMissingConfigValue)
+ }
+ config.Auth.Jwks = *(configWithPointers.Auth.Jwks)
+
+ if configWithPointers.Auth.Token == nil {
+ return fmt.Errorf("%w: auth.token", errMissingConfigValue)
+ }
+ config.Auth.Token = *(configWithPointers.Auth.Token)
+
+ if configWithPointers.Auth.Secret == nil {
+ return fmt.Errorf("%w: auth.secret", errMissingConfigValue)
+ }
+ config.Auth.Secret = *(configWithPointers.Auth.Secret)
+
+ if configWithPointers.Auth.Expr == nil {
+ return fmt.Errorf("%w: auth.expr", errMissingConfigValue)
+ }
+ config.Auth.Expr = *(configWithPointers.Auth.Expr)
+
+ if configWithPointers.Perf.MessageArgumentsCap == nil {
+ return fmt.Errorf("%w: perf.msg_args_cap", errMissingConfigValue)
+ }
+ config.Perf.MessageArgumentsCap = *(configWithPointers.Perf.MessageArgumentsCap)
+
+ if configWithPointers.Perf.MessageBytesCap == nil {
+ return fmt.Errorf("%w: perf.msg_bytes_cap", errMissingConfigValue)
+ }
+ config.Perf.MessageBytesCap = *(configWithPointers.Perf.MessageBytesCap)
+
+ if configWithPointers.Perf.ReadHeaderTimeout == nil {
+ return fmt.Errorf("%w: perf.read_header_timeout", errMissingConfigValue)
+ }
+ config.Perf.ReadHeaderTimeout = *(configWithPointers.Perf.ReadHeaderTimeout)
+
+ if configWithPointers.Perf.CourseUpdateInterval == nil {
+ return fmt.Errorf("%w: perf.course_update_interval", errMissingConfigValue)
+ }
+ config.Perf.CourseUpdateInterval = *(configWithPointers.Perf.CourseUpdateInterval)
+
+ return nil
+}
diff --git a/backend/courses.go b/backend/courses.go
new file mode 100644
index 0000000..4e1ea51
--- /dev/null
+++ b/backend/courses.go
@@ -0,0 +1,229 @@
+/*
+ * Course data structures and locking
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+)
+
+type (
+ courseTypeT string
+ courseGroupT string
+)
+
+type courseT struct {
+ ID int
+ /*
+ * TODO: There will be a lot of lock contention over Selected. It is
+ * probably more appropriate to directly use atomics.
+ * Except that it's actually hard to use atomics directly here
+ * because I need to "increment if less than Max"... I think I could
+ * just do compare and swap in a loop, but the loop would be intensive
+ * on the CPU so I'd have to look into how mutexes/semaphores are
+ * actually implemented and how I could interact with the runtime.
+ */
+ Selected int
+ SelectedLock sync.RWMutex
+ Max int
+ Title string
+ Type courseTypeT
+ Group courseGroupT
+ Teacher string
+ Location string
+ Usems map[string](*usemT)
+ UsemsLock sync.RWMutex
+}
+
+const (
+ sport courseTypeT = "Sport"
+ enrichment courseTypeT = "Enrichment"
+ culture courseTypeT = "Culture"
+)
+
+var courseTypes = map[courseTypeT]bool{
+ sport: true,
+ enrichment: true,
+ culture: true,
+}
+
+const (
+ mw1 courseGroupT = "MW1"
+ mw2 courseGroupT = "MW2"
+ mw3 courseGroupT = "MW3"
+ tt1 courseGroupT = "TT1"
+ tt2 courseGroupT = "TT2"
+ tt3 courseGroupT = "TT3"
+)
+
+var courseGroups = map[courseGroupT]bool{
+ mw1: true,
+ mw2: true,
+ mw3: true,
+ tt1: true,
+ tt2: true,
+ tt3: true,
+}
+
+func checkCourseType(ct courseTypeT) bool {
+ return courseTypes[ct]
+}
+
+func checkCourseGroup(cg courseGroupT) bool {
+ return courseGroups[cg]
+}
+
+var (
+ errInvalidCourseType = errors.New("invalid course type")
+ errInvalidCourseGroup = errors.New("invalid course group")
+ errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user")
+)
+
+/*
+ * The courses are simply stored in a map indexed by the course ID, although
+ * the course struct itself also contains an ID field. A lock is embedded
+ * inside the struct; we use a lock here instead of a pointer to a lock as
+ * it would be easy to forget to initialize the lock when creating the
+ * struct. However, this means that the struct could not be copied (though
+ * this should only ever happen during creation anyways), therefore we use a
+ * pointer to the struct as the value of the map, instead of the struct itself.
+ */
+var courses map[int](*courseT)
+
+/*
+ * This RWMutex is only for massive modifications of the course struct, since
+ * locking it on every write would be inefficient; in normal operation the only
+ * write that could occur to the courses struct is changing the Selected
+ * number, which should be handled with courseT.SelectedLock.
+ */
+var coursesLock sync.RWMutex
+
+/*
+ * Read course information from the database. This should be called during
+ * setup. Failure to do so before accessing course information may lead to
+ * a null pointer dereference.
+ */
+func setupCourses() error {
+ coursesLock.Lock()
+ defer coursesLock.Unlock()
+
+ courses = make(map[int](*courseT))
+
+ rows, err := db.Query(
+ context.Background(),
+ "SELECT id, nmax, title, ctype, cgroup, teacher, location FROM courses",
+ )
+ if err != nil {
+ return fmt.Errorf("error fetching courses: %w", err)
+ }
+
+ for {
+ if !rows.Next() {
+ err := rows.Err()
+ if err != nil {
+ return fmt.Errorf("error fetching courses: %w", err)
+ }
+ break
+ }
+ currentCourse := courseT{
+ Usems: make(map[string]*usemT),
+ } //exhaustruct:ignore
+ err = rows.Scan(
+ &currentCourse.ID,
+ &currentCourse.Max,
+ &currentCourse.Title,
+ &currentCourse.Type,
+ &currentCourse.Group,
+ &currentCourse.Teacher,
+ &currentCourse.Location,
+ )
+ if err != nil {
+ return fmt.Errorf("error fetching courses: %w", err)
+ }
+ if !checkCourseType(currentCourse.Type) {
+ return fmt.Errorf("%w: %d %s", errInvalidCourseType, currentCourse.ID, currentCourse.Type)
+ }
+ if !checkCourseGroup(currentCourse.Group) {
+ return fmt.Errorf("%w: %d %s", errInvalidCourseGroup, currentCourse.ID, currentCourse.Group)
+ }
+ err := db.QueryRow(context.Background(),
+ "SELECT COUNT (*) FROM choices WHERE courseid = $1",
+ currentCourse.ID,
+ ).Scan(&currentCourse.Selected)
+ if err != nil {
+ return fmt.Errorf("error querying course member number: %w", err)
+ }
+ courses[currentCourse.ID] = &currentCourse
+ }
+
+ return nil
+}
+
+type userCourseGroupsT map[courseGroupT]bool
+
+func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseGroupsT, userID string) error {
+ rows, err := db.Query(ctx, "SELECT courseid FROM choices WHERE userid = $1", userID)
+ if err != nil {
+ return fmt.Errorf("error querying user's choices while populating course groups: %w", err)
+ }
+ for {
+ if !rows.Next() {
+ err := rows.Err()
+ if err != nil {
+ return fmt.Errorf("error iterating user's choices while populating course groups: %w", err)
+ }
+ break
+ }
+ var thisCourseID int
+ err := rows.Scan(&thisCourseID)
+ if err != nil {
+ return fmt.Errorf("error fetching user's choices while populating course groups: %w", err)
+ }
+ var thisGroupName courseGroupT
+ func() {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ thisGroupName = courses[thisCourseID].Group
+ }()
+ if (*userCourseGroups)[thisGroupName] {
+ return fmt.Errorf("%w: user %v, group %v", errMultipleChoicesInOneGroup, userID, thisGroupName)
+ }
+ (*userCourseGroups)[thisGroupName] = true
+ }
+ return nil
+}
+
+func (course *courseT) decrementSelectedAndPropagate() {
+ func() {
+ course.SelectedLock.Lock()
+ defer course.SelectedLock.Unlock()
+ course.Selected--
+ }()
+ propagateSelectedUpdate(course.ID)
+}
+
+func getCourseByID(courseID int) *courseT {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ return courses[courseID]
+}
diff --git a/backend/db.go b/backend/db.go
new file mode 100644
index 0000000..9ced627
--- /dev/null
+++ b/backend/db.go
@@ -0,0 +1,51 @@
+/*
+ * Database handling
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/jackc/pgx/v5/pgxpool"
+)
+
+var db *pgxpool.Pool
+
+var errUnsupportedDatabaseType = errors.New("unsupported db type")
+
+const pgErrUniqueViolation = "23505"
+
+/*
+ * This must be run during setup, before the database is accessed by any
+ * means. Otherwise, db would be a null pointer.
+ */
+func setupDatabase() error {
+ var err error
+ if config.DB.Type != "postgres" {
+ return errUnsupportedDatabaseType
+ }
+ db, err = pgxpool.New(context.Background(), config.DB.Conn)
+ if err != nil {
+ return fmt.Errorf("error opening database: %w", err)
+ }
+ return nil
+}
diff --git a/backend/go.mod b/backend/go.mod
new file mode 100644
index 0000000..717f737
--- /dev/null
+++ b/backend/go.mod
@@ -0,0 +1,23 @@
+module git.sr.ht/~runxiyu/cca/backend
+
+go 1.23.1
+
+require (
+ git.sr.ht/~emersion/go-scfg v0.0.0-20240128091534-2ae16e782082
+ github.com/MicahParks/keyfunc/v3 v3.3.5
+ github.com/coder/websocket v1.8.12
+ github.com/golang-jwt/jwt/v5 v5.2.1
+ github.com/google/uuid v1.6.0
+ github.com/jackc/pgx/v5 v5.7.1
+)
+
+require (
+ github.com/MicahParks/jwkset v0.5.20 // indirect
+ github.com/jackc/pgpassfile v1.0.0 // indirect
+ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
+ github.com/jackc/puddle/v2 v2.2.2 // indirect
+ golang.org/x/crypto v0.28.0 // indirect
+ golang.org/x/sync v0.8.0 // indirect
+ golang.org/x/text v0.19.0 // indirect
+ golang.org/x/time v0.7.0 // indirect
+)
diff --git a/backend/go.sum b/backend/go.sum
new file mode 100644
index 0000000..ca23612
--- /dev/null
+++ b/backend/go.sum
@@ -0,0 +1,42 @@
+git.sr.ht/~emersion/go-scfg v0.0.0-20240128091534-2ae16e782082 h1:9Udx5fm4vRtmgDIBjy2ef5QioHbzpw5oHabbhpAUyEw=
+git.sr.ht/~emersion/go-scfg v0.0.0-20240128091534-2ae16e782082/go.mod h1:ybgvEJTIx5XbaspSviB3KNa6OdPmAZqDoSud7z8fFlw=
+github.com/MicahParks/jwkset v0.5.20 h1:gTIKx9AofTqQJ0srd8AL7ty9NeadP5WUXSPOZadTpOI=
+github.com/MicahParks/jwkset v0.5.20/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY=
+github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo=
+github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8=
+github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
+github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
+github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
+github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
+github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
+github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
+github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
+github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
+github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
+github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
+golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
+golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
+golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
+golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
+golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
+golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/backend/index.go b/backend/index.go
new file mode 100644
index 0000000..f446072
--- /dev/null
+++ b/backend/index.go
@@ -0,0 +1,120 @@
+/*
+ * Index page
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "net/http"
+
+ "github.com/jackc/pgx/v5"
+)
+
+/*
+ * Serve the index page. Also handles the login page in case the user doesn't
+ * have any valid login cookies.
+ */
+func handleIndex(w http.ResponseWriter, req *http.Request) {
+ sessionCookie, err := req.Cookie("session")
+ if errors.Is(err, http.ErrNoCookie) {
+ authURL, err := generateAuthorizationURL()
+ if err != nil {
+ wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL")
+ return
+ }
+ err = tmpl.ExecuteTemplate(
+ w,
+ "login",
+ map[string]string{
+ "authURL": authURL,
+ "source": config.Source,
+ /*
+ * We directly generate the login URL here
+ * instead of doing so in a redirect to save
+ * requests.
+ */
+ },
+ )
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ return
+ } else if err != nil {
+ wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.")
+ return
+ }
+
+ var userID, userName, userDepartment string
+ err = db.QueryRow(
+ req.Context(),
+ "SELECT id, name, department FROM users WHERE session = $1",
+ sessionCookie.Value,
+ ).Scan(&userID, &userName, &userDepartment)
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ authURL, err := generateAuthorizationURL()
+ if err != nil {
+ wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL")
+ return
+ }
+ err = tmpl.ExecuteTemplate(
+ w,
+ "login",
+ map[string]interface{}{
+ "authURL": authURL,
+ "notes": "You sent an invalid session cookie.",
+ "source": config.Source,
+ },
+ )
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ return
+ }
+ wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: Unexpected database error: %s", err))
+ return
+ }
+
+ err = func() error {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ return tmpl.ExecuteTemplate(
+ w,
+ "student",
+ map[string]interface{}{
+ "open": true,
+ "user": map[string]interface{}{
+ "Name": userName,
+ "Department": userDepartment,
+ },
+ "courses": courses,
+ "source": config.Source,
+ },
+ )
+ }()
+ if err != nil {
+ log.Println(err)
+ return
+ }
+}
diff --git a/backend/main.go b/backend/main.go
new file mode 100644
index 0000000..a185de9
--- /dev/null
+++ b/backend/main.go
@@ -0,0 +1,125 @@
+/*
+ * Main listener
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "crypto/tls"
+ "flag"
+ "html/template"
+ "log"
+ "net"
+ "net/http"
+ "time"
+)
+
+var tmpl *template.Template
+
+func main() {
+ var err error
+
+ var configPath string
+
+ flag.StringVar(&configPath, "config", "cca.scfg", "path to configuration file")
+ flag.Parse()
+
+ if err := fetchConfig(configPath); err != nil {
+ log.Fatal(err)
+ }
+
+ log.Println("Setting up database")
+ if err := setupDatabase(); err != nil {
+ log.Fatal(err)
+ }
+
+ log.Println("Setting up JWKS")
+ if err := setupJwks(); err != nil {
+ log.Fatal(err)
+ }
+
+ log.Println("Setting up templates")
+ tmpl, err = template.ParseGlob(config.Tmpl)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ log.Println("Setting up courses")
+ err = setupCourses()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ log.Println("Registering static handle")
+ fs := http.FileServer(http.Dir(config.Static))
+ http.Handle("/static/", http.StripPrefix("/static/", fs))
+
+ log.Println("Registering handlers")
+ http.HandleFunc("/{$}", handleIndex)
+ http.HandleFunc("/auth", handleAuth)
+ http.HandleFunc("/ws", handleWs)
+
+ var l net.Listener
+
+ switch config.Listen.Trans {
+ case "plain":
+ log.Printf(
+ "Establishing plain listener for net \"%s\", addr \"%s\"\n",
+ config.Listen.Net,
+ config.Listen.Addr,
+ )
+ l, err = net.Listen(config.Listen.Net, config.Listen.Addr)
+ if err != nil {
+ log.Fatalf("Failed to establish plain listener: %v\n", err)
+ }
+ case "tls":
+ cer, err := tls.LoadX509KeyPair(config.Listen.TLS.Cert, config.Listen.TLS.Key)
+ if err != nil {
+ log.Fatalf("Failed to load TLS certificate and key: %v\n", err)
+ }
+ tlsconfig := &tls.Config{
+ Certificates: []tls.Certificate{cer},
+ MinVersion: tls.VersionTLS13,
+ } //exhaustruct:ignore
+ log.Printf(
+ "Establishing TLS listener for net \"%s\", addr \"%s\"\n",
+ config.Listen.Net,
+ config.Listen.Addr,
+ )
+ l, err = tls.Listen(config.Listen.Net, config.Listen.Addr, tlsconfig)
+ if err != nil {
+ log.Fatalf("Failed to establish TLS listener: %v\n", err)
+ }
+ default:
+ log.Fatalln("listen.trans must be \"plain\" or \"tls\"")
+ }
+
+ if config.Listen.Proto == "http" {
+ log.Println("Serving http")
+ srv := &http.Server{
+ ReadHeaderTimeout: time.Duration(config.Perf.ReadHeaderTimeout) * time.Second,
+ } //exhaustruct:ignore
+ err = srv.Serve(l)
+ } else {
+ log.Fatalln("Unsupported protocol")
+ }
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/backend/usem.go b/backend/usem.go
new file mode 100644
index 0000000..0b839a4
--- /dev/null
+++ b/backend/usem.go
@@ -0,0 +1,36 @@
+/*
+ * Additional synchronization routines
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+type usemT struct {
+ ch (chan struct{})
+}
+
+func (s *usemT) init() {
+ s.ch = make(chan struct{}, 1)
+}
+
+func (s *usemT) set() {
+ select {
+ case s.ch <- struct{}{}:
+ default:
+ }
+}
diff --git a/backend/utils.go b/backend/utils.go
new file mode 100644
index 0000000..3628c80
--- /dev/null
+++ b/backend/utils.go
@@ -0,0 +1,58 @@
+/*
+ * Utility functions
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "fmt"
+ "log"
+ "net/http"
+)
+
+/*
+ * Write a string to a http.ResponseWriter, setting the Content-Type and status
+ * code.
+ */
+func wstr(w http.ResponseWriter, code int, msg string) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(code)
+ _, err := w.Write([]byte(msg))
+ if err != nil {
+ log.Printf("Error wstr'ing to writer: %v", err)
+ }
+}
+
+/*
+ * Generate a random url-safe string.
+ * Note that the "sz" parameter specifies the number of bytes taken from the
+ * random source divided by three and does NOT represent the length of the
+ * encoded string. It's divided by three because we're using base64 and it's
+ * ideal to ensure that the entropy remains consistent throughout the string.
+ */
+func randomString(sz int) (string, error) {
+ r := make([]byte, 3*sz)
+ _, err := rand.Read(r)
+ if err != nil {
+ return "", fmt.Errorf("error generating random string: %w", err)
+ }
+ return base64.RawURLEncoding.EncodeToString(r), nil
+}
diff --git a/backend/wsc.go b/backend/wsc.go
new file mode 100644
index 0000000..982c1fb
--- /dev/null
+++ b/backend/wsc.go
@@ -0,0 +1,266 @@
+/*
+ * WebSocket connection routine
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/coder/websocket"
+)
+
+type errbytesT struct {
+ err error
+ bytes *[]byte
+}
+
+/*
+ * The actual logic in handling the connection, after authentication has been
+ * completed.
+ */
+func handleConn(
+ ctx context.Context,
+ c *websocket.Conn,
+ session string,
+ userID string,
+) (retErr error) {
+ reportError := makeReportError(ctx, c)
+ newCtx, newCancel := context.WithCancel(ctx)
+
+ func() {
+ cancelPoolLock.Lock()
+ defer cancelPoolLock.Unlock()
+ cancel := cancelPool[userID]
+ if cancel != nil {
+ (*cancel)()
+ /* TODO: Make the cancel synchronous */
+ }
+ cancelPool[userID] = &newCancel
+ }()
+ defer func() {
+ cancelPoolLock.Lock()
+ defer cancelPoolLock.Unlock()
+ if cancelPool[userID] == &newCancel {
+ delete(cancelPool, userID)
+ }
+ if errors.Is(retErr, context.Canceled) {
+ /*
+ * Only works if it's newCtx that has been cancelled
+ * rather than the original ctx, which is kinda what
+ * we intend
+ */
+ _ = writeText(ctx, c, "E :Context canceled")
+ }
+ }()
+
+ /* TODO: Tell the user their current choices here. Deprecate HELLO. */
+
+ usems := make(map[int]*usemT)
+ func() {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ for courseID, course := range courses {
+ usem := &usemT{} //exhaustruct:ignore
+ usem.init()
+ func() {
+ course.UsemsLock.Lock()
+ defer course.UsemsLock.Unlock()
+ course.Usems[userID] = usem
+ }()
+ usems[courseID] = usem
+ }
+ }()
+ defer func() {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ for _, course := range courses {
+ func() {
+ course.UsemsLock.Lock()
+ defer course.UsemsLock.Unlock()
+ delete(course.Usems, userID)
+ }()
+ }
+ }()
+
+ usemParent := make(chan int)
+ for courseID, usem := range usems {
+ go func() {
+ for {
+ select {
+ case <-newCtx.Done():
+ return
+ case <-usem.ch:
+ select {
+ case <-newCtx.Done():
+ return
+ case usemParent <- courseID:
+ }
+ }
+ time.Sleep(time.Duration(config.Perf.CourseUpdateInterval) * time.Millisecond)
+ }
+ }()
+ }
+
+ /*
+ * userCourseGroups stores whether the user has already chosen a course
+ * in the courseGroup.
+ */
+ var userCourseGroups userCourseGroupsT = make(map[courseGroupT]bool)
+ err := populateUserCourseGroups(newCtx, &userCourseGroups, userID)
+ if err != nil {
+ return reportError(fmt.Sprintf("cannot populate user course groups: %v", err))
+ }
+
+ /*
+ * Later we need to select from recv and send and perform the
+ * corresponding action. But we can't just select from c.Read because
+ * the function blocks. Therefore, we must spawn a goroutine that
+ * blocks on c.Read and send what it receives to a channel "recv"; and
+ * then we can select from that channel.
+ */
+ recv := make(chan *errbytesT)
+ go func() {
+ for {
+ /*
+ * Here we use the original connection context instead
+ * of the new context we just created. Apparently when
+ * the context passed to Read expires, the connection
+ * gets closed, which makes it impossible for us to
+ * write the context expiry message to the client.
+ * So we pass the original connection context, which
+ * would get cancelled anyway once we close the
+ * connection.
+ * We still need to take care of this while sending so
+ * we don't infinitely block, and leak goroutines and
+ * cause the channel to remain out of reach of the
+ * garbage collector.
+ * It would be nice to return the newCtx.Err() but
+ * the only way to really do that is to use the recv
+ * channel which might not have a listener anymore.
+ * It's not really crucial anyways so we could just
+ * close this goroutine by returning here.
+ */
+ _, b, err := c.Read(ctx)
+ if err != nil {
+ /*
+ * TODO: Prioritize context dones... except
+ * that it's not really possible. I would just
+ * have placed newCtx in here but apparently
+ * that causes the connection to be closed when
+ * the context expires, which makes it
+ * impossible to deliver the final error
+ * message. Probably need to look into this
+ * design again.
+ */
+ select {
+ case <-newCtx.Done():
+ _ = writeText(ctx, c, "E :Context canceled")
+ /* Not a typo to use ctx here */
+ return
+ case recv <- &errbytesT{err: err, bytes: nil}:
+ }
+ return
+ }
+ select {
+ case <-newCtx.Done():
+ _ = writeText(ctx, c, "E :Context cancelled")
+ /* Not a typo to use ctx here */
+ return
+ case recv <- &errbytesT{err: nil, bytes: &b}:
+ }
+ }
+ }()
+
+ for {
+ var mar []string
+ select {
+ case <-newCtx.Done():
+ /*
+ * TODO: Somehow prioritize this case over all other cases
+ */
+ return fmt.Errorf("context done in main event loop: %w", newCtx.Err())
+ /*
+ * There are other times when the context could be
+ * cancelled, and apparently some WebSocket functions
+ * just close the connection when the context is
+ * cancelled. So it's kinda impossible to reliably
+ * send this message due to newCtx cancellation.
+ * But in any case, the WebSocket connection would
+ * be closed, and the user would see the connection
+ * closed page which should explain it.
+ */
+ case courseID := <-usemParent:
+ var selected int
+ func() {
+ course := courses[courseID]
+ course.SelectedLock.RLock()
+ defer course.SelectedLock.RUnlock()
+ selected = course.Selected
+ }()
+ err := writeText(newCtx, c, fmt.Sprintf("M %d %d", courseID, selected))
+ if err != nil {
+ return fmt.Errorf("error sending to websocket for course selected update: %w", err)
+ }
+ continue
+ case errbytes := <-recv:
+ if errbytes.err != nil {
+ return fmt.Errorf("error fetching message from recv channel: %w", errbytes.err)
+ /*
+ * Note that this cannot return newCtx.Err(),
+ * so we handle the error reporting in the
+ * reading routine
+ */
+ }
+ mar = splitMsg(errbytes.bytes)
+ switch mar[0] {
+ case "HELLO":
+ err := messageHello(newCtx, c, reportError, mar, userID, session)
+ if err != nil {
+ return err
+ }
+ case "Y":
+ err := messageChooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups)
+ if err != nil {
+ return err
+ }
+ case "N":
+ err := messageUnchooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups)
+ if err != nil {
+ return err
+ }
+ default:
+ return reportError("Unknown command " + mar[0])
+ }
+ }
+ }
+}
+
+var (
+ cancelPool = make(map[string](*context.CancelFunc))
+ /*
+ * Normal Go maps are not thread safe, so we protect large cancelPool
+ * operations such as addition and deletion under a RWMutex.
+ */
+ cancelPoolLock sync.RWMutex
+)
diff --git a/backend/wsh.go b/backend/wsh.go
new file mode 100644
index 0000000..abf4e61
--- /dev/null
+++ b/backend/wsh.go
@@ -0,0 +1,165 @@
+/*
+ * WebSocket endpoint handler
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "errors"
+ "log"
+ "net/http"
+ "time"
+
+ "github.com/coder/websocket"
+ "github.com/google/uuid"
+ "github.com/jackc/pgx/v5"
+)
+
+/*
+ * Handle requests to the WebSocket endpoint and establish a connection.
+ * Authentication is handled here, but afterwards, the connection is really
+ * handled in handleConn.
+ */
+func handleWs(w http.ResponseWriter, req *http.Request) {
+ wsOptions := &websocket.AcceptOptions{
+ Subprotocols: []string{"cca1"},
+ } //exhaustruct:ignore
+ c, err := websocket.Accept(
+ w,
+ req,
+ wsOptions,
+ )
+ if err != nil {
+ wstr(w, http.StatusBadRequest, "This endpoint only supports valid WebSocket connections.")
+ return
+ }
+ defer func() {
+ _ = c.CloseNow()
+ }()
+
+ fake := false
+
+ sessionCookie, err := req.Cookie("session")
+ if errors.Is(err, http.ErrNoCookie) {
+ if config.Auth.Fake == 0 {
+ err := writeText(req.Context(), c, "U")
+ if err != nil {
+ log.Println(err)
+ }
+ return
+ }
+ fake = true
+ } else if err != nil {
+ err := writeText(req.Context(), c, "E :Error fetching cookie")
+ if err != nil {
+ log.Println(err)
+ }
+ return
+ }
+
+ var userID string
+ var session string
+ var expr int
+
+ if fake {
+ switch config.Auth.Fake {
+ case 9080:
+ _uuid, err := uuid.NewRandom()
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ userID = _uuid.String()
+ case 4712:
+ userID = "fake"
+ default:
+ panic("not supposed to happen")
+ }
+ session, err = randomString(20)
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ _, err = db.Exec(
+ req.Context(),
+ "INSERT INTO users (id, name, email, department, session, expr) VALUES ($1, $2, $3, $4, $5, $6)",
+ userID,
+ "Fake User",
+ "fake@runxiyu.org",
+ "Y11",
+ session,
+ time.Now().Add(time.Duration(config.Auth.Expr)*time.Second).Unix(),
+ )
+ if err != nil && config.Auth.Fake != 4712 {
+ /* TODO check pgerr */
+ err := writeText(req.Context(), c, "E :Database error while writing fake account info")
+ if err != nil {
+ log.Println(err)
+ }
+ return
+ }
+ err = writeText(req.Context(), c, "FAKE "+userID+" "+session)
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ } else {
+ session = sessionCookie.Value
+ err = db.QueryRow(
+ req.Context(),
+ "SELECT id, expr FROM users WHERE session = $1",
+ session,
+ ).Scan(&userID, &expr)
+ if errors.Is(err, pgx.ErrNoRows) {
+ err := writeText(req.Context(), c, "U")
+ if err != nil {
+ log.Println(err)
+ }
+ return
+ } else if err != nil {
+ err := writeText(req.Context(), c, "E :Database error while selecting session")
+ if err != nil {
+ log.Println(err)
+ }
+ return
+ }
+ }
+
+ /*
+ * Now that we have an authenticated request, this WebSocket connection
+ * may be simply associated with the session and userID.
+ * TODO: There are various race conditions that could occur if one user
+ * creates multiple connections, with the same or different session
+ * cookies. The last situation could occur in normal use when a user
+ * opens multiple instances of the page in one browser, and is not
+ * unique to custom clients or malicious users. Some effort must be
+ * taken to ensure that each user may only have one connection at a
+ * time.
+ */
+ err = handleConn(
+ req.Context(),
+ c,
+ session,
+ userID,
+ )
+ if err != nil {
+ log.Printf("%v", err)
+ return
+ }
+}
diff --git a/backend/wsm.go b/backend/wsm.go
new file mode 100644
index 0000000..5f271ff
--- /dev/null
+++ b/backend/wsm.go
@@ -0,0 +1,224 @@
+/*
+ * WebSocket message handlers
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/coder/websocket"
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgconn"
+)
+
+func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string) error {
+ _, _ = mar, session
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context done when handling hello: %w", ctx.Err())
+ default:
+ }
+
+ rows, err := db.Query(
+ ctx,
+ "SELECT courseid FROM choices WHERE userid = $1",
+ userID,
+ )
+ if err != nil {
+ return reportError("error fetching choices")
+ }
+ courseIDs, err := pgx.CollectRows(rows, pgx.RowTo[string])
+ if err != nil {
+ return reportError("error collecting choices")
+ }
+
+ err = writeText(ctx, c, "HI :"+strings.Join(courseIDs, ","))
+ if err != nil {
+ return fmt.Errorf("error replying to HELLO: %w", err)
+ }
+
+ return nil
+}
+
+func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error {
+ _ = session
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context done when handling choose: %w", ctx.Err())
+ default:
+ }
+
+ if len(mar) != 2 {
+ return reportError("Invalid number of arguments for Y")
+ }
+ _courseID, err := strconv.ParseInt(mar[1], 10, strconv.IntSize)
+ if err != nil {
+ return reportError("Course ID must be an integer")
+ }
+ courseID := int(_courseID)
+ course := getCourseByID(courseID)
+
+ if course == nil {
+ return reportError("nil course")
+ }
+
+ err = func() (returnedError error) { /* Named returns so I could modify them in defer */
+ tx, err := db.Begin(ctx)
+ if err != nil {
+ return reportError("Database error while beginning transaction")
+ }
+ defer func() {
+ err := tx.Rollback(ctx)
+ if err != nil && (!errors.Is(err, pgx.ErrTxClosed)) {
+ returnedError = reportError("Database error while rolling back transaction in defer block")
+ return
+ }
+ }()
+
+ _, err = tx.Exec(
+ ctx,
+ "INSERT INTO choices (seltime, userid, courseid) VALUES ($1, $2, $3)",
+ time.Now().UnixMicro(),
+ userID,
+ courseID,
+ )
+ if err != nil {
+ var pgErr *pgconn.PgError
+ if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation {
+ err := writeText(ctx, c, "Y "+mar[1])
+ if err != nil {
+ return fmt.Errorf("error reaffirming course choice: %w", err)
+ }
+ return nil
+ }
+ return reportError("Database error while inserting course choice")
+ }
+
+ ok := func() bool {
+ course.SelectedLock.Lock()
+ defer course.SelectedLock.Unlock()
+ if course.Selected < course.Max {
+ course.Selected++
+ return true
+ }
+ return false
+ }()
+
+ if ok {
+ go propagateSelectedUpdate(courseID)
+ err := tx.Commit(ctx)
+ if err != nil {
+ go course.decrementSelectedAndPropagate()
+ return reportError("Database error while committing transaction")
+ }
+ var thisCourseGroup courseGroupT
+ func() {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ thisCourseGroup = courses[courseID].Group
+ }()
+ if (*userCourseGroups)[thisCourseGroup] {
+ go course.decrementSelectedAndPropagate()
+ return reportError("inconsistent user course groups")
+ }
+ (*userCourseGroups)[thisCourseGroup] = true
+ err = writeText(ctx, c, "Y "+mar[1])
+ if err != nil {
+ return fmt.Errorf("error affirming course choice: %w", err)
+ }
+ } else {
+ err := tx.Rollback(ctx)
+ if err != nil {
+ return reportError("Database error while rolling back transaction due to course limit")
+ }
+ err = writeText(ctx, c, "R "+mar[1]+" :Full")
+ if err != nil {
+ return fmt.Errorf("error rejecting course choice: %w", err)
+ }
+ }
+ return nil
+ }()
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error {
+ _ = session
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context done when handling unchoose: %w", ctx.Err())
+ default:
+ }
+
+ if len(mar) != 2 {
+ return reportError("Invalid number of arguments for N")
+ }
+ _courseID, err := strconv.ParseInt(mar[1], 10, strconv.IntSize)
+ if err != nil {
+ return reportError("Course ID must be an integer")
+ }
+ courseID := int(_courseID)
+ course := getCourseByID(courseID)
+
+ if course == nil {
+ return reportError("nil course")
+ }
+
+ ct, err := db.Exec(
+ ctx,
+ "DELETE FROM choices WHERE userid = $1 AND courseid = $2",
+ userID,
+ courseID,
+ )
+ if err != nil {
+ return reportError("Database error while deleting course choice")
+ }
+
+ if ct.RowsAffected() != 0 {
+ go course.decrementSelectedAndPropagate()
+ var thisCourseGroup courseGroupT
+ func() {
+ coursesLock.RLock()
+ defer coursesLock.RUnlock()
+ thisCourseGroup = courses[courseID].Group
+ }()
+ if !(*userCourseGroups)[thisCourseGroup] {
+ return reportError("inconsistent user course groups")
+ }
+ (*userCourseGroups)[thisCourseGroup] = false
+ }
+
+ err = writeText(ctx, c, "N "+mar[1])
+ if err != nil {
+ return fmt.Errorf("error replying that course has been deselected: %w", err)
+ }
+
+ return nil
+}
diff --git a/backend/wsp.go b/backend/wsp.go
new file mode 100644
index 0000000..7a1f6ab
--- /dev/null
+++ b/backend/wsp.go
@@ -0,0 +1,85 @@
+/*
+ * WebSocket-based protocol auxiliary functions
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/coder/websocket"
+)
+
+/*
+ * Split an IRC-style message of type []byte into type []string where each
+ * element is a complete argument. Generally, arguments are separated by
+ * spaces, and an argument that begins with a ':' causes the rest of the
+ * line to be treated as a single argument.
+ */
+func splitMsg(b *[]byte) []string {
+ mar := make([]string, 0, config.Perf.MessageArgumentsCap)
+ elem := make([]byte, 0, config.Perf.MessageBytesCap)
+ for i, c := range *b {
+ switch c {
+ case ' ':
+ if (*b)[i+1] == ':' {
+ mar = append(mar, string(elem))
+ mar = append(mar, string((*b)[i+2:]))
+ goto endl
+ }
+ mar = append(mar, string(elem))
+ elem = make([]byte, 0, config.Perf.MessageBytesCap)
+ default:
+ elem = append(elem, c)
+ }
+ }
+ mar = append(mar, string(elem))
+endl:
+ return mar
+}
+
+func baseReportError(ctx context.Context, conn *websocket.Conn, e string) error {
+ err := writeText(ctx, conn, "E :"+e)
+ if err != nil {
+ return fmt.Errorf("error reporting protocol violation: %w", err)
+ }
+ err = conn.Close(websocket.StatusProtocolError, e)
+ if err != nil {
+ return fmt.Errorf("error closing websocket: %w", err)
+ }
+ return nil
+}
+
+type reportErrorT func(e string) error
+
+func makeReportError(ctx context.Context, conn *websocket.Conn) reportErrorT {
+ return func(e string) error {
+ return baseReportError(ctx, conn, e)
+ }
+}
+
+func propagateSelectedUpdate(courseID int) {
+ course := courses[courseID]
+ course.UsemsLock.RLock()
+ defer course.UsemsLock.RUnlock()
+ for _, usem := range course.Usems {
+ usem.set()
+ }
+}
diff --git a/backend/wsx.go b/backend/wsx.go
new file mode 100644
index 0000000..a51b976
--- /dev/null
+++ b/backend/wsx.go
@@ -0,0 +1,59 @@
+/*
+ * Generic WebSocket auxiliary functions
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/*
+ * The message format is a WebSocket message separated with spaces.
+ * The contents of each field could contain anything other than spaces,
+ * The first character of each argument cannot be a colon. As an exception, the
+ * last argument may contain spaces and the first character thereof may be a
+ * colon, if the argument is prefixed with a colon. The colon used for the
+ * prefix is not considered part of the content of the message. For example, in
+ *
+ * SQUISH POP :cat purr!!
+ *
+ * the first field is "SQUISH", the second field is "POP", and the third
+ * field is "cat purr!!".
+ *
+ * It is essentially an RFC 1459 IRC message without trailing CR-LF and
+ * without prefixes. See section 2.3.1 of RFC 1459 for an approximate
+ * BNF representation.
+ *
+ * The reason this was chosen instead of using protobuf etc. is that it
+ * is simple to parse without external libraries, and it also happens to
+ * be a format I'm very familiar with, having extensively worked with the
+ * IRC protocol.
+ */
+
+package main
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/coder/websocket"
+)
+
+func writeText(ctx context.Context, c *websocket.Conn, msg string) error {
+ err := c.Write(ctx, websocket.MessageText, []byte(msg))
+ if err != nil {
+ return fmt.Errorf("error writing to connection: %w", err)
+ }
+ return nil
+}