diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/cli/cli.go | 57 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 58 |
2 files changed, 89 insertions, 26 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 85b9ae3..9b3409e 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -159,25 +159,39 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { // Acquire execution lock during configuration and startup inside the namespace lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) + if lockErr == nil { + defer namespace.ReleaseProfileLock(lockFile) + } if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil { fmt.Printf("failed to prune stale pids: %v\n", err) } if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil { - if lockErr == nil { - namespace.ReleaseProfileLock(lockFile) - } return fmt.Errorf("failed to register process: %w", err) } defer func() { - if err := namespace.UnregisterProcess(pm, cfg.Profile); err != nil { - fmt.Printf("failed to unregister process: %v\n", err) - } - if last, err := namespace.IsLastProcess(pm, cfg.Profile); err == nil && last { - fmt.Printf("Last process exiting. Cleaning up profile %s...\n", cfg.Profile) - if err := namespace.UnpinNamespace(pm, cfg.Profile); err != nil { - fmt.Printf("failed to unpin namespace: %v\n", err) + // Re-acquire lock for the entire cleanup sequence to ensure atomic unregister and unpin + cleanupLock, cleanupErr := namespace.AcquireProfileLock(pm, cfg.Profile) + if cleanupErr == nil { + // Check if we are the last active process before unregistering + last, lastErr := namespace.IsLastProcess(pm, cfg.Profile) + + if err := namespace.UnregisterProcess(pm, cfg.Profile); err != nil { + fmt.Printf("failed to unregister process: %v\n", err) + } + + if lastErr == nil && last { + fmt.Printf("Last process exiting. Cleaning up profile %s...\n", cfg.Profile) + if err := namespace.UnpinNamespace(pm, cfg.Profile); err != nil { + fmt.Printf("failed to unpin namespace: %v\n", err) + } + } + namespace.ReleaseProfileLock(cleanupLock) + } else { + // Fallback if lock fails to ensure we still unregister + if err := namespace.UnregisterProcess(pm, cfg.Profile); err != nil { + fmt.Printf("failed to unregister process: %v\n", err) } } }() @@ -192,9 +206,6 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { if _, err := os.Stat(profilePath); err == nil { wgCfg, err := wgconf.Parse(profilePath) if err != nil { - if lockErr == nil { - namespace.ReleaseProfileLock(lockFile) - } return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) } @@ -225,9 +236,6 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) if err != nil { - if lockErr == nil { - namespace.ReleaseProfileLock(lockFile) - } return fmt.Errorf("failed to start WireGuard tunnel: %w", err) } defer tunnel.Close() @@ -239,18 +247,13 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } else { // If profile is not default or it was explicitly requested but doesn't exist, we error if cfg.Profile != "default" { - if lockErr == nil { - namespace.ReleaseProfileLock(lockFile) - } return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) } fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } - // Setup and initialization are complete. We can now safely release the startup lock! - if lockErr == nil { - namespace.ReleaseProfileLock(lockFile) - } + // We can now release the startup lock and execute the command + namespace.ReleaseProfileLock(lockFile) cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...) cmd.Stdin = os.Stdin @@ -327,9 +330,15 @@ func (a *App) handleProfileConfigure(name string) error { editor = "vi" // Sensible fallback } + // Split editor string into command and arguments (e.g., "vim -R" -> ["vim", "-R"]) + editorArgs := strings.Fields(editor) + if len(editorArgs) == 0 { + editorArgs = []string{"vi"} + } + fmt.Printf("Opening profile %s in default editor (%s)...\n", name, editor) - cmd := exec.Command(editor, profilePath) + cmd := exec.Command(editorArgs[0], append(editorArgs[1:], profilePath)...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 48bd562..3f17392 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -42,6 +42,11 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (*Tunnel, error) { fmt.Printf("warning: failed to make mount namespace private: %v\n", err) } + // Block host services (D-Bus, nscd) to prevent name resolution leak bypasses + if err := BlockHostServices(); err != nil { + fmt.Printf("warning: failed to block host services: %v\n", err) + } + tunDev, err := tun.CreateTUN(tunName, mtu) if err != nil { return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err) @@ -254,13 +259,62 @@ func ConfigureResolvConf(dns string) error { // 2. Make the mount private to ensure it doesn't propagate back to the host // and to satisfy kernel requirements for mount transitions in some environments. - if err := unix.Mount("/etc/resolv.conf", "/etc/resolv.conf", "", unix.MS_REMOUNT|unix.MS_BIND|unix.MS_PRIVATE, ""); err != nil { - return fmt.Errorf("failed to make /etc/resolv.conf mount private: %w", err) + // We do this by applying MS_PRIVATE in a separate mount call. + if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { + // If MS_PRIVATE fails, we can log a warning but proceed since / is already private + fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err) } return nil } +// BlockHostServices blocks local D-Bus and name service cache daemon (nscd) sockets +// inside the mount namespace. This prevents glibc from bypassing the network namespace +// isolation via host services (e.g. systemd-resolved via D-Bus). +func BlockHostServices() error { + tmpDir, err := os.MkdirTemp("", "wg-wrap-block-") + if err != nil { + return fmt.Errorf("failed to create temp dir: %w", err) + } + defer func() { _ = os.Remove(tmpDir) }() + + tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpFileName := tmpFile.Name() + _ = tmpFile.Close() + defer func() { _ = os.Remove(tmpFileName) }() + + // Specific socket files and directories to block + pathsToBlock := []string{ + "/run/dbus/system_bus_socket", + "/run/systemd/resolve/io.systemd.Resolve", + "/run/systemd/resolve/io.systemd.Resolve.Monitor", + "/run/nscd/socket", + "/var/run/dbus/system_bus_socket", + "/var/run/systemd/resolve/io.systemd.Resolve", + "/var/run/systemd/resolve/io.systemd.Resolve.Monitor", + "/var/run/nscd/socket", + } + + for _, p := range pathsToBlock { + stat, err := os.Stat(p) + if err == nil { + source := tmpFileName + if stat.IsDir() { + source = tmpDir + } + if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil { + fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err) + } else { + _ = unix.Mount("", p, "", unix.MS_PRIVATE, "") + } + } + } + return nil +} + // HostBind wraps a standard conn.Bind so that its socket creation (Open) // is forced to execute within a host network namespace. type HostBind struct { |
