diff options
Diffstat (limited to 'internal/cli/cli.go')
| -rw-r--r-- | internal/cli/cli.go | 91 |
1 files changed, 89 insertions, 2 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index af408c5..85b9ae3 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -117,17 +117,29 @@ func (a *App) Run() error { cfg.Profile = "default" } + if !IsValidProfileName(cfg.Profile) { + return fmt.Errorf("invalid profile name: %q (only alphanumeric characters, underscores, and hyphens are allowed)", cfg.Profile) + } + if namespace.IsIsolated() { return a.ExecuteCommand(cfg) } + pm := a.getPathManager() + + // Acquire startup lock to prevent concurrent bootstrap/joining races + lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) + if lockErr == nil { + defer namespace.ReleaseProfileLock(lockFile) + } + // Before bootstrapping, see if an active namespace/process for the profile exists. // If yes, we can join it! - pm := a.getPathManager() joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile) if err == nil && joined { // We have joined the active namespace (user, mnt, net). - // We can now execute the command immediately in this context! + // Release the lock before executing the command to allow others to join + namespace.ReleaseProfileLock(lockFile) return a.ExecuteCommand(cfg) } @@ -145,10 +157,16 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { pm := a.getPathManager() + // Acquire execution lock during configuration and startup inside the namespace + lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) + if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil { fmt.Printf("failed to prune stale pids: %v\n", err) } if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("failed to register process: %w", err) } @@ -174,6 +192,9 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { if _, err := os.Stat(profilePath); err == nil { wgCfg, err := wgconf.Parse(profilePath) if err != nil { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) } @@ -184,10 +205,29 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } if dnsServer == "" { dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks + hasDefaultRoute := false + for _, peer := range wgCfg.Peers { + for _, ip := range peer.AllowedIPs { + trimmed := strings.TrimSpace(ip) + if trimmed == "0.0.0.0/0" || trimmed == "::/0" { + hasDefaultRoute = true + break + } + } + if hasDefaultRoute { + break + } + } + if !hasDefaultRoute { + fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") + } } tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) if err != nil { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("failed to start WireGuard tunnel: %w", err) } defer tunnel.Close() @@ -199,11 +239,19 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } else { // If profile is not default or it was explicitly requested but doesn't exist, we error if cfg.Profile != "default" { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) } fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } + // Setup and initialization are complete. We can now safely release the startup lock! + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } + cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -249,6 +297,9 @@ func (a *App) handleProfileCmd() error { if len(a.Args) < 4 { return fmt.Errorf("usage: wg-wrap profile stop <name>") } + if !IsValidProfileName(a.Args[3]) { + return fmt.Errorf("invalid profile name: %q", a.Args[3]) + } fmt.Printf("Stopping profile %s and unpinning namespace...\n", a.Args[3]) pm := a.getPathManager() if err := namespace.UnpinNamespace(pm, a.Args[3]); err != nil { @@ -262,6 +313,9 @@ func (a *App) handleProfileCmd() error { } func (a *App) handleProfileConfigure(name string) error { + if !IsValidProfileName(name) { + return fmt.Errorf("invalid profile name: %q", name) + } profilesDir := a.getPathManager().ConfigDir() profilePath := filepath.Join(profilesDir, name+".conf") if _, err := os.Stat(profilePath); os.IsNotExist(err) { @@ -329,7 +383,15 @@ func (a *App) handleProfileImport(srcPath string, name string) error { } } + if !IsValidProfileName(name) { + return fmt.Errorf("invalid profile name: %q", name) + } + destPath := filepath.Join(profilesDir, name+".conf") + if _, err := os.Stat(destPath); err == nil { + return fmt.Errorf("profile '%s' already exists", name) + } + data, err := os.ReadFile(srcPath) if err != nil { return fmt.Errorf("failed to read source file: %w", err) @@ -344,6 +406,9 @@ func (a *App) handleProfileImport(srcPath string, name string) error { } func (a *App) handleProfileDelete(name string) error { + if !IsValidProfileName(name) { + return fmt.Errorf("invalid profile name: %q", name) + } profilesDir := a.getPathManager().ConfigDir() destPath := filepath.Join(profilesDir, name+".conf") @@ -369,6 +434,10 @@ func (a *App) showConfig() error { _ = fs.Parse(a.Args[2:]) } + if !IsValidProfileName(cfg.Profile) { + return fmt.Errorf("invalid profile name: %q", cfg.Profile) + } + pm := paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir) profilePath := pm.ProfileNamespacePath(cfg.Profile) pidsPath := pm.ProfilePidsDir(cfg.Profile) @@ -384,3 +453,21 @@ func (a *App) showConfig() error { fmt.Printf(" UID: %d\n", os.Getuid()) return nil } + +// IsValidProfileName checks if a WireGuard profile name is safe and valid. +// It allows only alphanumeric characters, underscores, and hyphens, and prevents +// directory traversal attacks and hidden files. +func IsValidProfileName(name string) bool { + if name == "" { + return false + } + for _, r := range name { + if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' && r != '-' { + return false + } + } + if name == "." || name == ".." || strings.HasPrefix(name, "-") { + return false + } + return true +} |
