diff options
Diffstat (limited to 'backend')
-rw-r--r-- | backend/Makefile | 2 | ||||
-rw-r--r-- | backend/auth.go | 374 | ||||
-rw-r--r-- | backend/config.go | 284 | ||||
-rw-r--r-- | backend/courses.go | 229 | ||||
-rw-r--r-- | backend/db.go | 51 | ||||
-rw-r--r-- | backend/go.mod | 23 | ||||
-rw-r--r-- | backend/go.sum | 42 | ||||
-rw-r--r-- | backend/index.go | 120 | ||||
-rw-r--r-- | backend/main.go | 125 | ||||
-rw-r--r-- | backend/usem.go | 36 | ||||
-rw-r--r-- | backend/utils.go | 58 | ||||
-rw-r--r-- | backend/wsc.go | 266 | ||||
-rw-r--r-- | backend/wsh.go | 165 | ||||
-rw-r--r-- | backend/wsm.go | 224 | ||||
-rw-r--r-- | backend/wsp.go | 85 | ||||
-rw-r--r-- | backend/wsx.go | 59 |
16 files changed, 2143 insertions, 0 deletions
diff --git a/backend/Makefile b/backend/Makefile new file mode 100644 index 0000000..08566d8 --- /dev/null +++ b/backend/Makefile @@ -0,0 +1,2 @@ +cca: *.go + go build -o cca diff --git a/backend/auth.go b/backend/auth.go new file mode 100644 index 0000000..9ef1254 --- /dev/null +++ b/backend/auth.go @@ -0,0 +1,374 @@ +/* + * Custom OAUTH 2.0 implementation for the CCA Selection Service + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +var myKeyfunc keyfunc.Keyfunc + +var ( + errInsufficientFields = errors.New("insufficient fields") + errAccessTokenIncompleteError = errors.New("access token has unpopulated error fields") + errTokenEndpointReturnedError = errors.New("token endpoint returned error") +) + +const tokenLength = 20 + +/* + * These are the claims in the JSON Web Token received from the client, after + * it redirects from the authorize endpoint. Some of these fields must be + * explicitly selected in the Azure app registration and might appear as + * zero strings if it hasn't been configured correctly. + */ +type msclaimsT struct { + Name string `json:"name"` /* Scope: profile */ + Email string `json:"email"` /* Scope: email */ + Oid string `json:"oid"` /* Scope: profile */ + jwt.RegisteredClaims +} + +func generateAuthorizationURL() (string, error) { + /* + * TODO: Handle nonces and anti-replay. Incremental nonces would be + * nice on memory and speed (depending on how maps are implemented in + * Go, hopefully it's some sort of btree), but that requires either + * hacky atomics or having a multiple goroutines to handle + * authentication, neither of which are desirable. + */ + nonce, err := randomString(tokenLength) + if err != nil { + return "", err + } + /* + * Note that here we use a hybrid authentication flow to obtain an + * id_token for authentication and an authorization code. The + * authorization code may be used like any other; i.e., it may be used + * to obtain an access token directly, or the refresh token may be used + * to gain persistent access to the upstream API. Sometimes I wish that + * the JWT in id_token could have more claims. The only reason we + * presently use a hybrid flow is to use the authorization code to + * obtain an access code to call the user info endpoint to fetch the + * user's department information. + */ + return fmt.Sprintf( + "https://login.microsoftonline.com/ddd3d26c-b197-4d00-a32d-1ffd84c0c295/oauth2/authorize?client_id=%s&response_type=id_token%%20code&redirect_uri=%s%%2Fauth&response_mode=form_post&scope=openid+profile+email+User.Read&nonce=%s", + config.Auth.Client, + config.URL, + nonce, + ), nil +} + +/* + * Handles redirects to the /auth endpoint from the authorize endpoint. + * Expects JSON Web Keys to be already set up correctly; if myKeyfunc is null, + * a null pointer is dereferenced and the thread panics. + */ +func handleAuth(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + wstr(w, http.StatusMethodNotAllowed, "Only POST is supported on the authentication endpoint") + return + } + + err := req.ParseForm() + if err != nil { + wstr(w, http.StatusBadRequest, "Malformed form data") + return + } + + returnedError := req.PostFormValue("error") + if returnedError != "" { + returnedErrorDescription := req.PostFormValue("error_description") + if returnedErrorDescription == "" { + wstr(w, http.StatusBadRequest, fmt.Sprintf("authorize endpoint returned error: %v", returnedErrorDescription)) + return + } + wstr(w, http.StatusBadRequest, fmt.Sprintf( + "%s: %s", + returnedError, + returnedErrorDescription, + )) + return + } + + idTokenString := req.PostFormValue("id_token") + if idTokenString == "" { + wstr(w, http.StatusBadRequest, "Missing id_token") + return + } + + claimsTemplate := &msclaimsT{} //exhaustruct:ignore + token, err := jwt.ParseWithClaims( + idTokenString, + claimsTemplate, + myKeyfunc.Keyfunc, + ) + if err != nil { + wstr(w, http.StatusBadRequest, "Cannot parse claims") + return + } + + switch { + case token.Valid: + break + case errors.Is(err, jwt.ErrTokenMalformed): + wstr(w, http.StatusBadRequest, "Malformed JWT token") + return + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + wstr(w, http.StatusBadRequest, "Invalid JWS signature") + return + case errors.Is(err, jwt.ErrTokenExpired) || + errors.Is(err, jwt.ErrTokenNotValidYet): + wstr(w, http.StatusBadRequest, "JWT token expired or not yet valid") + return + default: + wstr(w, http.StatusBadRequest, "Unhandled JWT token error") + return + } + + claims, claimsOk := token.Claims.(*msclaimsT) + + if !claimsOk { + wstr(w, http.StatusBadRequest, "Cannot unpack claims") + return + } + + authorizationCode := req.PostFormValue("code") + + accessToken, err := getAccessToken(req.Context(), authorizationCode) + if err != nil { + wstr(w, http.StatusInternalServerError, fmt.Sprintf("Unable to fetch access token: %v", err)) + return + } + + department, err := getDepartment(req.Context(), *(accessToken.Content)) + if err != nil { + wstr(w, http.StatusInternalServerError, err.Error()) + return + } + + switch { + case department == "SJ Co-Curricular Activities Office 松江课外项目办公室" || department == "High School Teaching & Learning 高中教学部门": + department = "Staff" + case department == "Y9" || department == "Y10" || department == "Y11" || department == "Y12": + default: + wstr( + w, + http.StatusForbidden, + fmt.Sprintf( + "Your department \"%s\" is unknown.\nWe currently only allow Y9, Y10, Y11, Y12, and the CCA office.", + department, + ), + ) + return + } + + cookieValue, err := randomString(tokenLength) + if err != nil { + wstr(w, http.StatusInternalServerError, err.Error()) + return + } + + now := time.Now() + expr := now.Add(time.Duration(config.Auth.Expr) * time.Second) + exprU := expr.Unix() + + cookie := http.Cookie{ + Name: "session", + Value: cookieValue, + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: config.Prod, + Expires: expr, + } //exhaustruct:ignore + + http.SetCookie(w, &cookie) + + /* + * TODO: Here we attempt to insert and call update if we receive a + * conflict. This works but is not idiomatic (and could confuse the + * database administrator with database integrity warnings in the log. + * The INSERT statement actually supports updating on conflict: + * https://www.postgresql.org/docs/current/sql-insert.html + */ + _, err = db.Exec( + req.Context(), + "INSERT INTO users (id, name, email, department, session, expr) VALUES ($1, $2, $3, $4, $5, $6)", + claims.Oid, + claims.Name, + claims.Email, + department, + cookieValue, + exprU, + ) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation { + _, err := db.Exec( + req.Context(), + "UPDATE users SET (name, email, department, session, expr) = ($1, $2, $3, $4, $5) WHERE id = $6", + claims.Name, + claims.Email, + department, + cookieValue, + exprU, + claims.Oid, + ) + if err != nil { + wstr(w, http.StatusInternalServerError, "Database error while updating account.") + return + } + } else { + wstr(w, http.StatusInternalServerError, "Database error while writing account info.") + return + } + } + + http.Redirect(w, req, "/", http.StatusSeeOther) +} + +/* + * Setting up JSON Web Keys. Note that myKeyfunc is a global variable. + */ +func setupJwks() error { + var err error + myKeyfunc, err = keyfunc.NewDefault([]string{config.Auth.Jwks}) + if err != nil { + return fmt.Errorf("error setting up jwks: %w", err) + } + return nil +} + +/* + * Fetch the department name of the user, mostly to identify which grade + * a student is in. This expects an accessToken obtained from the OAUTH 2.0 + * token endpoint obtained via an authorization code. It might also be able + * to use this as part of a hybrid flow that directly provides access tokens, + * but this flow seems to be only usable for single-page applications according + * to the Azure portal. + */ +func getDepartment(ctx context.Context, accessToken string) (string, error) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + "https://graph.microsoft.com/v1.0/me?$select=department", + nil, + ) + if err != nil { + return "", fmt.Errorf("error getting department: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{} //exhaustruct:ignore + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error getting department: %w", err) + } + defer resp.Body.Close() + + var departmentWrap struct { + Department *string `json:"department"` + } + + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&departmentWrap) + if err != nil { + return "", fmt.Errorf("error getting department: %w", err) + } + + if departmentWrap.Department == nil { + /* + * This is probably because the response does not contain a + * "department" field, which hopefully doesn't occur as we + * have specified $select=department in the OData query. + */ + return "", fmt.Errorf("error getting department: %w", errInsufficientFields) + } + + return *(departmentWrap.Department), nil +} + +/* + * TODO: Access token expiration is not checked anywhere. + */ +type accessTokenT struct { + OriginalExpiresIn *int `json:"expires_in"` /* Original time to expiration */ + Expiration time.Time + Content *string `json:"access_token"` + Error *string `json:"error"` + ErrorDescription *string `json:"error_description"` + ErrorCodes *[]int `json:"error_codes"` +} + +/* + * Obtain an access token from the token endpoint with an existing + * authorization code. + */ +func getAccessToken(ctx context.Context, authorizationCode string) (accessTokenT, error) { + var accessToken accessTokenT + t := time.Now() + v := url.Values{} + v.Set("client_id", config.Auth.Client) + v.Set("scope", "https://graph.microsoft.com/User.Read") + v.Set("code", authorizationCode) + v.Set("redirect_uri", config.URL+"/auth") + v.Set("grant_type", "authorization_code") + v.Set("client_secret", config.Auth.Secret) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, config.Auth.Token, strings.NewReader(v.Encode())) + if err != nil { + return accessToken, fmt.Errorf("error making access token request: %w", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return accessToken, fmt.Errorf("error requesting access token: %w", err) + } + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&accessToken) + if err != nil { + return accessToken, fmt.Errorf("error decoding access token: %w", err) + } + if accessToken.Error != nil || accessToken.ErrorCodes != nil || accessToken.ErrorDescription != nil { + if accessToken.Error == nil || accessToken.ErrorCodes == nil || accessToken.ErrorDescription == nil { + return accessToken, errAccessTokenIncompleteError + } + return accessToken, fmt.Errorf("%w: %v", errTokenEndpointReturnedError, *accessToken.ErrorDescription) + } + if accessToken.Content == nil || accessToken.OriginalExpiresIn == nil { + return accessToken, fmt.Errorf("error extracting access token: %w", errInsufficientFields) + } + accessToken.Expiration = t.Add(time.Duration(*(accessToken.OriginalExpiresIn)) * time.Second) + + return accessToken, nil +} diff --git a/backend/config.go b/backend/config.go new file mode 100644 index 0000000..e5bea02 --- /dev/null +++ b/backend/config.go @@ -0,0 +1,284 @@ +/* + * Handle configuration + * + * Copyright (c) 2024 Runxi Yu <me@runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "bufio" + "errors" + "fmt" + "log" + "os" + + "git.sr.ht/~emersion/go-scfg" +) + +/* + * We use two structs. The first has all of its values as pointers, and scfg + * unmarshals the configuration to it. Then we take each value, dereference + * it, and throw it into a normal config struct without pointers. + * This means that any missing configuration options will simply cause a + * segmentation fault. + */ + +var configWithPointers struct { + URL *string `scfg:"url"` + Prod *bool `scfg:"prod"` + Tmpl *string `scfg:"tmpl"` + Static *string `scfg:"static"` + Source *string `scfg:"source"` + Listen struct { + Proto *string `scfg:"proto"` + Net *string `scfg:"net"` + Addr *string `scfg:"addr"` + Trans *string `scfg:"trans"` + TLS struct { + Cert *string `scfg:"cert"` + Key *string `scfg:"key"` + } `scfg:"tls"` + } `scfg:"listen"` + DB struct { + Type *string `scfg:"type"` + Conn *string `scfg:"conn"` + } `scfg:"db"` + Auth struct { + Fake *int `scfg:"fake"` + Client *string `scfg:"client"` + Authorize *string `scfg:"authorize"` + Jwks *string `scfg:"jwks"` + Token *string `scfg:"token"` + Secret *string `scfg:"secret"` + Expr *int `scfg:"expr"` + } `scfg:"auth"` + Perf struct { + MessageArgumentsCap *int `scfg:"msg_args_cap"` + MessageBytesCap *int `scfg:"msg_bytes_cap"` + ReadHeaderTimeout *int `scfg:"read_header_timeout"` + CourseUpdateInterval *int `scfg:"course_update_interval"` + } `scfg:"perf"` +} + +var config struct { + URL string + Prod bool + Tmpl string + Static string + Source string + Listen struct { + Proto string + Net string + Addr string + Trans string + TLS struct { + Cert string + Key string + } + } + DB struct { + Type string + Conn string + } + Auth struct { + Fake int + Client string + Authorize string + Jwks string + Token string + Secret string + Expr int + } + Perf struct { + MessageArgumentsCap int + MessageBytesCap int + ReadHeaderTimeout int + CourseUpdateInterval int + } `scfg:"perf"` +} + +var ( + errCannotProcessConfig = errors.New("cannot process configuration file") + errCannotOpenConfig = errors.New("cannot open configuration file") + errCannotDecodeConfig = errors.New("cannot decode configuration file") + errMissingConfigValue = errors.New("missing configuration value") + errIllegalConfig = errors.New("illegal configuration") +) + +func fetchConfig(path string) (retErr error) { + defer func() { + if v := recover(); v != nil { + s, ok := v.(error) + if ok { + retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, s) + } + retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v) + return + } + if retErr != nil { + retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, retErr) + return + } + }() + + f, err := os.Open(path) + if err != nil { + return fmt.Errorf("%w: %w", errCannotOpenConfig, err) + } + + err = scfg.NewDecoder(bufio.NewReader(f)).Decode(&configWithPointers) + if err != nil { + return fmt.Errorf("%w: %w", errCannotDecodeConfig, err) + } + + if configWithPointers.URL == nil { + return fmt.Errorf("%w: url", errMissingConfigValue) + } + config.URL = *(configWithPointers.URL) + + if configWithPointers.Prod == nil { + return fmt.Errorf("%w: prod", errMissingConfigValue) + } + config.Prod = *(configWithPointers.Prod) + + if configWithPointers.Tmpl == nil { + return fmt.Errorf("%w: tmpl", errMissingConfigValue) + } + config.Tmpl = *(configWithPointers.Tmpl) + + if configWithPointers.Static == nil { + return fmt.Errorf("%w: static", errMissingConfigValue) + } + config.Static = *(configWithPointers.Static) + + if configWithPointers.Source == nil { + return fmt.Errorf("%w: source", errMissingConfigValue) + } + config.Source = *(configWithPointers.Source) + + if configWithPointers.Listen.Proto == nil { + return fmt.Errorf("%w: listen.proto", errMissingConfigValue) + } + config.Listen.Proto = *(configWithPointers.Listen.Proto) + + if configWithPointers.Listen.Net == nil { + return fmt.Errorf("%w: listen.net", errMissingConfigValue) + } + config.Listen.Net = *(configWithPointers.Listen.Net) + + if configWithPointers.Listen.Addr == nil { + return fmt.Errorf("%w: listen.addr", errMissingConfigValue) + } + config.Listen.Addr = *(configWithPointers.Listen.Addr) + + if configWithPointers.Listen.Trans == nil { + return fmt.Errorf("%w: listen.trans", errMissingConfigValue) + } + config.Listen.Trans = *(configWithPointers.Listen.Trans) + + if config.Listen.Trans == "tls" { + if configWithPointers.Listen.TLS.Cert == nil { + return fmt.Errorf("%w: listen.tls.cert", errMissingConfigValue) + } + config.Listen.TLS.Cert = *(configWithPointers.Listen.TLS.Cert) + + if configWithPointers.Listen.TLS.Key == nil { + return fmt.Errorf("%w: listen.tls.key", errMissingConfigValue) + } + config.Listen.TLS.Key = *(configWithPointers.Listen.TLS.Key) + } + + if configWithPointers.DB.Type == nil { + return fmt.Errorf("%w: db.type", errMissingConfigValue) + } + config.DB.Type = *(configWithPointers.DB.Type) + + if configWithPointers.DB.Conn == nil { + return fmt.Errorf("%w: db.conn", errMissingConfigValue) + } + config.DB.Conn = *(configWithPointers.DB.Conn) + + if configWithPointers.Auth.Fake == nil { + config.Auth.Fake = 0 + } else { + config.Auth.Fake = *(configWithPointers.Auth.Fake) + switch config.Auth.Fake { + case 0: + /* It's okay to set it to 0 in production */ + case 4712, 9080: /* Don't use them unless you know what you're doing */ + if config.Prod { + return fmt.Errorf("%w: fake authentication is incompatible with production mode", errIllegalConfig) + } + log.Println("!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.") + default: + return fmt.Errorf("%w: invalid option for auth.fake", errIllegalConfig) + } + } + + if configWithPointers.Auth.Client == nil { + return fmt.Errorf("%w: auth.client", errMissingConfigValue) + } + config.Auth.Client = *(configWithPointers.Auth.Client) + + if configWithPointers.Auth.Authorize == nil { + return fmt.Errorf("%w: auth.authorize", errMissingConfigValue) + } + config.Auth.Authorize = *(configWithPointers.Auth.Authorize) + + if configWithPointers.Auth.Jwks == nil { + return fmt.Errorf("%w: auth.jwks", errMissingConfigValue) + } + config.Auth.Jwks = *(configWithPointers.Auth.Jwks) + + if configWithPointers.Auth.Token == nil { + return fmt.Errorf("%w: auth.token", errMissingConfigValue) + } + config.Auth.Token = *(configWithPointers.Auth.Token) + + if configWithPointers.Auth.Secret == nil { + return fmt.Errorf("%w: auth.secret", errMissingConfigValue) + } + config.Auth.Secret = *(configWithPointers.Auth.Secret) + + if configWithPointers.Auth.Expr == nil { + return fmt.Errorf("%w: auth.expr", errMissingConfigValue) + } + config.Auth.Expr = *(configWithPointers.Auth.Expr) + + if configWithPointers.Perf.MessageArgumentsCap == nil { + return fmt.Errorf("%w: perf.msg_args_cap", errMissingConfigValue) + } + config.Perf.MessageArgumentsCap = *(configWithPointers.Perf.MessageArgumentsCap) + + if configWithPointers.Perf.MessageBytesCap == nil { + return fmt.Errorf("%w: perf.msg_bytes_cap", errMissingConfigValue) + } + config.Perf.MessageBytesCap = *(configWithPointers.Perf.MessageBytesCap) + + if configWithPointers.Perf.ReadHeaderTimeout == nil { + return fmt.Errorf("%w: perf.read_header_timeout", errMissingConfigValue) + } + config.Perf.ReadHeaderTimeout = *(configWithPointers.Perf.ReadHeaderTimeout) + + if configWithPointers.Perf.CourseUpdateInterval == nil { + return fmt.Errorf("%w: perf.course_update_interval", errMissingConfigValue) + } + config.Perf.CourseUpdateInterval = *(configWithPointers.Perf.CourseUpdateInterval) + + return nil +} diff --git a/backend/courses.go b/backend/courses.go new file mode 100644 index 0000000..4e1ea51 --- /dev/null +++ b/backend/courses.go @@ -0,0 +1,229 @@ +/* + * Course data structures and locking + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "errors" + "fmt" + "sync" +) + +type ( + courseTypeT string + courseGroupT string +) + +type courseT struct { + ID int + /* + * TODO: There will be a lot of lock contention over Selected. It is + * probably more appropriate to directly use atomics. + * Except that it's actually hard to use atomics directly here + * because I need to "increment if less than Max"... I think I could + * just do compare and swap in a loop, but the loop would be intensive + * on the CPU so I'd have to look into how mutexes/semaphores are + * actually implemented and how I could interact with the runtime. + */ + Selected int + SelectedLock sync.RWMutex + Max int + Title string + Type courseTypeT + Group courseGroupT + Teacher string + Location string + Usems map[string](*usemT) + UsemsLock sync.RWMutex +} + +const ( + sport courseTypeT = "Sport" + enrichment courseTypeT = "Enrichment" + culture courseTypeT = "Culture" +) + +var courseTypes = map[courseTypeT]bool{ + sport: true, + enrichment: true, + culture: true, +} + +const ( + mw1 courseGroupT = "MW1" + mw2 courseGroupT = "MW2" + mw3 courseGroupT = "MW3" + tt1 courseGroupT = "TT1" + tt2 courseGroupT = "TT2" + tt3 courseGroupT = "TT3" +) + +var courseGroups = map[courseGroupT]bool{ + mw1: true, + mw2: true, + mw3: true, + tt1: true, + tt2: true, + tt3: true, +} + +func checkCourseType(ct courseTypeT) bool { + return courseTypes[ct] +} + +func checkCourseGroup(cg courseGroupT) bool { + return courseGroups[cg] +} + +var ( + errInvalidCourseType = errors.New("invalid course type") + errInvalidCourseGroup = errors.New("invalid course group") + errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user") +) + +/* + * The courses are simply stored in a map indexed by the course ID, although + * the course struct itself also contains an ID field. A lock is embedded + * inside the struct; we use a lock here instead of a pointer to a lock as + * it would be easy to forget to initialize the lock when creating the + * struct. However, this means that the struct could not be copied (though + * this should only ever happen during creation anyways), therefore we use a + * pointer to the struct as the value of the map, instead of the struct itself. + */ +var courses map[int](*courseT) + +/* + * This RWMutex is only for massive modifications of the course struct, since + * locking it on every write would be inefficient; in normal operation the only + * write that could occur to the courses struct is changing the Selected + * number, which should be handled with courseT.SelectedLock. + */ +var coursesLock sync.RWMutex + +/* + * Read course information from the database. This should be called during + * setup. Failure to do so before accessing course information may lead to + * a null pointer dereference. + */ +func setupCourses() error { + coursesLock.Lock() + defer coursesLock.Unlock() + + courses = make(map[int](*courseT)) + + rows, err := db.Query( + context.Background(), + "SELECT id, nmax, title, ctype, cgroup, teacher, location FROM courses", + ) + if err != nil { + return fmt.Errorf("error fetching courses: %w", err) + } + + for { + if !rows.Next() { + err := rows.Err() + if err != nil { + return fmt.Errorf("error fetching courses: %w", err) + } + break + } + currentCourse := courseT{ + Usems: make(map[string]*usemT), + } //exhaustruct:ignore + err = rows.Scan( + ¤tCourse.ID, + ¤tCourse.Max, + ¤tCourse.Title, + ¤tCourse.Type, + ¤tCourse.Group, + ¤tCourse.Teacher, + ¤tCourse.Location, + ) + if err != nil { + return fmt.Errorf("error fetching courses: %w", err) + } + if !checkCourseType(currentCourse.Type) { + return fmt.Errorf("%w: %d %s", errInvalidCourseType, currentCourse.ID, currentCourse.Type) + } + if !checkCourseGroup(currentCourse.Group) { + return fmt.Errorf("%w: %d %s", errInvalidCourseGroup, currentCourse.ID, currentCourse.Group) + } + err := db.QueryRow(context.Background(), + "SELECT COUNT (*) FROM choices WHERE courseid = $1", + currentCourse.ID, + ).Scan(¤tCourse.Selected) + if err != nil { + return fmt.Errorf("error querying course member number: %w", err) + } + courses[currentCourse.ID] = ¤tCourse + } + + return nil +} + +type userCourseGroupsT map[courseGroupT]bool + +func populateUserCourseGroups(ctx context.Context, userCourseGroups *userCourseGroupsT, userID string) error { + rows, err := db.Query(ctx, "SELECT courseid FROM choices WHERE userid = $1", userID) + if err != nil { + return fmt.Errorf("error querying user's choices while populating course groups: %w", err) + } + for { + if !rows.Next() { + err := rows.Err() + if err != nil { + return fmt.Errorf("error iterating user's choices while populating course groups: %w", err) + } + break + } + var thisCourseID int + err := rows.Scan(&thisCourseID) + if err != nil { + return fmt.Errorf("error fetching user's choices while populating course groups: %w", err) + } + var thisGroupName courseGroupT + func() { + coursesLock.RLock() + defer coursesLock.RUnlock() + thisGroupName = courses[thisCourseID].Group + }() + if (*userCourseGroups)[thisGroupName] { + return fmt.Errorf("%w: user %v, group %v", errMultipleChoicesInOneGroup, userID, thisGroupName) + } + (*userCourseGroups)[thisGroupName] = true + } + return nil +} + +func (course *courseT) decrementSelectedAndPropagate() { + func() { + course.SelectedLock.Lock() + defer course.SelectedLock.Unlock() + course.Selected-- + }() + propagateSelectedUpdate(course.ID) +} + +func getCourseByID(courseID int) *courseT { + coursesLock.RLock() + defer coursesLock.RUnlock() + return courses[courseID] +} diff --git a/backend/db.go b/backend/db.go new file mode 100644 index 0000000..9ced627 --- /dev/null +++ b/backend/db.go @@ -0,0 +1,51 @@ +/* + * Database handling + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" +) + +var db *pgxpool.Pool + +var errUnsupportedDatabaseType = errors.New("unsupported db type") + +const pgErrUniqueViolation = "23505" + +/* + * This must be run during setup, before the database is accessed by any + * means. Otherwise, db would be a null pointer. + */ +func setupDatabase() error { + var err error + if config.DB.Type != "postgres" { + return errUnsupportedDatabaseType + } + db, err = pgxpool.New(context.Background(), config.DB.Conn) + if err != nil { + return fmt.Errorf("error opening database: %w", err) + } + return nil +} diff --git a/backend/go.mod b/backend/go.mod new file mode 100644 index 0000000..717f737 --- /dev/null +++ b/backend/go.mod @@ -0,0 +1,23 @@ +module git.sr.ht/~runxiyu/cca/backend + +go 1.23.1 + +require ( + git.sr.ht/~emersion/go-scfg v0.0.0-20240128091534-2ae16e782082 + github.com/MicahParks/keyfunc/v3 v3.3.5 + github.com/coder/websocket v1.8.12 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.7.1 +) + +require ( + github.com/MicahParks/jwkset v0.5.20 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/crypto v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/time v0.7.0 // indirect +) diff --git a/backend/go.sum b/backend/go.sum new file mode 100644 index 0000000..ca23612 --- /dev/null +++ b/backend/go.sum @@ -0,0 +1,42 @@ +git.sr.ht/~emersion/go-scfg v0.0.0-20240128091534-2ae16e782082 h1:9Udx5fm4vRtmgDIBjy2ef5QioHbzpw5oHabbhpAUyEw= +git.sr.ht/~emersion/go-scfg v0.0.0-20240128091534-2ae16e782082/go.mod h1:ybgvEJTIx5XbaspSviB3KNa6OdPmAZqDoSud7z8fFlw= +github.com/MicahParks/jwkset v0.5.20 h1:gTIKx9AofTqQJ0srd8AL7ty9NeadP5WUXSPOZadTpOI= +github.com/MicahParks/jwkset v0.5.20/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY= +github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= +github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= +github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/index.go b/backend/index.go new file mode 100644 index 0000000..f446072 --- /dev/null +++ b/backend/index.go @@ -0,0 +1,120 @@ +/* + * Index page + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "errors" + "fmt" + "log" + "net/http" + + "github.com/jackc/pgx/v5" +) + +/* + * Serve the index page. Also handles the login page in case the user doesn't + * have any valid login cookies. + */ +func handleIndex(w http.ResponseWriter, req *http.Request) { + sessionCookie, err := req.Cookie("session") + if errors.Is(err, http.ErrNoCookie) { + authURL, err := generateAuthorizationURL() + if err != nil { + wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL") + return + } + err = tmpl.ExecuteTemplate( + w, + "login", + map[string]string{ + "authURL": authURL, + "source": config.Source, + /* + * We directly generate the login URL here + * instead of doing so in a redirect to save + * requests. + */ + }, + ) + if err != nil { + log.Println(err) + return + } + return + } else if err != nil { + wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.") + return + } + + var userID, userName, userDepartment string + err = db.QueryRow( + req.Context(), + "SELECT id, name, department FROM users WHERE session = $1", + sessionCookie.Value, + ).Scan(&userID, &userName, &userDepartment) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + authURL, err := generateAuthorizationURL() + if err != nil { + wstr(w, http.StatusInternalServerError, "Cannot generate authorization URL") + return + } + err = tmpl.ExecuteTemplate( + w, + "login", + map[string]interface{}{ + "authURL": authURL, + "notes": "You sent an invalid session cookie.", + "source": config.Source, + }, + ) + if err != nil { + log.Println(err) + return + } + return + } + wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: Unexpected database error: %s", err)) + return + } + + err = func() error { + coursesLock.RLock() + defer coursesLock.RUnlock() + return tmpl.ExecuteTemplate( + w, + "student", + map[string]interface{}{ + "open": true, + "user": map[string]interface{}{ + "Name": userName, + "Department": userDepartment, + }, + "courses": courses, + "source": config.Source, + }, + ) + }() + if err != nil { + log.Println(err) + return + } +} diff --git a/backend/main.go b/backend/main.go new file mode 100644 index 0000000..a185de9 --- /dev/null +++ b/backend/main.go @@ -0,0 +1,125 @@ +/* + * Main listener + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "crypto/tls" + "flag" + "html/template" + "log" + "net" + "net/http" + "time" +) + +var tmpl *template.Template + +func main() { + var err error + + var configPath string + + flag.StringVar(&configPath, "config", "cca.scfg", "path to configuration file") + flag.Parse() + + if err := fetchConfig(configPath); err != nil { + log.Fatal(err) + } + + log.Println("Setting up database") + if err := setupDatabase(); err != nil { + log.Fatal(err) + } + + log.Println("Setting up JWKS") + if err := setupJwks(); err != nil { + log.Fatal(err) + } + + log.Println("Setting up templates") + tmpl, err = template.ParseGlob(config.Tmpl) + if err != nil { + log.Fatal(err) + } + + log.Println("Setting up courses") + err = setupCourses() + if err != nil { + log.Fatal(err) + } + + log.Println("Registering static handle") + fs := http.FileServer(http.Dir(config.Static)) + http.Handle("/static/", http.StripPrefix("/static/", fs)) + + log.Println("Registering handlers") + http.HandleFunc("/{$}", handleIndex) + http.HandleFunc("/auth", handleAuth) + http.HandleFunc("/ws", handleWs) + + var l net.Listener + + switch config.Listen.Trans { + case "plain": + log.Printf( + "Establishing plain listener for net \"%s\", addr \"%s\"\n", + config.Listen.Net, + config.Listen.Addr, + ) + l, err = net.Listen(config.Listen.Net, config.Listen.Addr) + if err != nil { + log.Fatalf("Failed to establish plain listener: %v\n", err) + } + case "tls": + cer, err := tls.LoadX509KeyPair(config.Listen.TLS.Cert, config.Listen.TLS.Key) + if err != nil { + log.Fatalf("Failed to load TLS certificate and key: %v\n", err) + } + tlsconfig := &tls.Config{ + Certificates: []tls.Certificate{cer}, + MinVersion: tls.VersionTLS13, + } //exhaustruct:ignore + log.Printf( + "Establishing TLS listener for net \"%s\", addr \"%s\"\n", + config.Listen.Net, + config.Listen.Addr, + ) + l, err = tls.Listen(config.Listen.Net, config.Listen.Addr, tlsconfig) + if err != nil { + log.Fatalf("Failed to establish TLS listener: %v\n", err) + } + default: + log.Fatalln("listen.trans must be \"plain\" or \"tls\"") + } + + if config.Listen.Proto == "http" { + log.Println("Serving http") + srv := &http.Server{ + ReadHeaderTimeout: time.Duration(config.Perf.ReadHeaderTimeout) * time.Second, + } //exhaustruct:ignore + err = srv.Serve(l) + } else { + log.Fatalln("Unsupported protocol") + } + if err != nil { + log.Fatal(err) + } +} diff --git a/backend/usem.go b/backend/usem.go new file mode 100644 index 0000000..0b839a4 --- /dev/null +++ b/backend/usem.go @@ -0,0 +1,36 @@ +/* + * Additional synchronization routines + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +type usemT struct { + ch (chan struct{}) +} + +func (s *usemT) init() { + s.ch = make(chan struct{}, 1) +} + +func (s *usemT) set() { + select { + case s.ch <- struct{}{}: + default: + } +} diff --git a/backend/utils.go b/backend/utils.go new file mode 100644 index 0000000..3628c80 --- /dev/null +++ b/backend/utils.go @@ -0,0 +1,58 @@ +/* + * Utility functions + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "log" + "net/http" +) + +/* + * Write a string to a http.ResponseWriter, setting the Content-Type and status + * code. + */ +func wstr(w http.ResponseWriter, code int, msg string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(code) + _, err := w.Write([]byte(msg)) + if err != nil { + log.Printf("Error wstr'ing to writer: %v", err) + } +} + +/* + * Generate a random url-safe string. + * Note that the "sz" parameter specifies the number of bytes taken from the + * random source divided by three and does NOT represent the length of the + * encoded string. It's divided by three because we're using base64 and it's + * ideal to ensure that the entropy remains consistent throughout the string. + */ +func randomString(sz int) (string, error) { + r := make([]byte, 3*sz) + _, err := rand.Read(r) + if err != nil { + return "", fmt.Errorf("error generating random string: %w", err) + } + return base64.RawURLEncoding.EncodeToString(r), nil +} diff --git a/backend/wsc.go b/backend/wsc.go new file mode 100644 index 0000000..982c1fb --- /dev/null +++ b/backend/wsc.go @@ -0,0 +1,266 @@ +/* + * WebSocket connection routine + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/coder/websocket" +) + +type errbytesT struct { + err error + bytes *[]byte +} + +/* + * The actual logic in handling the connection, after authentication has been + * completed. + */ +func handleConn( + ctx context.Context, + c *websocket.Conn, + session string, + userID string, +) (retErr error) { + reportError := makeReportError(ctx, c) + newCtx, newCancel := context.WithCancel(ctx) + + func() { + cancelPoolLock.Lock() + defer cancelPoolLock.Unlock() + cancel := cancelPool[userID] + if cancel != nil { + (*cancel)() + /* TODO: Make the cancel synchronous */ + } + cancelPool[userID] = &newCancel + }() + defer func() { + cancelPoolLock.Lock() + defer cancelPoolLock.Unlock() + if cancelPool[userID] == &newCancel { + delete(cancelPool, userID) + } + if errors.Is(retErr, context.Canceled) { + /* + * Only works if it's newCtx that has been cancelled + * rather than the original ctx, which is kinda what + * we intend + */ + _ = writeText(ctx, c, "E :Context canceled") + } + }() + + /* TODO: Tell the user their current choices here. Deprecate HELLO. */ + + usems := make(map[int]*usemT) + func() { + coursesLock.RLock() + defer coursesLock.RUnlock() + for courseID, course := range courses { + usem := &usemT{} //exhaustruct:ignore + usem.init() + func() { + course.UsemsLock.Lock() + defer course.UsemsLock.Unlock() + course.Usems[userID] = usem + }() + usems[courseID] = usem + } + }() + defer func() { + coursesLock.RLock() + defer coursesLock.RUnlock() + for _, course := range courses { + func() { + course.UsemsLock.Lock() + defer course.UsemsLock.Unlock() + delete(course.Usems, userID) + }() + } + }() + + usemParent := make(chan int) + for courseID, usem := range usems { + go func() { + for { + select { + case <-newCtx.Done(): + return + case <-usem.ch: + select { + case <-newCtx.Done(): + return + case usemParent <- courseID: + } + } + time.Sleep(time.Duration(config.Perf.CourseUpdateInterval) * time.Millisecond) + } + }() + } + + /* + * userCourseGroups stores whether the user has already chosen a course + * in the courseGroup. + */ + var userCourseGroups userCourseGroupsT = make(map[courseGroupT]bool) + err := populateUserCourseGroups(newCtx, &userCourseGroups, userID) + if err != nil { + return reportError(fmt.Sprintf("cannot populate user course groups: %v", err)) + } + + /* + * Later we need to select from recv and send and perform the + * corresponding action. But we can't just select from c.Read because + * the function blocks. Therefore, we must spawn a goroutine that + * blocks on c.Read and send what it receives to a channel "recv"; and + * then we can select from that channel. + */ + recv := make(chan *errbytesT) + go func() { + for { + /* + * Here we use the original connection context instead + * of the new context we just created. Apparently when + * the context passed to Read expires, the connection + * gets closed, which makes it impossible for us to + * write the context expiry message to the client. + * So we pass the original connection context, which + * would get cancelled anyway once we close the + * connection. + * We still need to take care of this while sending so + * we don't infinitely block, and leak goroutines and + * cause the channel to remain out of reach of the + * garbage collector. + * It would be nice to return the newCtx.Err() but + * the only way to really do that is to use the recv + * channel which might not have a listener anymore. + * It's not really crucial anyways so we could just + * close this goroutine by returning here. + */ + _, b, err := c.Read(ctx) + if err != nil { + /* + * TODO: Prioritize context dones... except + * that it's not really possible. I would just + * have placed newCtx in here but apparently + * that causes the connection to be closed when + * the context expires, which makes it + * impossible to deliver the final error + * message. Probably need to look into this + * design again. + */ + select { + case <-newCtx.Done(): + _ = writeText(ctx, c, "E :Context canceled") + /* Not a typo to use ctx here */ + return + case recv <- &errbytesT{err: err, bytes: nil}: + } + return + } + select { + case <-newCtx.Done(): + _ = writeText(ctx, c, "E :Context cancelled") + /* Not a typo to use ctx here */ + return + case recv <- &errbytesT{err: nil, bytes: &b}: + } + } + }() + + for { + var mar []string + select { + case <-newCtx.Done(): + /* + * TODO: Somehow prioritize this case over all other cases + */ + return fmt.Errorf("context done in main event loop: %w", newCtx.Err()) + /* + * There are other times when the context could be + * cancelled, and apparently some WebSocket functions + * just close the connection when the context is + * cancelled. So it's kinda impossible to reliably + * send this message due to newCtx cancellation. + * But in any case, the WebSocket connection would + * be closed, and the user would see the connection + * closed page which should explain it. + */ + case courseID := <-usemParent: + var selected int + func() { + course := courses[courseID] + course.SelectedLock.RLock() + defer course.SelectedLock.RUnlock() + selected = course.Selected + }() + err := writeText(newCtx, c, fmt.Sprintf("M %d %d", courseID, selected)) + if err != nil { + return fmt.Errorf("error sending to websocket for course selected update: %w", err) + } + continue + case errbytes := <-recv: + if errbytes.err != nil { + return fmt.Errorf("error fetching message from recv channel: %w", errbytes.err) + /* + * Note that this cannot return newCtx.Err(), + * so we handle the error reporting in the + * reading routine + */ + } + mar = splitMsg(errbytes.bytes) + switch mar[0] { + case "HELLO": + err := messageHello(newCtx, c, reportError, mar, userID, session) + if err != nil { + return err + } + case "Y": + err := messageChooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups) + if err != nil { + return err + } + case "N": + err := messageUnchooseCourse(newCtx, c, reportError, mar, userID, session, &userCourseGroups) + if err != nil { + return err + } + default: + return reportError("Unknown command " + mar[0]) + } + } + } +} + +var ( + cancelPool = make(map[string](*context.CancelFunc)) + /* + * Normal Go maps are not thread safe, so we protect large cancelPool + * operations such as addition and deletion under a RWMutex. + */ + cancelPoolLock sync.RWMutex +) diff --git a/backend/wsh.go b/backend/wsh.go new file mode 100644 index 0000000..abf4e61 --- /dev/null +++ b/backend/wsh.go @@ -0,0 +1,165 @@ +/* + * WebSocket endpoint handler + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "errors" + "log" + "net/http" + "time" + + "github.com/coder/websocket" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" +) + +/* + * Handle requests to the WebSocket endpoint and establish a connection. + * Authentication is handled here, but afterwards, the connection is really + * handled in handleConn. + */ +func handleWs(w http.ResponseWriter, req *http.Request) { + wsOptions := &websocket.AcceptOptions{ + Subprotocols: []string{"cca1"}, + } //exhaustruct:ignore + c, err := websocket.Accept( + w, + req, + wsOptions, + ) + if err != nil { + wstr(w, http.StatusBadRequest, "This endpoint only supports valid WebSocket connections.") + return + } + defer func() { + _ = c.CloseNow() + }() + + fake := false + + sessionCookie, err := req.Cookie("session") + if errors.Is(err, http.ErrNoCookie) { + if config.Auth.Fake == 0 { + err := writeText(req.Context(), c, "U") + if err != nil { + log.Println(err) + } + return + } + fake = true + } else if err != nil { + err := writeText(req.Context(), c, "E :Error fetching cookie") + if err != nil { + log.Println(err) + } + return + } + + var userID string + var session string + var expr int + + if fake { + switch config.Auth.Fake { + case 9080: + _uuid, err := uuid.NewRandom() + if err != nil { + log.Println(err) + return + } + userID = _uuid.String() + case 4712: + userID = "fake" + default: + panic("not supposed to happen") + } + session, err = randomString(20) + if err != nil { + log.Println(err) + return + } + _, err = db.Exec( + req.Context(), + "INSERT INTO users (id, name, email, department, session, expr) VALUES ($1, $2, $3, $4, $5, $6)", + userID, + "Fake User", + "fake@runxiyu.org", + "Y11", + session, + time.Now().Add(time.Duration(config.Auth.Expr)*time.Second).Unix(), + ) + if err != nil && config.Auth.Fake != 4712 { + /* TODO check pgerr */ + err := writeText(req.Context(), c, "E :Database error while writing fake account info") + if err != nil { + log.Println(err) + } + return + } + err = writeText(req.Context(), c, "FAKE "+userID+" "+session) + if err != nil { + log.Println(err) + return + } + } else { + session = sessionCookie.Value + err = db.QueryRow( + req.Context(), + "SELECT id, expr FROM users WHERE session = $1", + session, + ).Scan(&userID, &expr) + if errors.Is(err, pgx.ErrNoRows) { + err := writeText(req.Context(), c, "U") + if err != nil { + log.Println(err) + } + return + } else if err != nil { + err := writeText(req.Context(), c, "E :Database error while selecting session") + if err != nil { + log.Println(err) + } + return + } + } + + /* + * Now that we have an authenticated request, this WebSocket connection + * may be simply associated with the session and userID. + * TODO: There are various race conditions that could occur if one user + * creates multiple connections, with the same or different session + * cookies. The last situation could occur in normal use when a user + * opens multiple instances of the page in one browser, and is not + * unique to custom clients or malicious users. Some effort must be + * taken to ensure that each user may only have one connection at a + * time. + */ + err = handleConn( + req.Context(), + c, + session, + userID, + ) + if err != nil { + log.Printf("%v", err) + return + } +} diff --git a/backend/wsm.go b/backend/wsm.go new file mode 100644 index 0000000..5f271ff --- /dev/null +++ b/backend/wsm.go @@ -0,0 +1,224 @@ +/* + * WebSocket message handlers + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/coder/websocket" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +func messageHello(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string) error { + _, _ = mar, session + + select { + case <-ctx.Done(): + return fmt.Errorf("context done when handling hello: %w", ctx.Err()) + default: + } + + rows, err := db.Query( + ctx, + "SELECT courseid FROM choices WHERE userid = $1", + userID, + ) + if err != nil { + return reportError("error fetching choices") + } + courseIDs, err := pgx.CollectRows(rows, pgx.RowTo[string]) + if err != nil { + return reportError("error collecting choices") + } + + err = writeText(ctx, c, "HI :"+strings.Join(courseIDs, ",")) + if err != nil { + return fmt.Errorf("error replying to HELLO: %w", err) + } + + return nil +} + +func messageChooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error { + _ = session + + select { + case <-ctx.Done(): + return fmt.Errorf("context done when handling choose: %w", ctx.Err()) + default: + } + + if len(mar) != 2 { + return reportError("Invalid number of arguments for Y") + } + _courseID, err := strconv.ParseInt(mar[1], 10, strconv.IntSize) + if err != nil { + return reportError("Course ID must be an integer") + } + courseID := int(_courseID) + course := getCourseByID(courseID) + + if course == nil { + return reportError("nil course") + } + + err = func() (returnedError error) { /* Named returns so I could modify them in defer */ + tx, err := db.Begin(ctx) + if err != nil { + return reportError("Database error while beginning transaction") + } + defer func() { + err := tx.Rollback(ctx) + if err != nil && (!errors.Is(err, pgx.ErrTxClosed)) { + returnedError = reportError("Database error while rolling back transaction in defer block") + return + } + }() + + _, err = tx.Exec( + ctx, + "INSERT INTO choices (seltime, userid, courseid) VALUES ($1, $2, $3)", + time.Now().UnixMicro(), + userID, + courseID, + ) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == pgErrUniqueViolation { + err := writeText(ctx, c, "Y "+mar[1]) + if err != nil { + return fmt.Errorf("error reaffirming course choice: %w", err) + } + return nil + } + return reportError("Database error while inserting course choice") + } + + ok := func() bool { + course.SelectedLock.Lock() + defer course.SelectedLock.Unlock() + if course.Selected < course.Max { + course.Selected++ + return true + } + return false + }() + + if ok { + go propagateSelectedUpdate(courseID) + err := tx.Commit(ctx) + if err != nil { + go course.decrementSelectedAndPropagate() + return reportError("Database error while committing transaction") + } + var thisCourseGroup courseGroupT + func() { + coursesLock.RLock() + defer coursesLock.RUnlock() + thisCourseGroup = courses[courseID].Group + }() + if (*userCourseGroups)[thisCourseGroup] { + go course.decrementSelectedAndPropagate() + return reportError("inconsistent user course groups") + } + (*userCourseGroups)[thisCourseGroup] = true + err = writeText(ctx, c, "Y "+mar[1]) + if err != nil { + return fmt.Errorf("error affirming course choice: %w", err) + } + } else { + err := tx.Rollback(ctx) + if err != nil { + return reportError("Database error while rolling back transaction due to course limit") + } + err = writeText(ctx, c, "R "+mar[1]+" :Full") + if err != nil { + return fmt.Errorf("error rejecting course choice: %w", err) + } + } + return nil + }() + if err != nil { + return err + } + return nil +} + +func messageUnchooseCourse(ctx context.Context, c *websocket.Conn, reportError reportErrorT, mar []string, userID string, session string, userCourseGroups *userCourseGroupsT) error { + _ = session + + select { + case <-ctx.Done(): + return fmt.Errorf("context done when handling unchoose: %w", ctx.Err()) + default: + } + + if len(mar) != 2 { + return reportError("Invalid number of arguments for N") + } + _courseID, err := strconv.ParseInt(mar[1], 10, strconv.IntSize) + if err != nil { + return reportError("Course ID must be an integer") + } + courseID := int(_courseID) + course := getCourseByID(courseID) + + if course == nil { + return reportError("nil course") + } + + ct, err := db.Exec( + ctx, + "DELETE FROM choices WHERE userid = $1 AND courseid = $2", + userID, + courseID, + ) + if err != nil { + return reportError("Database error while deleting course choice") + } + + if ct.RowsAffected() != 0 { + go course.decrementSelectedAndPropagate() + var thisCourseGroup courseGroupT + func() { + coursesLock.RLock() + defer coursesLock.RUnlock() + thisCourseGroup = courses[courseID].Group + }() + if !(*userCourseGroups)[thisCourseGroup] { + return reportError("inconsistent user course groups") + } + (*userCourseGroups)[thisCourseGroup] = false + } + + err = writeText(ctx, c, "N "+mar[1]) + if err != nil { + return fmt.Errorf("error replying that course has been deselected: %w", err) + } + + return nil +} diff --git a/backend/wsp.go b/backend/wsp.go new file mode 100644 index 0000000..7a1f6ab --- /dev/null +++ b/backend/wsp.go @@ -0,0 +1,85 @@ +/* + * WebSocket-based protocol auxiliary functions + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "context" + "fmt" + + "github.com/coder/websocket" +) + +/* + * Split an IRC-style message of type []byte into type []string where each + * element is a complete argument. Generally, arguments are separated by + * spaces, and an argument that begins with a ':' causes the rest of the + * line to be treated as a single argument. + */ +func splitMsg(b *[]byte) []string { + mar := make([]string, 0, config.Perf.MessageArgumentsCap) + elem := make([]byte, 0, config.Perf.MessageBytesCap) + for i, c := range *b { + switch c { + case ' ': + if (*b)[i+1] == ':' { + mar = append(mar, string(elem)) + mar = append(mar, string((*b)[i+2:])) + goto endl + } + mar = append(mar, string(elem)) + elem = make([]byte, 0, config.Perf.MessageBytesCap) + default: + elem = append(elem, c) + } + } + mar = append(mar, string(elem)) +endl: + return mar +} + +func baseReportError(ctx context.Context, conn *websocket.Conn, e string) error { + err := writeText(ctx, conn, "E :"+e) + if err != nil { + return fmt.Errorf("error reporting protocol violation: %w", err) + } + err = conn.Close(websocket.StatusProtocolError, e) + if err != nil { + return fmt.Errorf("error closing websocket: %w", err) + } + return nil +} + +type reportErrorT func(e string) error + +func makeReportError(ctx context.Context, conn *websocket.Conn) reportErrorT { + return func(e string) error { + return baseReportError(ctx, conn, e) + } +} + +func propagateSelectedUpdate(courseID int) { + course := courses[courseID] + course.UsemsLock.RLock() + defer course.UsemsLock.RUnlock() + for _, usem := range course.Usems { + usem.set() + } +} diff --git a/backend/wsx.go b/backend/wsx.go new file mode 100644 index 0000000..a51b976 --- /dev/null +++ b/backend/wsx.go @@ -0,0 +1,59 @@ +/* + * Generic WebSocket auxiliary functions + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +/* + * The message format is a WebSocket message separated with spaces. + * The contents of each field could contain anything other than spaces, + * The first character of each argument cannot be a colon. As an exception, the + * last argument may contain spaces and the first character thereof may be a + * colon, if the argument is prefixed with a colon. The colon used for the + * prefix is not considered part of the content of the message. For example, in + * + * SQUISH POP :cat purr!! + * + * the first field is "SQUISH", the second field is "POP", and the third + * field is "cat purr!!". + * + * It is essentially an RFC 1459 IRC message without trailing CR-LF and + * without prefixes. See section 2.3.1 of RFC 1459 for an approximate + * BNF representation. + * + * The reason this was chosen instead of using protobuf etc. is that it + * is simple to parse without external libraries, and it also happens to + * be a format I'm very familiar with, having extensively worked with the + * IRC protocol. + */ + +package main + +import ( + "context" + "fmt" + + "github.com/coder/websocket" +) + +func writeText(ctx context.Context, c *websocket.Conn, msg string) error { + err := c.Write(ctx, websocket.MessageText, []byte(msg)) + if err != nil { + return fmt.Errorf("error writing to connection: %w", err) + } + return nil +} |