diff options
Diffstat (limited to 'internal/cli/cli.go')
| -rw-r--r-- | internal/cli/cli.go | 103 |
1 files changed, 56 insertions, 47 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 9b3409e..87ee34f 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -127,6 +127,9 @@ func (a *App) Run() error { pm := a.getPathManager() + // Preserve the host runtime base dir in the environment before bootstrapping + _ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir()) + // Acquire startup lock to prevent concurrent bootstrap/joining races lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) if lockErr == nil { @@ -135,12 +138,14 @@ func (a *App) Run() error { // Before bootstrapping, see if an active namespace/process for the profile exists. // If yes, we can join it! - joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile) - if err == nil && joined { - // We have joined the active namespace (user, mnt, net). + activePid, err := namespace.FindActiveProfilePid(pm, cfg.Profile) + if err == nil && activePid > 0 { // Release the lock before executing the command to allow others to join namespace.ReleaseProfileLock(lockFile) - return a.ExecuteCommand(cfg) + if err := namespace.BootstrapJoin(activePid); err != nil { + return fmt.Errorf("failed to join existing namespace: %w", err) + } + return nil } if err := namespace.Bootstrap(); err != nil { @@ -196,60 +201,64 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } }() - fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) + if os.Getenv("WG_WRAP_JOINED") == "1" { + fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile) + } else { + fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) - // Parse the profile configuration - profilesDir := pm.ConfigDir() - profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") + // Parse the profile configuration + profilesDir := pm.ConfigDir() + profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") - // Create tunnel if the file exists - if _, err := os.Stat(profilePath); err == nil { - wgCfg, err := wgconf.Parse(profilePath) - if err != nil { - return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) - } + // Create tunnel if the file exists + if _, err := os.Stat(profilePath); err == nil { + wgCfg, err := wgconf.Parse(profilePath) + if err != nil { + return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) + } - // Start the WireGuard userspace device & routing table setup - dnsServer := cfg.DNSServer - if dnsServer == "" { - dnsServer = wgCfg.DNS - } - if dnsServer == "" { - dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks - hasDefaultRoute := false - for _, peer := range wgCfg.Peers { - for _, ip := range peer.AllowedIPs { - trimmed := strings.TrimSpace(ip) - if trimmed == "0.0.0.0/0" || trimmed == "::/0" { - hasDefaultRoute = true + // Start the WireGuard userspace device & routing table setup + dnsServer := cfg.DNSServer + if dnsServer == "" { + dnsServer = wgCfg.DNS + } + if dnsServer == "" { + dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks + hasDefaultRoute := false + for _, peer := range wgCfg.Peers { + for _, ip := range peer.AllowedIPs { + trimmed := strings.TrimSpace(ip) + if trimmed == "0.0.0.0/0" || trimmed == "::/0" { + hasDefaultRoute = true + break + } + } + if hasDefaultRoute { break } } - if hasDefaultRoute { - break + if !hasDefaultRoute { + fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") } } - if !hasDefaultRoute { - fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") - } - } - tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) - if err != nil { - return fmt.Errorf("failed to start WireGuard tunnel: %w", err) - } - defer tunnel.Close() + tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) + if err != nil { + return fmt.Errorf("failed to start WireGuard tunnel: %w", err) + } + defer tunnel.Close() - // Pin the namespace so others can join it - if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { - fmt.Printf("warning: failed to pin namespace: %v\n", err) - } - } else { - // If profile is not default or it was explicitly requested but doesn't exist, we error - if cfg.Profile != "default" { - return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) + // Pin the namespace so others can join it + if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { + fmt.Printf("warning: failed to pin namespace: %v\n", err) + } + } else { + // If profile is not default or it was explicitly requested but doesn't exist, we error + if cfg.Profile != "default" { + return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) + } + fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } - fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } // We can now release the startup lock and execute the command |
