package cli import ( "flag" "fmt" "os" "os/exec" "path/filepath" "strings" "git.theodohertyfamily.com/tools/wg-wrap/internal/config" "git.theodohertyfamily.com/tools/wg-wrap/internal/namespace" "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" "git.theodohertyfamily.com/tools/wg-wrap/internal/wireguard" "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf" ) type App struct { Args []string ConfigDir string // Optional override for profile storage location RuntimeBaseDir string // Optional override for namespace/PID tracking } func NewApp(args []string) *App { return &App{Args: args} } func (a *App) getPathManager() *paths.PathManager { return paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir) } func (a *App) Route() error { for i, arg := range a.Args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { return fmt.Errorf("argument %d contains null byte at position %d", i, j) } } } if len(a.Args) > 1 { switch a.Args[1] { case "show-config": return a.showConfig() } } if len(a.Args) > 1 && a.Args[1] == "profile" { return a.handleProfileCmd() } return a.Run() } func (a *App) Run() error { if len(a.Args) > 1 { switch a.Args[1] { case "test-ns": if !namespace.IsIsolated() { if err := namespace.Bootstrap(); err != nil { return fmt.Errorf("bootstrap failed: %w", err) } } ok, msg := namespace.VerifyIsolation() if !ok { return fmt.Errorf("isolation check failed: %s", msg) } fmt.Println("Isolation Verified: OK") return nil case "test-args": if !namespace.IsIsolated() { if err := namespace.Bootstrap(); err != nil { return fmt.Errorf("bootstrap failed: %w", err) } } return namespace.VerifyArguments(a.Args) } } cfg := &config.Config{} fs := flag.NewFlagSet("wg-wrap", flag.ExitOnError) fs.StringVar(&cfg.Profile, "profile", "", "WireGuard profile to use") fs.StringVar(&cfg.DNSServer, "dns-server", "", "Override DNS server to use") args := a.Args[1:] sepIdx := -1 for i, arg := range args { if arg == "--" { sepIdx = i break } } var flagsToParse []string if sepIdx != -1 { flagsToParse = args[:sepIdx] cfg.Command = args[sepIdx+1:] } else { flagsToParse = args } err := fs.Parse(flagsToParse) if err != nil { return fmt.Errorf("error parsing flags: %w", err) } if sepIdx == -1 { cfg.Command = fs.Args() } if len(cfg.Command) == 0 { return fmt.Errorf("no command provided. use --help for usage") } if cfg.Profile == "" { cfg.Profile = "default" } if !IsValidProfileName(cfg.Profile) { return fmt.Errorf("invalid profile name: %q (only alphanumeric characters, underscores, and hyphens are allowed)", cfg.Profile) } if namespace.IsIsolated() { return a.ExecuteCommand(cfg) } pm := a.getPathManager() // Preserve the host runtime base dir in the environment before bootstrapping _ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir()) // Acquire startup lock to prevent concurrent bootstrap/joining races lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) if lockErr == nil { defer namespace.ReleaseProfileLock(lockFile) } // Before bootstrapping, see if an active namespace/process for the profile exists. // If yes, we can join it! activePid, err := namespace.FindActiveProfilePid(pm, cfg.Profile) if err == nil && activePid > 0 { // Release the lock before executing the command to allow others to join namespace.ReleaseProfileLock(lockFile) if err := namespace.BootstrapJoin(activePid); err != nil { return fmt.Errorf("failed to join existing namespace: %w", err) } return nil } if err := namespace.Bootstrap(); err != nil { return fmt.Errorf("bootstrap failed: %w", err) } return nil } func (a *App) ExecuteCommand(cfg *config.Config) error { if !namespace.IsIsolated() { return fmt.Errorf("ExecuteCommand called without namespace isolation") } pm := a.getPathManager() // Acquire execution lock during configuration and startup inside the namespace lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) if lockErr == nil { defer namespace.ReleaseProfileLock(lockFile) } if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil { fmt.Printf("failed to prune stale pids: %v\n", err) } if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil { return fmt.Errorf("failed to register process: %w", err) } defer func() { // Re-acquire lock for the entire cleanup sequence to ensure atomic unregister and unpin cleanupLock, cleanupErr := namespace.AcquireProfileLock(pm, cfg.Profile) if cleanupErr == nil { // 1. Unregister the process first. if err := namespace.UnregisterProcess(pm, cfg.Profile); err != nil { fmt.Printf("failed to unregister process: %v\n", err) } // 2. Prune and check if we are the last process. if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil { fmt.Printf("failed to prune stale pids during cleanup: %v\n", err) } last, lastErr := namespace.IsLastProcess(pm, cfg.Profile) if lastErr == nil && last { fmt.Printf("Last process exiting. Cleaning up profile %s...\n", cfg.Profile) if err := namespace.UnpinNamespace(pm, cfg.Profile); err != nil { fmt.Printf("failed to unpin namespace: %v\n", err) } } namespace.ReleaseProfileLock(cleanupLock) } else { // Fallback if lock fails to ensure we still unregister if err := namespace.UnregisterProcess(pm, cfg.Profile); err != nil { fmt.Printf("failed to unregister process: %v\n", err) } } }() if os.Getenv("WG_WRAP_JOINED") == "1" { fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile) } else { fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) // Parse the profile configuration profilesDir := pm.ConfigDir() profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") // Create tunnel if the file exists if _, err := os.Stat(profilePath); err == nil { wgCfg, err := wgconf.Parse(profilePath) if err != nil { return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) } // Start the WireGuard userspace device & routing table setup dnsServer := cfg.DNSServer if dnsServer == "" { dnsServer = wgCfg.DNS } if dnsServer == "" { dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks hasDefaultRoute := false for _, peer := range wgCfg.Peers { for _, ip := range peer.AllowedIPs { trimmed := strings.TrimSpace(ip) if trimmed == "0.0.0.0/0" || trimmed == "::/0" { hasDefaultRoute = true break } } if hasDefaultRoute { break } } if !hasDefaultRoute { fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") } } tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) if err != nil { return fmt.Errorf("failed to start WireGuard tunnel: %w", err) } defer tunnel.Close() // Pin the namespace so others can join it if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { fmt.Printf("warning: failed to pin namespace: %v\n", err) } } else { // If profile is not default or it was explicitly requested but doesn't exist, we error if cfg.Profile != "default" { return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) } fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } } // We can now release the startup lock and execute the command namespace.ReleaseProfileLock(lockFile) cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Env = os.Environ() if err := cmd.Run(); err != nil { return fmt.Errorf("command execution failed: %w", err) } return nil } func (a *App) handleProfileCmd() error { if len(a.Args) < 3 { return fmt.Errorf("usage: wg-wrap profile [args]") } subCmd := a.Args[2] switch subCmd { case "list": return a.handleProfileList() case "import": if len(a.Args) < 4 { return fmt.Errorf("usage: wg-wrap profile import [name]") } var name string if len(a.Args) > 4 { name = a.Args[4] } return a.handleProfileImport(a.Args[3], name) case "configure": if len(a.Args) < 4 { return fmt.Errorf("usage: wg-wrap profile configure ") } return a.handleProfileConfigure(a.Args[3]) case "delete": if len(a.Args) < 4 { return fmt.Errorf("usage: wg-wrap profile delete ") } return a.handleProfileDelete(a.Args[3]) case "stop": if len(a.Args) < 4 { return fmt.Errorf("usage: wg-wrap profile stop ") } if !IsValidProfileName(a.Args[3]) { return fmt.Errorf("invalid profile name: %q", a.Args[3]) } fmt.Printf("Stopping profile %s and unpinning namespace...\n", a.Args[3]) pm := a.getPathManager() if err := namespace.UnpinNamespace(pm, a.Args[3]); err != nil { return fmt.Errorf("failed to stop profile: %w", err) } fmt.Printf("Profile %s stopped and unpinned.\n", a.Args[3]) return nil default: return fmt.Errorf("unknown profile subcommand: %s", subCmd) } } func (a *App) handleProfileConfigure(name string) error { if !IsValidProfileName(name) { return fmt.Errorf("invalid profile name: %q", name) } profilesDir := a.getPathManager().ConfigDir() profilePath := filepath.Join(profilesDir, name+".conf") if _, err := os.Stat(profilePath); os.IsNotExist(err) { return fmt.Errorf("profile '%s' not found", name) } editor := os.Getenv("EDITOR") if editor == "" { editor = "vi" // Sensible fallback } // Split editor string into command and arguments (e.g., "vim -R" -> ["vim", "-R"]) editorArgs := strings.Fields(editor) if len(editorArgs) == 0 { editorArgs = []string{"vi"} } fmt.Printf("Opening profile %s in default editor (%s)...\n", name, editor) cmd := exec.Command(editorArgs[0], append(editorArgs[1:], profilePath)...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { return fmt.Errorf("editor failed: %w", err) } return nil } func (a *App) handleProfileList() error { profilesDir := a.getPathManager().ConfigDir() entries, err := os.ReadDir(profilesDir) if err != nil { return fmt.Errorf("failed to read profiles directory %s: %w", profilesDir, err) } fmt.Println("Available profiles:") found := false for _, entry := range entries { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".conf") { name := strings.TrimSuffix(entry.Name(), ".conf") fmt.Printf("- %s\n", name) found = true } } if !found { fmt.Println(" (no profiles found)") } return nil } func (a *App) handleProfileImport(srcPath string, name string) error { profilesDir := a.getPathManager().ConfigDir() if err := os.MkdirAll(profilesDir, 0755); err != nil { return fmt.Errorf("failed to create profiles directory: %w", err) } if _, err := wgconf.Parse(srcPath); err != nil { return fmt.Errorf("invalid WireGuard configuration at %s: %w", srcPath, err) } if name == "" { baseName := filepath.Base(srcPath) name = strings.TrimSuffix(baseName, filepath.Ext(baseName)) if name == "" { return fmt.Errorf("invalid source filename") } } if !IsValidProfileName(name) { return fmt.Errorf("invalid profile name: %q", name) } destPath := filepath.Join(profilesDir, name+".conf") if _, err := os.Stat(destPath); err == nil { return fmt.Errorf("profile '%s' already exists", name) } data, err := os.ReadFile(srcPath) if err != nil { return fmt.Errorf("failed to read source file: %w", err) } if err := os.WriteFile(destPath, data, 0644); err != nil { return fmt.Errorf("failed to write profile to %s: %w", destPath, err) } fmt.Printf("Profile '%s' imported successfully to %s\n", name, destPath) return nil } func (a *App) handleProfileDelete(name string) error { if !IsValidProfileName(name) { return fmt.Errorf("invalid profile name: %q", name) } profilesDir := a.getPathManager().ConfigDir() destPath := filepath.Join(profilesDir, name+".conf") if _, err := os.Stat(destPath); os.IsNotExist(err) { return fmt.Errorf("profile '%s' not found", name) } if err := os.Remove(destPath); err != nil { return fmt.Errorf("failed to delete profile %s: %w", name, err) } fmt.Printf("Profile '%s' deleted successfully\n", name) return nil } func (a *App) showConfig() error { cfg := &config.Config{} fs := flag.NewFlagSet("wg-wrap", flag.ExitOnError) fs.StringVar(&cfg.Profile, "profile", "default", "WireGuard profile to use") fs.StringVar(&cfg.DNSServer, "dns-server", "", "Override DNS server to use") if len(a.Args) > 2 { _ = fs.Parse(a.Args[2:]) } if !IsValidProfileName(cfg.Profile) { return fmt.Errorf("invalid profile name: %q", cfg.Profile) } pm := paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir) profilePath := pm.ProfileNamespacePath(cfg.Profile) pidsPath := pm.ProfilePidsDir(cfg.Profile) fmt.Printf("Configuration:\n") fmt.Printf(" Profile: %s\n", cfg.Profile) fmt.Printf(" DNS Server: %s\n", cfg.DNSServer) fmt.Printf(" Config Dir: %s\n", pm.ConfigDir()) fmt.Printf(" Runtime Base: %s\n", pm.RuntimeBaseDir()) fmt.Printf(" Profile Path: %s\n", profilePath) fmt.Printf(" PIDs Path: %s\n", pidsPath) fmt.Printf(" Isolated: %v\n", namespace.IsIsolated()) fmt.Printf(" UID: %d\n", os.Getuid()) return nil } // IsValidProfileName checks if a WireGuard profile name is safe and valid. // It allows only alphanumeric characters, underscores, and hyphens, and prevents // directory traversal attacks and hidden files. func IsValidProfileName(name string) bool { if name == "" { return false } for _, r := range name { if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' && r != '-' { return false } } if name == "." || name == ".." || strings.HasPrefix(name, "-") { return false } return true }