diff options
Diffstat (limited to '')
-rw-r--r-- | config.go | 28 | ||||
-rw-r--r-- | endpoint_export.go | 40 | ||||
-rw-r--r-- | endpoint_index.go | 83 | ||||
-rw-r--r-- | endpoint_newcourses.go | 44 | ||||
-rw-r--r-- | endpoint_state.go | 68 | ||||
-rw-r--r-- | endpoint_ws.go | 110 | ||||
-rw-r--r-- | errors.go | 4 | ||||
-rw-r--r-- | session.go | 54 | ||||
-rw-r--r-- | state.go | 75 |
9 files changed, 157 insertions, 349 deletions
@@ -23,7 +23,6 @@ package main import ( "bufio" "fmt" - "log" "os" "git.sr.ht/~emersion/go-scfg" @@ -55,7 +54,6 @@ var configWithPointers struct { Conn *string `scfg:"conn"` } `scfg:"db"` Auth struct { - Fake *int `scfg:"fake"` Client *string `scfg:"client"` Authorize *string `scfg:"authorize"` Jwks *string `scfg:"jwks"` @@ -91,7 +89,6 @@ var config struct { Conn string } Auth struct { - Fake int Client string Authorize string Jwks string @@ -184,31 +181,6 @@ func fetchConfig(path string) (retErr error) { } config.DB.Conn = *(configWithPointers.DB.Conn) - if configWithPointers.Auth.Fake == nil { - config.Auth.Fake = 0 - } else { - config.Auth.Fake = *(configWithPointers.Auth.Fake) - switch config.Auth.Fake { - case 0: - /* It's okay to set it to 0 in production */ - case 4712, 9080: /* Don't use them unless you know what you're doing */ - if config.Prod { - return fmt.Errorf( - "%w: fake authentication is incompatible with production mode", - errIllegalConfig, - ) - } - log.Println( - "!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.", - ) - default: - return fmt.Errorf( - "%w: invalid option for auth.fake", - errIllegalConfig, - ) - } - } - if configWithPointers.Auth.Client == nil { return fmt.Errorf("%w: auth.client", errMissingConfigValue) } diff --git a/endpoint_export.go b/endpoint_export.go index 401c632..5118181 100644 --- a/endpoint_export.go +++ b/endpoint_export.go @@ -22,55 +22,21 @@ package main import ( "encoding/csv" - "errors" "fmt" "net/http" "strings" - - "github.com/jackc/pgx/v5" ) func handleExport(w http.ResponseWriter, req *http.Request) { - sessionCookie, err := req.Cookie("session") - if errors.Is(err, http.ErrNoCookie) { - wstr( - w, - http.StatusUnauthorized, - "No session cookie, which is required for this endpoint", - ) - return - } else if err != nil { - wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.") - return - } - - var userDepartment string - err = db.QueryRow( - req.Context(), - "SELECT department FROM users WHERE session = $1", - sessionCookie.Value, - ).Scan(&userDepartment) + _, _, department, err := getUserInfoFromRequest(req) if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - wstr( - w, - http.StatusForbidden, - "Invalid session cookie", - ) - return - } wstr( w, http.StatusInternalServerError, - fmt.Sprintf( - "Error: Unexpected database error: %s", - err, - ), + fmt.Sprintf("Error: %v", err), ) - return } - - if userDepartment != staffDepartment { + if department != staffDepartment { wstr( w, http.StatusForbidden, diff --git a/endpoint_index.go b/endpoint_index.go index 512fe1d..e6a5741 100644 --- a/endpoint_index.go +++ b/endpoint_index.go @@ -26,15 +26,13 @@ import ( "log" "net/http" "sync/atomic" - - "github.com/jackc/pgx/v5" ) func handleIndex(w http.ResponseWriter, req *http.Request) { - sessionCookie, err := req.Cookie("session") - if errors.Is(err, http.ErrNoCookie) { - authURL, err := generateAuthorizationURL() - if err != nil { + _, username, department, err := getUserInfoFromRequest(req) + if errors.Is(err, errNoCookie) || errors.Is(err, errNoSuchUser) { + authURL, err2 := generateAuthorizationURL() + if err2 != nil { wstr( w, http.StatusInternalServerError, @@ -42,7 +40,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { ) return } - err = tmpl.ExecuteTemplate( + var noteString string + if errors.Is(err, errNoSuchUser) { + noteString = "Your browser provided an invalid session cookie." + } + err2 = tmpl.ExecuteTemplate( w, "login", struct { @@ -50,62 +52,15 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { Notes string }{ authURL, - "", + noteString, }, ) - if err != nil { - log.Println(err) + if err2 != nil { + log.Println(err2) return } - return } else if err != nil { - wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.") - return - } - - var userID, userName, userDepartment string - err = db.QueryRow( - req.Context(), - "SELECT id, name, department FROM users WHERE session = $1", - sessionCookie.Value, - ).Scan(&userID, &userName, &userDepartment) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - authURL, err := generateAuthorizationURL() - if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Cannot generate authorization URL", - ) - return - } - err = tmpl.ExecuteTemplate( - w, - "login", - struct { - AuthURL string - Notes string - }{ - authURL, - "Your session is invalid or has expired.", - }, - ) - if err != nil { - log.Println(err) - return - } - return - } - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf( - "Error: Unexpected database error: %s", - err, - ), - ) - return + wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: %v", err)) } /* TODO: The below should be completed on-update. */ @@ -136,7 +91,7 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { return true }) - if userDepartment == staffDepartment { + if department == staffDepartment { err := tmpl.ExecuteTemplate( w, "staff", @@ -145,7 +100,7 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { State uint32 Groups *map[courseGroupT]groupT }{ - userName, + username, state, &_groups, }, @@ -164,8 +119,8 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { Name string Department string }{ - userName, - userDepartment, + username, + department, }, ) if err != nil { @@ -182,8 +137,8 @@ func handleIndex(w http.ResponseWriter, req *http.Request) { Department string Groups *map[courseGroupT]groupT }{ - userName, - userDepartment, + username, + department, &_groups, }, ) diff --git a/endpoint_newcourses.go b/endpoint_newcourses.go index 5963e1b..bf8f570 100644 --- a/endpoint_newcourses.go +++ b/endpoint_newcourses.go @@ -34,54 +34,19 @@ import ( func handleNewCourses(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodPost { - wstr( - w, - http.StatusMethodNotAllowed, - "Only POST is allowed here", - ) - return - } - - sessionCookie, err := req.Cookie("session") - if errors.Is(err, http.ErrNoCookie) { - wstr( - w, - http.StatusUnauthorized, - "No session cookie, which is required for this endpoint", - ) - return - } else if err != nil { - wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.") + wstr(w, http.StatusMethodNotAllowed, "Only POST is allowed here") return } - var userDepartment string - err = db.QueryRow( - req.Context(), - "SELECT department FROM users WHERE session = $1", - sessionCookie.Value, - ).Scan(&userDepartment) + _, _, department, err := getUserInfoFromRequest(req) if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - wstr( - w, - http.StatusForbidden, - "Invalid session cookie", - ) - return - } wstr( w, http.StatusInternalServerError, - fmt.Sprintf( - "Error: Unexpected database error: %s", - err, - ), + fmt.Sprintf("Error: %v", err), ) - return } - - if userDepartment != staffDepartment { + if department != staffDepartment { wstr( w, http.StatusForbidden, @@ -98,6 +63,7 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) { ) return } + /* TODO: Potential race. The global state may need to be write-locked. */ file, fileHeader, err := req.FormFile("coursecsv") diff --git a/endpoint_state.go b/endpoint_state.go new file mode 100644 index 0000000..4e0ee56 --- /dev/null +++ b/endpoint_state.go @@ -0,0 +1,68 @@ +/* + * Let staff update state + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "fmt" + "net/http" + "strconv" +) + +func handleState(w http.ResponseWriter, req *http.Request) { + _, _, department, err := getUserInfoFromRequest(req) + if err != nil { + wstr( + w, + http.StatusInternalServerError, + fmt.Sprintf("Error: %v", err), + ) + } + if department != staffDepartment { + wstr( + w, + http.StatusForbidden, + "You are not authorized to view this page", + ) + return + } + + basePath := req.PathValue("s") + newState, err := strconv.ParseUint(basePath, 10, 32) + if err != nil { + wstr( + w, + http.StatusBadRequest, + "State must be an unsigned 32-bit integer", + ) + return + } + err = setState(req.Context(), uint32(newState)) + if err != nil { + wstr( + w, + http.StatusInternalServerError, + "Failed setting state, please return to previous page; are you sure it's within limits?", + ) + return + } + + http.Redirect(w, req, "/", http.StatusSeeOther) +} diff --git a/endpoint_ws.go b/endpoint_ws.go index 4bf2319..45cee60 100644 --- a/endpoint_ws.go +++ b/endpoint_ws.go @@ -21,14 +21,10 @@ package main import ( - "errors" "log" "net/http" - "time" "github.com/coder/websocket" - "github.com/google/uuid" - "github.com/jackc/pgx/v5" ) /* @@ -57,114 +53,18 @@ func handleWs(w http.ResponseWriter, req *http.Request) { _ = c.CloseNow() }() - fake := false - - sessionCookie, err := req.Cookie("session") - if errors.Is(err, http.ErrNoCookie) { - if config.Auth.Fake == 0 { - err := writeText(req.Context(), c, "U") - if err != nil { - log.Println(err) - } - return - } - fake = true - } else if err != nil { - err := writeText(req.Context(), c, "E :Error fetching cookie") + userID, _, _, err := getUserInfoFromRequest(req) + if err != nil { + err := writeText(req.Context(), c, "U") if err != nil { log.Println(err) } return } - var userID string - var session string - var expr int - - if fake { - switch config.Auth.Fake { - case 9080: - _uuid, err := uuid.NewRandom() - if err != nil { - log.Println(err) - return - } - userID = _uuid.String() - case 4712: - userID = "fake" - default: - panic("not supposed to happen") - } - session, err = randomString(20) - if err != nil { - log.Println(err) - return - } - _, err = db.Exec( - req.Context(), - "INSERT INTO users (id, name, email, department, session, expr) VALUES ($1, $2, $3, $4, $5, $6)", - userID, - "Fake User", - "fake@runxiyu.org", - "Y11", - session, - time.Now().Add( - time.Duration(config.Auth.Expr)*time.Second, - ).Unix(), - ) - if err != nil && config.Auth.Fake != 4712 { - err := writeText( - req.Context(), - c, - "E :Database error while writing fake account info", - ) - if err != nil { - log.Println(err) - } - return - } - err = writeText(req.Context(), c, "FAKE "+userID+" "+session) - if err != nil { - log.Println(err) - return - } - } else { - session = sessionCookie.Value - err = db.QueryRow( - req.Context(), - "SELECT id, expr FROM users WHERE session = $1", - session, - ).Scan(&userID, &expr) - if errors.Is(err, pgx.ErrNoRows) { - err := writeText(req.Context(), c, "U") - if err != nil { - log.Println(err) - } - return - } else if err != nil { - err := writeText( - req.Context(), - c, - "E :Database error while selecting session", - ) - if err != nil { - log.Println(err) - } - return - } - } - - /* - * Now that we have an authenticated request, this WebSocket connection - * may be simply associated with the session and userID. - */ - err = handleConn( - req.Context(), - c, - userID, - ) + err = handleConn(req.Context(), c, userID) if err != nil { - log.Printf("%v", err) + log.Println(err) return } } @@ -35,7 +35,6 @@ var ( errCannotOpenConfig = errors.New("cannot open configuration file") errCannotDecodeConfig = errors.New("cannot decode configuration file") errMissingConfigValue = errors.New("missing configuration value") - errIllegalConfig = errors.New("illegal configuration") errInvalidCourseType = errors.New("invalid course type") errInvalidCourseGroup = errors.New("invalid course group") errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user") @@ -48,6 +47,9 @@ var ( errNoSuchCourse = errors.New("no such course") errInvalidState = errors.New("invalid state") errWebSocketWrite = errors.New("error writing to websocket") + errCannotCheckCookie = errors.New("error checking cookie") + errNoCookie = errors.New("no cookie found") + errNoSuchUser = errors.New("no such user") ) func wrapError(a, b error) error { diff --git a/session.go b/session.go new file mode 100644 index 0000000..be83c5c --- /dev/null +++ b/session.go @@ -0,0 +1,54 @@ +/* + * Session checking functions + * + * Copyright (C) 2024 Runxi Yu <https://runxiyu.org> + * SPDX-License-Identifier: AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +package main + +import ( + "errors" + "net/http" + + "github.com/jackc/pgx/v5" +) + +func getUserInfoFromRequest(req *http.Request) (userID, username, department string, retErr error) { + sessionCookie, err := req.Cookie("session") + if errors.Is(err, http.ErrNoCookie) { + retErr = wrapError(errNoCookie, err) + return + } else if err != nil { + retErr = wrapError(errCannotCheckCookie, err) + return + } + + err = db.QueryRow( + req.Context(), + "SELECT id, name, department FROM users WHERE session = $1", + sessionCookie.Value, + ).Scan(&userID, &username, &department) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + retErr = errNoSuchUser + return + } + retErr = wrapError(errUnexpectedDBError, err) + return + } + return +} @@ -23,9 +23,6 @@ package main import ( "context" "errors" - "fmt" - "net/http" - "strconv" "sync/atomic" "github.com/jackc/pgx/v5" @@ -100,75 +97,3 @@ func setState(ctx context.Context, newState uint32) error { atomic.StoreUint32(&state, newState) return nil } - -func handleState(w http.ResponseWriter, req *http.Request) { - sessionCookie, err := req.Cookie("session") - if errors.Is(err, http.ErrNoCookie) { - wstr( - w, - http.StatusUnauthorized, - "No session cookie, which is required for this endpoint", - ) - return - } else if err != nil { - wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.") - return - } - - var userID, userName, userDepartment string - err = db.QueryRow( - req.Context(), - "SELECT id, name, department FROM users WHERE session = $1", - sessionCookie.Value, - ).Scan(&userID, &userName, &userDepartment) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - wstr( - w, - http.StatusForbidden, - "Invalid session cookie", - ) - return - } - wstr( - w, - http.StatusInternalServerError, - fmt.Sprintf( - "Error: Unexpected database error: %s", - err, - ), - ) - return - } - - if userDepartment != staffDepartment { - wstr( - w, - http.StatusForbidden, - "You are not authorized to view this page", - ) - return - } - - basePath := req.PathValue("s") - newState, err := strconv.ParseUint(basePath, 10, 32) - if err != nil { - wstr( - w, - http.StatusBadRequest, - "State must be an unsigned 32-bit integer", - ) - return - } - err = setState(req.Context(), uint32(newState)) - if err != nil { - wstr( - w, - http.StatusInternalServerError, - "Failed setting state, please return to previous page; are you sure it's within limits?", - ) - return - } - - http.Redirect(w, req, "/", http.StatusSeeOther) -} |