summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.go28
-rw-r--r--endpoint_export.go40
-rw-r--r--endpoint_index.go83
-rw-r--r--endpoint_newcourses.go44
-rw-r--r--endpoint_state.go68
-rw-r--r--endpoint_ws.go110
-rw-r--r--errors.go4
-rw-r--r--session.go54
-rw-r--r--state.go75
9 files changed, 157 insertions, 349 deletions
diff --git a/config.go b/config.go
index 576ac6d..4fc098c 100644
--- a/config.go
+++ b/config.go
@@ -23,7 +23,6 @@ package main
import (
"bufio"
"fmt"
- "log"
"os"
"git.sr.ht/~emersion/go-scfg"
@@ -55,7 +54,6 @@ var configWithPointers struct {
Conn *string `scfg:"conn"`
} `scfg:"db"`
Auth struct {
- Fake *int `scfg:"fake"`
Client *string `scfg:"client"`
Authorize *string `scfg:"authorize"`
Jwks *string `scfg:"jwks"`
@@ -91,7 +89,6 @@ var config struct {
Conn string
}
Auth struct {
- Fake int
Client string
Authorize string
Jwks string
@@ -184,31 +181,6 @@ func fetchConfig(path string) (retErr error) {
}
config.DB.Conn = *(configWithPointers.DB.Conn)
- if configWithPointers.Auth.Fake == nil {
- config.Auth.Fake = 0
- } else {
- config.Auth.Fake = *(configWithPointers.Auth.Fake)
- switch config.Auth.Fake {
- case 0:
- /* It's okay to set it to 0 in production */
- case 4712, 9080: /* Don't use them unless you know what you're doing */
- if config.Prod {
- return fmt.Errorf(
- "%w: fake authentication is incompatible with production mode",
- errIllegalConfig,
- )
- }
- log.Println(
- "!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.",
- )
- default:
- return fmt.Errorf(
- "%w: invalid option for auth.fake",
- errIllegalConfig,
- )
- }
- }
-
if configWithPointers.Auth.Client == nil {
return fmt.Errorf("%w: auth.client", errMissingConfigValue)
}
diff --git a/endpoint_export.go b/endpoint_export.go
index 401c632..5118181 100644
--- a/endpoint_export.go
+++ b/endpoint_export.go
@@ -22,55 +22,21 @@ package main
import (
"encoding/csv"
- "errors"
"fmt"
"net/http"
"strings"
-
- "github.com/jackc/pgx/v5"
)
func handleExport(w http.ResponseWriter, req *http.Request) {
- sessionCookie, err := req.Cookie("session")
- if errors.Is(err, http.ErrNoCookie) {
- wstr(
- w,
- http.StatusUnauthorized,
- "No session cookie, which is required for this endpoint",
- )
- return
- } else if err != nil {
- wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.")
- return
- }
-
- var userDepartment string
- err = db.QueryRow(
- req.Context(),
- "SELECT department FROM users WHERE session = $1",
- sessionCookie.Value,
- ).Scan(&userDepartment)
+ _, _, department, err := getUserInfoFromRequest(req)
if err != nil {
- if errors.Is(err, pgx.ErrNoRows) {
- wstr(
- w,
- http.StatusForbidden,
- "Invalid session cookie",
- )
- return
- }
wstr(
w,
http.StatusInternalServerError,
- fmt.Sprintf(
- "Error: Unexpected database error: %s",
- err,
- ),
+ fmt.Sprintf("Error: %v", err),
)
- return
}
-
- if userDepartment != staffDepartment {
+ if department != staffDepartment {
wstr(
w,
http.StatusForbidden,
diff --git a/endpoint_index.go b/endpoint_index.go
index 512fe1d..e6a5741 100644
--- a/endpoint_index.go
+++ b/endpoint_index.go
@@ -26,15 +26,13 @@ import (
"log"
"net/http"
"sync/atomic"
-
- "github.com/jackc/pgx/v5"
)
func handleIndex(w http.ResponseWriter, req *http.Request) {
- sessionCookie, err := req.Cookie("session")
- if errors.Is(err, http.ErrNoCookie) {
- authURL, err := generateAuthorizationURL()
- if err != nil {
+ _, username, department, err := getUserInfoFromRequest(req)
+ if errors.Is(err, errNoCookie) || errors.Is(err, errNoSuchUser) {
+ authURL, err2 := generateAuthorizationURL()
+ if err2 != nil {
wstr(
w,
http.StatusInternalServerError,
@@ -42,7 +40,11 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
)
return
}
- err = tmpl.ExecuteTemplate(
+ var noteString string
+ if errors.Is(err, errNoSuchUser) {
+ noteString = "Your browser provided an invalid session cookie."
+ }
+ err2 = tmpl.ExecuteTemplate(
w,
"login",
struct {
@@ -50,62 +52,15 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
Notes string
}{
authURL,
- "",
+ noteString,
},
)
- if err != nil {
- log.Println(err)
+ if err2 != nil {
+ log.Println(err2)
return
}
- return
} else if err != nil {
- wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.")
- return
- }
-
- var userID, userName, userDepartment string
- err = db.QueryRow(
- req.Context(),
- "SELECT id, name, department FROM users WHERE session = $1",
- sessionCookie.Value,
- ).Scan(&userID, &userName, &userDepartment)
- if err != nil {
- if errors.Is(err, pgx.ErrNoRows) {
- authURL, err := generateAuthorizationURL()
- if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Cannot generate authorization URL",
- )
- return
- }
- err = tmpl.ExecuteTemplate(
- w,
- "login",
- struct {
- AuthURL string
- Notes string
- }{
- authURL,
- "Your session is invalid or has expired.",
- },
- )
- if err != nil {
- log.Println(err)
- return
- }
- return
- }
- wstr(
- w,
- http.StatusInternalServerError,
- fmt.Sprintf(
- "Error: Unexpected database error: %s",
- err,
- ),
- )
- return
+ wstr(w, http.StatusInternalServerError, fmt.Sprintf("Error: %v", err))
}
/* TODO: The below should be completed on-update. */
@@ -136,7 +91,7 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
return true
})
- if userDepartment == staffDepartment {
+ if department == staffDepartment {
err := tmpl.ExecuteTemplate(
w,
"staff",
@@ -145,7 +100,7 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
State uint32
Groups *map[courseGroupT]groupT
}{
- userName,
+ username,
state,
&_groups,
},
@@ -164,8 +119,8 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
Name string
Department string
}{
- userName,
- userDepartment,
+ username,
+ department,
},
)
if err != nil {
@@ -182,8 +137,8 @@ func handleIndex(w http.ResponseWriter, req *http.Request) {
Department string
Groups *map[courseGroupT]groupT
}{
- userName,
- userDepartment,
+ username,
+ department,
&_groups,
},
)
diff --git a/endpoint_newcourses.go b/endpoint_newcourses.go
index 5963e1b..bf8f570 100644
--- a/endpoint_newcourses.go
+++ b/endpoint_newcourses.go
@@ -34,54 +34,19 @@ import (
func handleNewCourses(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
- wstr(
- w,
- http.StatusMethodNotAllowed,
- "Only POST is allowed here",
- )
- return
- }
-
- sessionCookie, err := req.Cookie("session")
- if errors.Is(err, http.ErrNoCookie) {
- wstr(
- w,
- http.StatusUnauthorized,
- "No session cookie, which is required for this endpoint",
- )
- return
- } else if err != nil {
- wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.")
+ wstr(w, http.StatusMethodNotAllowed, "Only POST is allowed here")
return
}
- var userDepartment string
- err = db.QueryRow(
- req.Context(),
- "SELECT department FROM users WHERE session = $1",
- sessionCookie.Value,
- ).Scan(&userDepartment)
+ _, _, department, err := getUserInfoFromRequest(req)
if err != nil {
- if errors.Is(err, pgx.ErrNoRows) {
- wstr(
- w,
- http.StatusForbidden,
- "Invalid session cookie",
- )
- return
- }
wstr(
w,
http.StatusInternalServerError,
- fmt.Sprintf(
- "Error: Unexpected database error: %s",
- err,
- ),
+ fmt.Sprintf("Error: %v", err),
)
- return
}
-
- if userDepartment != staffDepartment {
+ if department != staffDepartment {
wstr(
w,
http.StatusForbidden,
@@ -98,6 +63,7 @@ func handleNewCourses(w http.ResponseWriter, req *http.Request) {
)
return
}
+
/* TODO: Potential race. The global state may need to be write-locked. */
file, fileHeader, err := req.FormFile("coursecsv")
diff --git a/endpoint_state.go b/endpoint_state.go
new file mode 100644
index 0000000..4e0ee56
--- /dev/null
+++ b/endpoint_state.go
@@ -0,0 +1,68 @@
+/*
+ * Let staff update state
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "fmt"
+ "net/http"
+ "strconv"
+)
+
+func handleState(w http.ResponseWriter, req *http.Request) {
+ _, _, department, err := getUserInfoFromRequest(req)
+ if err != nil {
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ fmt.Sprintf("Error: %v", err),
+ )
+ }
+ if department != staffDepartment {
+ wstr(
+ w,
+ http.StatusForbidden,
+ "You are not authorized to view this page",
+ )
+ return
+ }
+
+ basePath := req.PathValue("s")
+ newState, err := strconv.ParseUint(basePath, 10, 32)
+ if err != nil {
+ wstr(
+ w,
+ http.StatusBadRequest,
+ "State must be an unsigned 32-bit integer",
+ )
+ return
+ }
+ err = setState(req.Context(), uint32(newState))
+ if err != nil {
+ wstr(
+ w,
+ http.StatusInternalServerError,
+ "Failed setting state, please return to previous page; are you sure it's within limits?",
+ )
+ return
+ }
+
+ http.Redirect(w, req, "/", http.StatusSeeOther)
+}
diff --git a/endpoint_ws.go b/endpoint_ws.go
index 4bf2319..45cee60 100644
--- a/endpoint_ws.go
+++ b/endpoint_ws.go
@@ -21,14 +21,10 @@
package main
import (
- "errors"
"log"
"net/http"
- "time"
"github.com/coder/websocket"
- "github.com/google/uuid"
- "github.com/jackc/pgx/v5"
)
/*
@@ -57,114 +53,18 @@ func handleWs(w http.ResponseWriter, req *http.Request) {
_ = c.CloseNow()
}()
- fake := false
-
- sessionCookie, err := req.Cookie("session")
- if errors.Is(err, http.ErrNoCookie) {
- if config.Auth.Fake == 0 {
- err := writeText(req.Context(), c, "U")
- if err != nil {
- log.Println(err)
- }
- return
- }
- fake = true
- } else if err != nil {
- err := writeText(req.Context(), c, "E :Error fetching cookie")
+ userID, _, _, err := getUserInfoFromRequest(req)
+ if err != nil {
+ err := writeText(req.Context(), c, "U")
if err != nil {
log.Println(err)
}
return
}
- var userID string
- var session string
- var expr int
-
- if fake {
- switch config.Auth.Fake {
- case 9080:
- _uuid, err := uuid.NewRandom()
- if err != nil {
- log.Println(err)
- return
- }
- userID = _uuid.String()
- case 4712:
- userID = "fake"
- default:
- panic("not supposed to happen")
- }
- session, err = randomString(20)
- if err != nil {
- log.Println(err)
- return
- }
- _, err = db.Exec(
- req.Context(),
- "INSERT INTO users (id, name, email, department, session, expr) VALUES ($1, $2, $3, $4, $5, $6)",
- userID,
- "Fake User",
- "fake@runxiyu.org",
- "Y11",
- session,
- time.Now().Add(
- time.Duration(config.Auth.Expr)*time.Second,
- ).Unix(),
- )
- if err != nil && config.Auth.Fake != 4712 {
- err := writeText(
- req.Context(),
- c,
- "E :Database error while writing fake account info",
- )
- if err != nil {
- log.Println(err)
- }
- return
- }
- err = writeText(req.Context(), c, "FAKE "+userID+" "+session)
- if err != nil {
- log.Println(err)
- return
- }
- } else {
- session = sessionCookie.Value
- err = db.QueryRow(
- req.Context(),
- "SELECT id, expr FROM users WHERE session = $1",
- session,
- ).Scan(&userID, &expr)
- if errors.Is(err, pgx.ErrNoRows) {
- err := writeText(req.Context(), c, "U")
- if err != nil {
- log.Println(err)
- }
- return
- } else if err != nil {
- err := writeText(
- req.Context(),
- c,
- "E :Database error while selecting session",
- )
- if err != nil {
- log.Println(err)
- }
- return
- }
- }
-
- /*
- * Now that we have an authenticated request, this WebSocket connection
- * may be simply associated with the session and userID.
- */
- err = handleConn(
- req.Context(),
- c,
- userID,
- )
+ err = handleConn(req.Context(), c, userID)
if err != nil {
- log.Printf("%v", err)
+ log.Println(err)
return
}
}
diff --git a/errors.go b/errors.go
index 77de2a3..6f1c130 100644
--- a/errors.go
+++ b/errors.go
@@ -35,7 +35,6 @@ var (
errCannotOpenConfig = errors.New("cannot open configuration file")
errCannotDecodeConfig = errors.New("cannot decode configuration file")
errMissingConfigValue = errors.New("missing configuration value")
- errIllegalConfig = errors.New("illegal configuration")
errInvalidCourseType = errors.New("invalid course type")
errInvalidCourseGroup = errors.New("invalid course group")
errMultipleChoicesInOneGroup = errors.New("multiple choices per group per user")
@@ -48,6 +47,9 @@ var (
errNoSuchCourse = errors.New("no such course")
errInvalidState = errors.New("invalid state")
errWebSocketWrite = errors.New("error writing to websocket")
+ errCannotCheckCookie = errors.New("error checking cookie")
+ errNoCookie = errors.New("no cookie found")
+ errNoSuchUser = errors.New("no such user")
)
func wrapError(a, b error) error {
diff --git a/session.go b/session.go
new file mode 100644
index 0000000..be83c5c
--- /dev/null
+++ b/session.go
@@ -0,0 +1,54 @@
+/*
+ * Session checking functions
+ *
+ * Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package main
+
+import (
+ "errors"
+ "net/http"
+
+ "github.com/jackc/pgx/v5"
+)
+
+func getUserInfoFromRequest(req *http.Request) (userID, username, department string, retErr error) {
+ sessionCookie, err := req.Cookie("session")
+ if errors.Is(err, http.ErrNoCookie) {
+ retErr = wrapError(errNoCookie, err)
+ return
+ } else if err != nil {
+ retErr = wrapError(errCannotCheckCookie, err)
+ return
+ }
+
+ err = db.QueryRow(
+ req.Context(),
+ "SELECT id, name, department FROM users WHERE session = $1",
+ sessionCookie.Value,
+ ).Scan(&userID, &username, &department)
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ retErr = errNoSuchUser
+ return
+ }
+ retErr = wrapError(errUnexpectedDBError, err)
+ return
+ }
+ return
+}
diff --git a/state.go b/state.go
index 20e1d84..23d5c5d 100644
--- a/state.go
+++ b/state.go
@@ -23,9 +23,6 @@ package main
import (
"context"
"errors"
- "fmt"
- "net/http"
- "strconv"
"sync/atomic"
"github.com/jackc/pgx/v5"
@@ -100,75 +97,3 @@ func setState(ctx context.Context, newState uint32) error {
atomic.StoreUint32(&state, newState)
return nil
}
-
-func handleState(w http.ResponseWriter, req *http.Request) {
- sessionCookie, err := req.Cookie("session")
- if errors.Is(err, http.ErrNoCookie) {
- wstr(
- w,
- http.StatusUnauthorized,
- "No session cookie, which is required for this endpoint",
- )
- return
- } else if err != nil {
- wstr(w, http.StatusBadRequest, "Error: Unable to check cookie.")
- return
- }
-
- var userID, userName, userDepartment string
- err = db.QueryRow(
- req.Context(),
- "SELECT id, name, department FROM users WHERE session = $1",
- sessionCookie.Value,
- ).Scan(&userID, &userName, &userDepartment)
- if err != nil {
- if errors.Is(err, pgx.ErrNoRows) {
- wstr(
- w,
- http.StatusForbidden,
- "Invalid session cookie",
- )
- return
- }
- wstr(
- w,
- http.StatusInternalServerError,
- fmt.Sprintf(
- "Error: Unexpected database error: %s",
- err,
- ),
- )
- return
- }
-
- if userDepartment != staffDepartment {
- wstr(
- w,
- http.StatusForbidden,
- "You are not authorized to view this page",
- )
- return
- }
-
- basePath := req.PathValue("s")
- newState, err := strconv.ParseUint(basePath, 10, 32)
- if err != nil {
- wstr(
- w,
- http.StatusBadRequest,
- "State must be an unsigned 32-bit integer",
- )
- return
- }
- err = setState(req.Context(), uint32(newState))
- if err != nil {
- wstr(
- w,
- http.StatusInternalServerError,
- "Failed setting state, please return to previous page; are you sure it's within limits?",
- )
- return
- }
-
- http.Redirect(w, req, "/", http.StatusSeeOther)
-}