aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--config.go47
-rw-r--r--docs/cca.scfg.example18
-rw-r--r--endpoint_auth.go162
-rw-r--r--errors.go3
4 files changed, 69 insertions, 161 deletions
diff --git a/config.go b/config.go
index 5f9ccd8..2efa4dc 100644
--- a/config.go
+++ b/config.go
@@ -54,12 +54,13 @@ var configWithPointers struct {
Conn *string `scfg:"conn"`
} `scfg:"db"`
Auth struct {
- Client *string `scfg:"client"`
- Authorize *string `scfg:"authorize"`
- Jwks *string `scfg:"jwks"`
- Token *string `scfg:"token"`
- Secret *string `scfg:"secret"`
- Expr *int `scfg:"expr"`
+ Client *string `scfg:"client"`
+ Authorize *string `scfg:"authorize"`
+ Jwks *string `scfg:"jwks"`
+ Token *string `scfg:"token"`
+ Expr *int `scfg:"expr"`
+ Departments *map[string]string `scfg:"depts"`
+ Udepts *map[string]string `scfg:"udepts"`
} `scfg:"auth"`
Perf struct {
SendQ *int `scfg:"sendq"`
@@ -107,12 +108,13 @@ var config struct {
Conn string
}
Auth struct {
- Client string
- Authorize string
- Jwks string
- Token string
- Secret string
- Expr int
+ Client string
+ Authorize string
+ Jwks string
+ Token string
+ Expr int
+ Departments map[string]string
+ Udepts map[string]string
}
Perf struct {
SendQ int
@@ -237,16 +239,27 @@ func fetchConfig(path string) (retErr error) {
}
config.Auth.Token = *(configWithPointers.Auth.Token)
- if configWithPointers.Auth.Secret == nil {
- return fmt.Errorf("%w: auth.secret", errMissingConfigValue)
- }
- config.Auth.Secret = *(configWithPointers.Auth.Secret)
-
if configWithPointers.Auth.Expr == nil {
return fmt.Errorf("%w: auth.expr", errMissingConfigValue)
}
config.Auth.Expr = *(configWithPointers.Auth.Expr)
+ if configWithPointers.Auth.Departments == nil {
+ return fmt.Errorf("%w: auth.depts", errMissingConfigValue)
+ }
+ config.Auth.Departments = *(configWithPointers.Auth.Departments)
+ if config.Auth.Departments == nil {
+ return fmt.Errorf("%w: auth.depts", errMissingConfigValue)
+ }
+
+ if configWithPointers.Auth.Udepts == nil {
+ return fmt.Errorf("%w: auth.udepts", errMissingConfigValue)
+ }
+ config.Auth.Udepts = *(configWithPointers.Auth.Udepts)
+ if config.Auth.Udepts == nil {
+ return fmt.Errorf("%w: auth.udepts", errMissingConfigValue)
+ }
+
if configWithPointers.Perf.SendQ == nil {
return fmt.Errorf(
"%w: perf.sendq",
diff --git a/docs/cca.scfg.example b/docs/cca.scfg.example
index 5ce05d3..4ba7f84 100644
--- a/docs/cca.scfg.example
+++ b/docs/cca.scfg.example
@@ -60,12 +60,24 @@ auth {
# What is the URL to the JSON Web Key Set?
jwks https://login.microsoftonline.com/common/discovery/keys
-
- # What is the client secret? Certificates are not supported yet.
- secret something
# How long, in seconds, should cookies last?
expr 604800
+
+ # Which group IDs mean which departments?
+ depts {
+ dc3ab000-6352-4596-9f15-771e0b17f6f1 Y12
+ b006d3b8-2ab7-4038-9887-a8276f7ba8e6 Y11
+ a51fb4ab-704e-4c7a-b639-b84de0516e57 Y10
+ 4bae8dbe-ce80-4b5e-994f-d42f0307bd13 Y9
+ }
+
+ # User department overrides
+ udepts {
+ fa1f6b2b-0424-41db-bda0-13962abdadf4 Staff
+ a1a735c0-1ba8-4f08-b4d0-4c6f85552ac7 Staff
+ 34d4ee3c-6515-4e13-9679-57ccb9ca2835 Staff
+ }
}
# The following block contains some tweaks for performance.
diff --git a/endpoint_auth.go b/endpoint_auth.go
index 0ff3316..e27f74f 100644
--- a/endpoint_auth.go
+++ b/endpoint_auth.go
@@ -21,13 +21,9 @@
package main
import (
- "context"
- "encoding/json"
"errors"
"fmt"
"net/http"
- "net/url"
- "strings"
"time"
"github.com/MicahParks/keyfunc/v3"
@@ -46,9 +42,10 @@ const tokenLength = 20
* zero strings if it hasn't been configured correctly.
*/
type msclaimsT struct {
- Name string `json:"name"` /* Scope: profile */
- Email string `json:"email"` /* Scope: email */
- Oid string `json:"oid"` /* Scope: profile */
+ Name string `json:"name"`
+ Email string `json:"email"`
+ Oid string `json:"oid"`
+ Groups []string `json:"groups"`
jwt.RegisteredClaims
}
@@ -153,25 +150,14 @@ func handleAuth(w http.ResponseWriter, req *http.Request) (string, int, error) {
return "", http.StatusBadRequest, errCannotUnpackClaims
}
- authorizationCode := req.PostFormValue("code")
-
- accessToken, err := getAccessToken(req.Context(), authorizationCode)
- if err != nil {
- return "", -1, err
- }
-
- department, err := getDepartment(req.Context(), *(accessToken.Content))
- if err != nil {
- return "", -1, err
- }
-
- switch {
- case department == "SJ Co-Curricular Activities Office 松江课外项目办公室":
- department = staffDepartment
- case department == "Y9" || department == "Y10" ||
- department == "Y11" || department == "Y12":
- default:
- return "", http.StatusForbidden, errUnknownDepartment
+ var department string
+ var ok bool
+ department, ok = getDepartmentByUserIDOverride(claims.Oid)
+ if !ok {
+ department, ok = getDepartmentByGroups(claims.Groups)
+ if !ok {
+ return "", http.StatusBadRequest, errUnknownDepartment
+ }
}
cookieValue, err := randomString(tokenLength)
@@ -239,120 +225,20 @@ func setupJwks() error {
return nil
}
-/*
- * Fetch the department name of the user, mostly to identify which grade
- * a student is in. This expects an accessToken obtained from the OAUTH 2.0
- * token endpoint obtained via an authorization code. It might also be able
- * to use this as part of a hybrid flow that directly provides access tokens,
- * but this flow seems to be only usable for single-page applications according
- * to the Azure portal.
- */
-func getDepartment(ctx context.Context, accessToken string) (string, error) {
- req, err := http.NewRequestWithContext(
- ctx,
- http.MethodGet,
- "https://graph.microsoft.com/v1.0/me?$select=department",
- nil,
- )
- if err != nil {
- return "", wrapError(errCannotGetDepartment, err)
- }
- req.Header.Set("Authorization", "Bearer "+accessToken)
-
- client := &http.Client{} //exhaustruct:ignore
- resp, err := client.Do(req)
- if err != nil {
- return "", wrapError(errCannotGetDepartment, err)
- }
- defer resp.Body.Close()
-
- var departmentWrap struct {
- Department *string `json:"department"`
- }
-
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&departmentWrap)
- if err != nil {
- return "", wrapError(errCannotGetDepartment, err)
- }
-
- if departmentWrap.Department == nil {
- /*
- * This is probably because the response does not contain a
- * "department" field, which hopefully doesn't occur as we
- * have specified $select=department in the OData query.
- */
- return "", wrapError(
- errCannotGetDepartment,
- errInsufficientFields,
- )
+func getDepartmentByGroups(groups []string) (string, bool) {
+ for _, g := range groups {
+ d, ok := config.Auth.Departments[g]
+ if ok {
+ return d, true
+ }
}
-
- return *(departmentWrap.Department), nil
-}
-
-type accessTokenT struct {
- Content *string `json:"access_token"`
- Error *string `json:"error"`
- ErrorDescription *string `json:"error_description"`
- ErrorCodes *[]int `json:"error_codes"`
+ return "", false
}
-func getAccessToken(
- ctx context.Context,
- authorizationCode string,
-) (accessTokenT, error) {
- var accessToken accessTokenT
- v := url.Values{}
- v.Set("client_id", config.Auth.Client)
- v.Set("scope", "https://graph.microsoft.com/User.Read")
- v.Set("code", authorizationCode)
- v.Set("redirect_uri", config.URL+"/auth")
- v.Set("grant_type", "authorization_code")
- v.Set("client_secret", config.Auth.Secret)
- req, err := http.NewRequestWithContext(
- ctx,
- http.MethodPost,
- config.Auth.Token,
- strings.NewReader(v.Encode()),
- )
- if err != nil {
- return accessToken,
- wrapError(errCannotFetchAccessToken, err)
- }
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return accessToken,
- wrapError(errCannotFetchAccessToken, err)
- }
- defer resp.Body.Close()
-
- decoder := json.NewDecoder(resp.Body)
- err = decoder.Decode(&accessToken)
- if err != nil {
- return accessToken,
- wrapError(errCannotFetchAccessToken, err)
+func getDepartmentByUserIDOverride(userID string) (string, bool) {
+ d, ok := config.Auth.Udepts[userID]
+ if ok {
+ return d, true
}
- if accessToken.Error != nil || accessToken.ErrorCodes != nil ||
- accessToken.ErrorDescription != nil {
- if accessToken.Error == nil || accessToken.ErrorCodes == nil ||
- accessToken.ErrorDescription == nil {
- return accessToken, errCannotFetchAccessToken
- }
- return accessToken,
- fmt.Errorf(
- "%w: %v",
- errTokenEndpointReturnedError,
- *accessToken.ErrorDescription,
- )
- }
- if accessToken.Content == nil {
- return accessToken,
- fmt.Errorf(
- "error extracting access token: %w",
- errInsufficientFields,
- )
- }
-
- return accessToken, nil
+ return "", false
}
diff --git a/errors.go b/errors.go
index 3921d49..f81ff02 100644
--- a/errors.go
+++ b/errors.go
@@ -28,10 +28,7 @@ import (
var (
errCannotSetupJwks = errors.New("cannot set up jwks")
errInsufficientFields = errors.New("insufficient fields")
- errCannotGetDepartment = errors.New("cannot get department")
errUnknownDepartment = errors.New("unknown department")
- errCannotFetchAccessToken = errors.New("cannot fetch access token")
- errTokenEndpointReturnedError = errors.New("token endpoint returned error")
errCannotProcessConfig = errors.New("cannot process configuration file")
errCannotOpenConfig = errors.New("cannot open configuration file")
errCannotDecodeConfig = errors.New("cannot decode configuration file")