diff options
-rw-r--r-- | config.go | 124 | ||||
-rw-r--r-- | main.go | 8 |
2 files changed, 124 insertions, 8 deletions
@@ -113,36 +113,106 @@ var config struct { } var ( - errAuthFakeProd = errors.New("auth.fake not allowed in production mode") - errAuthFakeInvalid = errors.New("invalid auth.fake value") + errCannotProcessConfig = errors.New("cannot process configuration file") + errCannotOpenConfig = errors.New("cannot open configuration file") + errCannotDecodeConfig = errors.New("cannot decode configuration file") + errMissingConfigValue = errors.New("missing configuration value") + errIllegalConfig = errors.New("illegal configuration") ) -func fetchConfig(path string) error { +func fetchConfig(path string) (retErr error) { + defer func() { + if v := recover(); v != nil { + s, ok := v.(error) + if ok { + retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, s) + } + retErr = fmt.Errorf("%w: %v", errCannotProcessConfig, v) + return + } + if retErr != nil { + retErr = fmt.Errorf("%w: %w", errCannotProcessConfig, retErr) + return + } + }() + f, err := os.Open(path) if err != nil { - return fmt.Errorf("error opening configuration file: %w", err) + return fmt.Errorf("%w: %w", errCannotOpenConfig, err) } err = scfg.NewDecoder(bufio.NewReader(f)).Decode(&configWithPointers) if err != nil { - return fmt.Errorf("error decoding configuration file: %w", err) + return fmt.Errorf("%w: %w", errCannotDecodeConfig, err) } + if configWithPointers.URL == nil { + return fmt.Errorf("%w: url", errMissingConfigValue) + } config.URL = *(configWithPointers.URL) + + if configWithPointers.Prod == nil { + return fmt.Errorf("%w: prod", errMissingConfigValue) + } config.Prod = *(configWithPointers.Prod) + + if configWithPointers.Tmpl == nil { + return fmt.Errorf("%w: tmpl", errMissingConfigValue) + } config.Tmpl = *(configWithPointers.Tmpl) + + if configWithPointers.Static == nil { + return fmt.Errorf("%w: static", errMissingConfigValue) + } config.Static = *(configWithPointers.Static) + + if configWithPointers.Source == nil { + return fmt.Errorf("%w: source", errMissingConfigValue) + } config.Source = *(configWithPointers.Source) + + if configWithPointers.Listen.Proto == nil { + return fmt.Errorf("%w: listen.proto", errMissingConfigValue) + } config.Listen.Proto = *(configWithPointers.Listen.Proto) + + if configWithPointers.Listen.Net == nil { + return fmt.Errorf("%w: listen.net", errMissingConfigValue) + } config.Listen.Net = *(configWithPointers.Listen.Net) + + if configWithPointers.Listen.Addr == nil { + return fmt.Errorf("%w: listen.addr", errMissingConfigValue) + } config.Listen.Addr = *(configWithPointers.Listen.Addr) + + if configWithPointers.Listen.Trans == nil { + return fmt.Errorf("%w: listen.trans", errMissingConfigValue) + } config.Listen.Trans = *(configWithPointers.Listen.Trans) + if config.Listen.Trans == "tls" { + if configWithPointers.Listen.TLS.Cert == nil { + return fmt.Errorf("%w: listen.tls.cert", errMissingConfigValue) + } config.Listen.TLS.Cert = *(configWithPointers.Listen.TLS.Cert) + + if configWithPointers.Listen.TLS.Key == nil { + return fmt.Errorf("%w: listen.tls.key", errMissingConfigValue) + } config.Listen.TLS.Key = *(configWithPointers.Listen.TLS.Key) } + + if configWithPointers.DB.Type == nil { + return fmt.Errorf("%w: db.type", errMissingConfigValue) + } config.DB.Type = *(configWithPointers.DB.Type) + + if configWithPointers.DB.Conn == nil { + return fmt.Errorf("%w: db.conn", errMissingConfigValue) + } config.DB.Conn = *(configWithPointers.DB.Conn) + if configWithPointers.Auth.Fake == nil { config.Auth.Fake = 0 } else { @@ -152,22 +222,62 @@ func fetchConfig(path string) error { /* It's okay to set it to 0 in production */ case 4712, 9080: /* Don't use them unless you know what you're doing */ if config.Prod { - return errAuthFakeProd + return fmt.Errorf("%w: fake authentication is incompatible with production mode", errIllegalConfig) } log.Println("!!! WARNING: Fake authentication is enabled. Any WebSocket connection would have a fake account. This is a HUGE security hole. You should only use this while benchmarking.") default: - return errAuthFakeInvalid + return fmt.Errorf("%w: invalid option for auth.fake", errIllegalConfig) } } + + if configWithPointers.Auth.Client == nil { + return fmt.Errorf("%w: auth.client", errMissingConfigValue) + } config.Auth.Client = *(configWithPointers.Auth.Client) + + if configWithPointers.Auth.Authorize == nil { + return fmt.Errorf("%w: auth.authorize", errMissingConfigValue) + } config.Auth.Authorize = *(configWithPointers.Auth.Authorize) + + if configWithPointers.Auth.Jwks == nil { + return fmt.Errorf("%w: auth.jwks", errMissingConfigValue) + } config.Auth.Jwks = *(configWithPointers.Auth.Jwks) + + if configWithPointers.Auth.Token == nil { + return fmt.Errorf("%w: auth.token", errMissingConfigValue) + } config.Auth.Token = *(configWithPointers.Auth.Token) + + if configWithPointers.Auth.Secret == nil { + return fmt.Errorf("%w: auth.secret", errMissingConfigValue) + } config.Auth.Secret = *(configWithPointers.Auth.Secret) + + if configWithPointers.Auth.Expr == nil { + return fmt.Errorf("%w: auth.expr", errMissingConfigValue) + } config.Auth.Expr = *(configWithPointers.Auth.Expr) + + if configWithPointers.Perf.MessageArgumentsCap == nil { + return fmt.Errorf("%w: perf.msg_args_cap", errMissingConfigValue) + } config.Perf.MessageArgumentsCap = *(configWithPointers.Perf.MessageArgumentsCap) + + if configWithPointers.Perf.MessageBytesCap == nil { + return fmt.Errorf("%w: perf.msg_bytes_cap", errMissingConfigValue) + } config.Perf.MessageBytesCap = *(configWithPointers.Perf.MessageBytesCap) + + if configWithPointers.Perf.ReadHeaderTimeout == nil { + return fmt.Errorf("%w: perf.read_header_timeout", errMissingConfigValue) + } config.Perf.ReadHeaderTimeout = *(configWithPointers.Perf.ReadHeaderTimeout) + + if configWithPointers.Perf.CourseUpdateInterval == nil { + return fmt.Errorf("%w: perf.course_update_interval", errMissingConfigValue) + } config.Perf.CourseUpdateInterval = *(configWithPointers.Perf.CourseUpdateInterval) return nil @@ -22,6 +22,7 @@ package main import ( "crypto/tls" + "flag" "html/template" "log" "net" @@ -34,7 +35,12 @@ var tmpl *template.Template func main() { var err error - if err := fetchConfig("cca.scfg"); err != nil { + var configPath string + + flag.StringVar(&configPath, "config", "cca.scfg", "path to configuration file") + flag.Parse() + + if err := fetchConfig(configPath); err != nil { log.Fatal(err) } |