summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.go124
-rw-r--r--main.go8
2 files changed, 124 insertions, 8 deletions
diff --git a/config.go b/config.go
index 8761a36..e5bea02 100644
--- a/config.go
+++ b/config.go
@@ -113,36 +113,106 @@ var config struct {
}
var (
- errAuthFakeProd = errors.New("auth.fake not allowed in production mode")
- errAuthFakeInvalid = errors.New("invalid auth.fake value")
+ errCannotProcessConfig = errors.New("cannot process configuration file")
+ errCannotOpenConfig = errors.New("cannot open configuration file")
+ errCannotDecodeConfig = errors.New("cannot decode configuration file")
+ errMissingConfigValue = errors.New("missing configuration value")
+ errIllegalConfig = errors.New("illegal configuration")
)
-func fetchConfig(path string) error {
+func fetchConfig(path string) (retErr error) {
+ defer func() {
+ if v := recover(); v != nil {
+ s, ok := v.(error)
+ if ok {
+ retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, s)
+ }
+ retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v)
+ return
+ }
+ if retErr != nil {
+ retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, retErr)
+ return
+ }
+ }()
+
f, err := os.Open(path)
if err != nil {
- return fmt.Errorf("error opening configuration file: %w", err)
+ return fmt.Errorf("%w: %w", errCannotOpenConfig, err)
}
err = scfg.NewDecoder(bufio.NewReader(f)).Decode(&configWithPointers)
if err != nil {
- return fmt.Errorf("error decoding configuration file: %w", err)
+ return fmt.Errorf("%w: %w", errCannotDecodeConfig, err)
}
+ if configWithPointers.URL == nil {
+ return fmt.Errorf("%w: url", errMissingConfigValue)
+ }
config.URL = *(configWithPointers.URL)
+
+ if configWithPointers.Prod == nil {
+ return fmt.Errorf("%w: prod", errMissingConfigValue)
+ }
config.Prod = *(configWithPointers.Prod)
+
+ if configWithPointers.Tmpl == nil {
+ return fmt.Errorf("%w: tmpl", errMissingConfigValue)
+ }
config.Tmpl = *(configWithPointers.Tmpl)
+
+ if configWithPointers.Static == nil {
+ return fmt.Errorf("%w: static", errMissingConfigValue)
+ }
config.Static = *(configWithPointers.Static)
+
+ if configWithPointers.Source == nil {
+ return fmt.Errorf("%w: source", errMissingConfigValue)
+ }
config.Source = *(configWithPointers.Source)
+
+ if configWithPointers.Listen.Proto == nil {
+ return fmt.Errorf("%w: listen.proto", errMissingConfigValue)
+ }
config.Listen.Proto = *(configWithPointers.Listen.Proto)
+
+ if configWithPointers.Listen.Net == nil {
+ return fmt.Errorf("%w: listen.net", errMissingConfigValue)
+ }
config.Listen.Net = *(configWithPointers.Listen.Net)
+
+ if configWithPointers.Listen.Addr == nil {
+ return fmt.Errorf("%w: listen.addr", errMissingConfigValue)
+ }
config.Listen.Addr = *(configWithPointers.Listen.Addr)
+
+ if configWithPointers.Listen.Trans == nil {
+ return fmt.Errorf("%w: listen.trans", errMissingConfigValue)
+ }
config.Listen.Trans = *(configWithPointers.Listen.Trans)
+
if config.Listen.Trans == "tls" {
+ if configWithPointers.Listen.TLS.Cert == nil {
+ return fmt.Errorf("%w: listen.tls.cert", errMissingConfigValue)
+ }
config.Listen.TLS.Cert = *(configWithPointers.Listen.TLS.Cert)
+
+ if configWithPointers.Listen.TLS.Key == nil {
+ return fmt.Errorf("%w: listen.tls.key", errMissingConfigValue)
+ }
config.Listen.TLS.Key = *(configWithPointers.Listen.TLS.Key)
}
+
+ if configWithPointers.DB.Type == nil {
+ return fmt.Errorf("%w: db.type", errMissingConfigValue)
+ }
config.DB.Type = *(configWithPointers.DB.Type)
+
+ if configWithPointers.DB.Conn == nil {
+ return fmt.Errorf("%w: db.conn", errMissingConfigValue)
+ }
config.DB.Conn = *(configWithPointers.DB.Conn)
+
if configWithPointers.Auth.Fake == nil {
config.Auth.Fake = 0
} else {
@@ -152,22 +222,62 @@ func fetchConfig(path string) error {
/* It's okay to set it to 0 in production */
case 4712, 9080: /* Don't use them unless you know what you're doing */
if config.Prod {
- return errAuthFakeProd
+ return fmt.Errorf("%w: fake authentication is incompatible with production mode", errIllegalConfig)
}
log.Println("!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.")
default:
- return errAuthFakeInvalid
+ return fmt.Errorf("%w: invalid option for auth.fake", errIllegalConfig)
}
}
+
+ if configWithPointers.Auth.Client == nil {
+ return fmt.Errorf("%w: auth.client", errMissingConfigValue)
+ }
config.Auth.Client = *(configWithPointers.Auth.Client)
+
+ if configWithPointers.Auth.Authorize == nil {
+ return fmt.Errorf("%w: auth.authorize", errMissingConfigValue)
+ }
config.Auth.Authorize = *(configWithPointers.Auth.Authorize)
+
+ if configWithPointers.Auth.Jwks == nil {
+ return fmt.Errorf("%w: auth.jwks", errMissingConfigValue)
+ }
config.Auth.Jwks = *(configWithPointers.Auth.Jwks)
+
+ if configWithPointers.Auth.Token == nil {
+ return fmt.Errorf("%w: auth.token", errMissingConfigValue)
+ }
config.Auth.Token = *(configWithPointers.Auth.Token)
+
+ if configWithPointers.Auth.Secret == nil {
+ return fmt.Errorf("%w: auth.secret", errMissingConfigValue)
+ }
config.Auth.Secret = *(configWithPointers.Auth.Secret)
+
+ if configWithPointers.Auth.Expr == nil {
+ return fmt.Errorf("%w: auth.expr", errMissingConfigValue)
+ }
config.Auth.Expr = *(configWithPointers.Auth.Expr)
+
+ if configWithPointers.Perf.MessageArgumentsCap == nil {
+ return fmt.Errorf("%w: perf.msg_args_cap", errMissingConfigValue)
+ }
config.Perf.MessageArgumentsCap = *(configWithPointers.Perf.MessageArgumentsCap)
+
+ if configWithPointers.Perf.MessageBytesCap == nil {
+ return fmt.Errorf("%w: perf.msg_bytes_cap", errMissingConfigValue)
+ }
config.Perf.MessageBytesCap = *(configWithPointers.Perf.MessageBytesCap)
+
+ if configWithPointers.Perf.ReadHeaderTimeout == nil {
+ return fmt.Errorf("%w: perf.read_header_timeout", errMissingConfigValue)
+ }
config.Perf.ReadHeaderTimeout = *(configWithPointers.Perf.ReadHeaderTimeout)
+
+ if configWithPointers.Perf.CourseUpdateInterval == nil {
+ return fmt.Errorf("%w: perf.course_update_interval", errMissingConfigValue)
+ }
config.Perf.CourseUpdateInterval = *(configWithPointers.Perf.CourseUpdateInterval)
return nil
diff --git a/main.go b/main.go
index 412caa3..a185de9 100644
--- a/main.go
+++ b/main.go
@@ -22,6 +22,7 @@ package main
import (
"crypto/tls"
+ "flag"
"html/template"
"log"
"net"
@@ -34,7 +35,12 @@ var tmpl *template.Template
func main() {
var err error
- if err := fetchConfig("cca.scfg"); err != nil {
+ var configPath string
+
+ flag.StringVar(&configPath, "config", "cca.scfg", "path to configuration file")
+ flag.Parse()
+
+ if err := fetchConfig(configPath); err != nil {
log.Fatal(err)
}