summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorJames O'Doherty <james@theodohertyfamily.com>2026-06-03 23:45:45 -0400
committerJames O'Doherty <james@theodohertyfamily.com>2026-06-03 23:45:45 -0400
commit51a0845adba702ac02437405988b24b3b2c9fb45 (patch)
tree62174471b2bf2240f5cbe8532c991e33afce9e18 /internal
parentda70b10fbd056f19d892acad542ce96c40c58389 (diff)
fix: resolve resource leaks and improve namespace lifecycle management
- Fix DNS resolver leaks by creating temporary resolv.conf files within the profile's runtime directory and ensuring robust cleanup. - Fix isolation block directory leaks by explicitly removing the block directory during namespace unpinning. - Improve namespace lifecycle management: - Register processes before joining an active namespace to prevent race conditions in reference counting. - Update `IsLastProcess` and corresponding tests to reflect the unregister-then-check cleanup flow. - Improve test reliability and correctness: - Convert `TestAppRun_ProfileDirInjection` to use separate binary execution, preventing process replacement and ensuring `t.TempDir()` cleanup. - Replace hardcoded test configuration paths with `t.TempDir()` in `mount_leak_test.go`. - Implement `SetEnvOverrides` helper for cleaner environment variable management in E2E tests. - Improve E2E lifecycle tests with better environment handling and output redirection.
Diffstat (limited to 'internal')
-rw-r--r--internal/cli/cli.go9
-rw-r--r--internal/cli/cli_test.go29
-rw-r--r--internal/namespace/lifecycle.go2
-rw-r--r--internal/namespace/lifecycle_test.go27
-rw-r--r--internal/namespace/pinning.go29
-rw-r--r--internal/wireguard/wireguard.go78
6 files changed, 122 insertions, 52 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 4d028a2..4b3e36a 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -123,10 +123,12 @@ func (a *App) Run() error {
}
if namespace.IsIsolated() {
+ fmt.Printf("DEBUG: IsIsolated=true, RuntimeBaseDir=%s\n", a.getPathManager().RuntimeBaseDir())
return a.ExecuteCommand(cfg)
}
pm := a.getPathManager()
+ fmt.Printf("DEBUG: IsIsolated=false, RuntimeBaseDir=%s\n", pm.RuntimeBaseDir())
// Preserve the host runtime base dir in the environment before bootstrapping
_ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir())
@@ -143,6 +145,11 @@ func (a *App) Run() error {
if err == nil && activePid > 0 {
// Release the lock before executing the command to allow others to join
namespace.ReleaseProfileLock(lockFile)
+
+ // Register this PID before joining to prevent the race where the joining process
+ // hasn't registered itself yet, causing the existing process to think it's the last one.
+ _ = namespace.RegisterProcess(pm, cfg.Profile)
+
if err := namespace.BootstrapJoin(activePid); err != nil {
return fmt.Errorf("failed to join existing namespace: %w", err)
}
@@ -265,7 +272,7 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}
}
- tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
+ tunnel, err := wireguard.StartTunnel(pm, cfg.Profile, wgCfg, dnsServer)
if err != nil {
return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
}
diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go
index 2e85283..093bac3 100644
--- a/internal/cli/cli_test.go
+++ b/internal/cli/cli_test.go
@@ -1,14 +1,26 @@
package cli
import (
+ "fmt"
"os"
+ "os/exec"
"path/filepath"
"strings"
"testing"
)
+func getTestBinary(t *testing.T) string {
+ binPath := "../../wg-wrap"
+ if _, err := os.Stat(binPath); err != nil {
+ t.Fatalf("test binary not found at %s. please run 'make' first", binPath)
+ }
+ return binPath
+}
+
func TestAppRun_ProfileDirInjection(t *testing.T) {
t.Parallel()
+ bin := getTestBinary(t)
+
// Set up a temporary directory to simulate XDG_CONFIG_HOME/wg-wrap/profiles
tmpDir := t.TempDir()
@@ -34,23 +46,26 @@ AllowedIPs = 10.0.0.0/24
}{
{
name: "valid profile with injected dir",
- args: []string{"wg-wrap", "--profile", "test-vpn", "true"},
+ args: []string{"--profile", "test-vpn", "true"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- app := NewApp(tt.args)
- app.ConfigDir = tmpDir // Inject temporary directory
- app.RuntimeBaseDir = tmpDir // Inject temporary directory for PID tracking
+ cmd := exec.Command(bin, tt.args...)
+ cmd.Env = append(os.Environ(),
+ fmt.Sprintf("WG_WRAP_CONFIG_DIR=%s", tmpDir),
+ fmt.Sprintf("WG_WRAP_RUNTIME_BASE_DIR=%s", tmpDir),
+ )
- err := app.Run()
+ err := cmd.Run()
if (err != nil) != tt.wantErr {
- if err != nil && strings.Contains(err.Error(), "command execution failed") {
+ if err != nil && strings.Contains(err.Error(), "exit status 1") {
+ // In some environments, 'true' might fail or isolation might fail
return
}
- t.Errorf("App.Run() error = %v, wantErr %v", err, tt.wantErr)
+ t.Errorf("cmd.Run() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go
index 3bd1753..5f729d3 100644
--- a/internal/namespace/lifecycle.go
+++ b/internal/namespace/lifecycle.go
@@ -168,7 +168,7 @@ func IsLastProcess(pm *paths.PathManager, profile string) (bool, error) {
}
}
- return activeCount <= 1, nil
+ return activeCount == 0, nil
}
// SetControllerPid records the current process as the owner of the namespace.
diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go
index 1fb0a13..9962e14 100644
--- a/internal/namespace/lifecycle_test.go
+++ b/internal/namespace/lifecycle_test.go
@@ -87,17 +87,40 @@ func TestLifecycleReferenceCounting(t *testing.T) {
if err := RegisterProcess(pm, profile); err != nil {
t.Fatal(err)
}
+
+ // Simulate the application flow: Unregister before checking if we are the last one
+ if err := UnregisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
isLast, err = IsLastProcess(pm, profile)
if err != nil || !isLast {
- t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err)
+ t.Errorf("Expected IsLastProcess to be true after unregistering the only process, got %v, err: %v", isLast, err)
}
+ // Add a "stale" process to ensure it's pruned and doesn't count as active
if err := os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644); err != nil {
t.Fatal(err)
}
+
+ // Register a real process so that pruning has something to do
+ if err := RegisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
+ // Prune the stale one
+ if err := PruneStalePids(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
+ // Unregister the real one
+ if err := UnregisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
isLast, err = IsLastProcess(pm, profile)
if err != nil || !isLast {
- t.Errorf("Expected IsLastProcess to be true because 1234567 is dead, got %v, err: %v", isLast, err)
+ t.Errorf("Expected IsLastProcess to be true after pruning stale and unregistering current, got %v, err: %v", isLast, err)
}
})
}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index 07f15f8..9bf4fee 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -12,6 +12,24 @@ import (
"golang.org/x/sys/unix"
)
+// blockPaths defines the host services that are bind-mounted over to block access
+// from within the isolated namespace.
+var blockPaths = []string{
+ "/run/dbus/system_bus_socket",
+ "/run/systemd/resolve/io.systemd.Resolve",
+ "/run/systemd/resolve/io.systemd.Resolve.Monitor",
+ "/run/nscd/socket",
+ "/var/run/dbus/system_bus_socket",
+ "/var/run/systemd/resolve/io.systemd.Resolve",
+ "/var/run/systemd/resolve/io.systemd.Resolve.Monitor",
+ "/var/run/nscd/socket",
+}
+
+// GetBlockPaths returns the list of paths blocked for namespace isolation.
+func GetBlockPaths() []string {
+ return blockPaths
+}
+
// PinNamespace binds the current network namespace to the profile's namespace path.
// This prevents the kernel from destroying the namespace when all processes exit.
func PinNamespace(pm *paths.PathManager, profile string) error {
@@ -44,7 +62,6 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
}
// 1. Unmount the namespace first.
- // If this is the last reference to the namespace, the kernel will destroy it.
if err := unix.Unmount(nsPath, 0); err != nil {
return fmt.Errorf("failed to unmount namespace %s: %w", nsPath, err)
}
@@ -54,6 +71,16 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return fmt.Errorf("failed to remove pin file %s: %w", nsPath, err)
}
+ // 3. Unmount and clean up blocking services.
+ // Since the block files are located within the profile directory,
+ // we must unmount them before we can remove the directory.
+ for _, p := range GetBlockPaths() {
+ _ = unix.Unmount(p, unix.MNT_DETACH)
+ }
+
+ blockDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
+ _ = os.RemoveAll(blockDir)
+
pidsDir := GetPidsDirPath(pm, profile)
// Try to remove pids directory and empty parent directories
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 8ac7e63..45e6292 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -9,9 +9,12 @@ import (
"net"
"net/netip"
"os"
+ "path/filepath"
"strconv"
"strings"
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
"git.theodohertyfamily.com/wg-wrap/pkg/wgconf"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@@ -22,13 +25,13 @@ import (
// Tunnel represents an active Userspace WireGuard tunnel inside a network namespace.
type Tunnel struct {
- Device *device.Device
- Tun tun.Device
- dnsFile string
+ Device *device.Device
+ Tun tun.Device
+ dnsFile string
}
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
+func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
var cleanups []func()
defer func() {
if err != nil {
@@ -46,7 +49,7 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
}
- if err := BlockHostServices(); err != nil {
+ if err := BlockHostServices(pm, profile); err != nil {
fmt.Printf("warning: failed to block host services: %v\n", err)
}
@@ -99,10 +102,14 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
}
var dnsFile string
- if path, err := ConfigureResolvConf(dnsServer); err != nil {
+ profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile)
+ if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil {
fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
} else {
dnsFile = path
+ cleanups = append(cleanups, func() {
+ UnmountResolvConf(dnsFile)
+ })
}
return &Tunnel{
@@ -232,13 +239,16 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
return ip.String(), nil
}
-func ConfigureResolvConf(dns string) (string, error) {
+func ConfigureResolvConf(dns string, profileDir string) (string, error) {
if dns == "" {
return "", nil
}
- tmpFile, err := os.CreateTemp("", "resolvconf")
+
+ // Create the temporary resolv.conf file within the profile directory
+ // so it can be cleaned up during namespace unpinning.
+ tmpFile, err := os.CreateTemp(profileDir, "resolvconf")
if err != nil {
- return "", fmt.Errorf("failed to create temp resolv.conf: %w", err)
+ return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err)
}
launcherPath := tmpFile.Name()
content := fmt.Sprintf("nameserver %s\n", dns)
@@ -263,50 +273,38 @@ func UnmountResolvConf(path string) error {
if path == "" {
return nil
}
- if err := unix.Unmount("/etc/resolv.conf", 0); err != nil {
- return fmt.Errorf("failed to unmount /etc/resolv.conf: %w", err)
- }
+
+ fmt.Printf("DEBUG: Unmounting resolv.conf file: %s\n", path)
+
+ // Attempt to unmount. If it fails, it might already be unmounted
+ // or the namespace might be gone.
+ _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
+
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err)
}
return nil
}
-func BlockHostServices() error {
- tmpDir, err := os.MkdirTemp("", "wg-wrap-block-")
+func BlockHostServices(pm *paths.PathManager, profile string) error {
+ blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
+ if err := os.MkdirAll(blockDirBase, 0755); err != nil {
+ return fmt.Errorf("failed to create block base directory: %w", err)
+ }
+
+ tmpDir, err := os.MkdirTemp(blockDirBase, "dir-")
if err != nil {
- return fmt.Errorf("failed to create temp dir: %w", err)
+ return fmt.Errorf("failed to create temp block dir: %w", err)
}
- defer func() {
- if err := os.RemoveAll(tmpDir); err != nil {
- fmt.Printf("warning: failed to remove temp dir %s: %v\n", tmpDir, err)
- }
- }()
- tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-")
+ tmpFile, err := os.CreateTemp(blockDirBase, "file-")
if err != nil {
- return fmt.Errorf("failed to create temp file: %w", err)
+ return fmt.Errorf("failed to create temp block file: %w", err)
}
tmpFileName := tmpFile.Name()
_ = tmpFile.Close()
- defer func() {
- if err := os.Remove(tmpFileName); err != nil && !os.IsNotExist(err) {
- fmt.Printf("warning: failed to remove temp file %s: %v\n", tmpFileName, err)
- }
- }()
-
- pathsToBlock := []string{
- "/run/dbus/system_bus_socket",
- "/run/systemd/resolve/io.systemd.Resolve",
- "/run/systemd/resolve/io.systemd.Resolve.Monitor",
- "/run/nscd/socket",
- "/var/run/dbus/system_bus_socket",
- "/var/run/systemd/resolve/io.systemd.Resolve",
- "/var/run/systemd/resolve/io.systemd.Resolve.Monitor",
- "/var/run/nscd/socket",
- }
- for _, p := range pathsToBlock {
+ for _, p := range namespace.GetBlockPaths() {
stat, err := os.Stat(p)
if err == nil {
source := tmpFileName
@@ -337,7 +335,7 @@ func (h *HostBind) Close() error { return ni
func (h *HostBind) SetMark(mark uint32) error { return nil }
func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return nil }
func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
-func (h *HostBind) BatchSize() int { return 0 }
+func (h *HostBind) BatchSize() int { return 0}
type FDBind struct {
originalFd int