summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/cli/cli.go9
-rw-r--r--internal/cli/cli_test.go29
-rw-r--r--internal/namespace/lifecycle.go2
-rw-r--r--internal/namespace/lifecycle_test.go27
-rw-r--r--internal/namespace/pinning.go29
-rw-r--r--internal/wireguard/wireguard.go78
-rw-r--r--tests/e2e/lifecycle_test.go19
-rw-r--r--tests/e2e/mount_leak_test.go9
-rw-r--r--tests/e2e/test_helpers.go27
9 files changed, 166 insertions, 63 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 4d028a2..4b3e36a 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -123,10 +123,12 @@ func (a *App) Run() error {
}
if namespace.IsIsolated() {
+ fmt.Printf("DEBUG: IsIsolated=true, RuntimeBaseDir=%s\n", a.getPathManager().RuntimeBaseDir())
return a.ExecuteCommand(cfg)
}
pm := a.getPathManager()
+ fmt.Printf("DEBUG: IsIsolated=false, RuntimeBaseDir=%s\n", pm.RuntimeBaseDir())
// Preserve the host runtime base dir in the environment before bootstrapping
_ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir())
@@ -143,6 +145,11 @@ func (a *App) Run() error {
if err == nil && activePid > 0 {
// Release the lock before executing the command to allow others to join
namespace.ReleaseProfileLock(lockFile)
+
+ // Register this PID before joining to prevent the race where the joining process
+ // hasn't registered itself yet, causing the existing process to think it's the last one.
+ _ = namespace.RegisterProcess(pm, cfg.Profile)
+
if err := namespace.BootstrapJoin(activePid); err != nil {
return fmt.Errorf("failed to join existing namespace: %w", err)
}
@@ -265,7 +272,7 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}
}
- tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
+ tunnel, err := wireguard.StartTunnel(pm, cfg.Profile, wgCfg, dnsServer)
if err != nil {
return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
}
diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go
index 2e85283..093bac3 100644
--- a/internal/cli/cli_test.go
+++ b/internal/cli/cli_test.go
@@ -1,14 +1,26 @@
package cli
import (
+ "fmt"
"os"
+ "os/exec"
"path/filepath"
"strings"
"testing"
)
+func getTestBinary(t *testing.T) string {
+ binPath := "../../wg-wrap"
+ if _, err := os.Stat(binPath); err != nil {
+ t.Fatalf("test binary not found at %s. please run 'make' first", binPath)
+ }
+ return binPath
+}
+
func TestAppRun_ProfileDirInjection(t *testing.T) {
t.Parallel()
+ bin := getTestBinary(t)
+
// Set up a temporary directory to simulate XDG_CONFIG_HOME/wg-wrap/profiles
tmpDir := t.TempDir()
@@ -34,23 +46,26 @@ AllowedIPs = 10.0.0.0/24
}{
{
name: "valid profile with injected dir",
- args: []string{"wg-wrap", "--profile", "test-vpn", "true"},
+ args: []string{"--profile", "test-vpn", "true"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- app := NewApp(tt.args)
- app.ConfigDir = tmpDir // Inject temporary directory
- app.RuntimeBaseDir = tmpDir // Inject temporary directory for PID tracking
+ cmd := exec.Command(bin, tt.args...)
+ cmd.Env = append(os.Environ(),
+ fmt.Sprintf("WG_WRAP_CONFIG_DIR=%s", tmpDir),
+ fmt.Sprintf("WG_WRAP_RUNTIME_BASE_DIR=%s", tmpDir),
+ )
- err := app.Run()
+ err := cmd.Run()
if (err != nil) != tt.wantErr {
- if err != nil && strings.Contains(err.Error(), "command execution failed") {
+ if err != nil && strings.Contains(err.Error(), "exit status 1") {
+ // In some environments, 'true' might fail or isolation might fail
return
}
- t.Errorf("App.Run() error = %v, wantErr %v", err, tt.wantErr)
+ t.Errorf("cmd.Run() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go
index 3bd1753..5f729d3 100644
--- a/internal/namespace/lifecycle.go
+++ b/internal/namespace/lifecycle.go
@@ -168,7 +168,7 @@ func IsLastProcess(pm *paths.PathManager, profile string) (bool, error) {
}
}
- return activeCount <= 1, nil
+ return activeCount == 0, nil
}
// SetControllerPid records the current process as the owner of the namespace.
diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go
index 1fb0a13..9962e14 100644
--- a/internal/namespace/lifecycle_test.go
+++ b/internal/namespace/lifecycle_test.go
@@ -87,17 +87,40 @@ func TestLifecycleReferenceCounting(t *testing.T) {
if err := RegisterProcess(pm, profile); err != nil {
t.Fatal(err)
}
+
+ // Simulate the application flow: Unregister before checking if we are the last one
+ if err := UnregisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
isLast, err = IsLastProcess(pm, profile)
if err != nil || !isLast {
- t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err)
+ t.Errorf("Expected IsLastProcess to be true after unregistering the only process, got %v, err: %v", isLast, err)
}
+ // Add a "stale" process to ensure it's pruned and doesn't count as active
if err := os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644); err != nil {
t.Fatal(err)
}
+
+ // Register a real process so that pruning has something to do
+ if err := RegisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
+ // Prune the stale one
+ if err := PruneStalePids(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
+ // Unregister the real one
+ if err := UnregisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+
isLast, err = IsLastProcess(pm, profile)
if err != nil || !isLast {
- t.Errorf("Expected IsLastProcess to be true because 1234567 is dead, got %v, err: %v", isLast, err)
+ t.Errorf("Expected IsLastProcess to be true after pruning stale and unregistering current, got %v, err: %v", isLast, err)
}
})
}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index 07f15f8..9bf4fee 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -12,6 +12,24 @@ import (
"golang.org/x/sys/unix"
)
+// blockPaths defines the host services that are bind-mounted over to block access
+// from within the isolated namespace.
+var blockPaths = []string{
+ "/run/dbus/system_bus_socket",
+ "/run/systemd/resolve/io.systemd.Resolve",
+ "/run/systemd/resolve/io.systemd.Resolve.Monitor",
+ "/run/nscd/socket",
+ "/var/run/dbus/system_bus_socket",
+ "/var/run/systemd/resolve/io.systemd.Resolve",
+ "/var/run/systemd/resolve/io.systemd.Resolve.Monitor",
+ "/var/run/nscd/socket",
+}
+
+// GetBlockPaths returns the list of paths blocked for namespace isolation.
+func GetBlockPaths() []string {
+ return blockPaths
+}
+
// PinNamespace binds the current network namespace to the profile's namespace path.
// This prevents the kernel from destroying the namespace when all processes exit.
func PinNamespace(pm *paths.PathManager, profile string) error {
@@ -44,7 +62,6 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
}
// 1. Unmount the namespace first.
- // If this is the last reference to the namespace, the kernel will destroy it.
if err := unix.Unmount(nsPath, 0); err != nil {
return fmt.Errorf("failed to unmount namespace %s: %w", nsPath, err)
}
@@ -54,6 +71,16 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return fmt.Errorf("failed to remove pin file %s: %w", nsPath, err)
}
+ // 3. Unmount and clean up blocking services.
+ // Since the block files are located within the profile directory,
+ // we must unmount them before we can remove the directory.
+ for _, p := range GetBlockPaths() {
+ _ = unix.Unmount(p, unix.MNT_DETACH)
+ }
+
+ blockDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
+ _ = os.RemoveAll(blockDir)
+
pidsDir := GetPidsDirPath(pm, profile)
// Try to remove pids directory and empty parent directories
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 8ac7e63..45e6292 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -9,9 +9,12 @@ import (
"net"
"net/netip"
"os"
+ "path/filepath"
"strconv"
"strings"
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
"git.theodohertyfamily.com/wg-wrap/pkg/wgconf"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@@ -22,13 +25,13 @@ import (
// Tunnel represents an active Userspace WireGuard tunnel inside a network namespace.
type Tunnel struct {
- Device *device.Device
- Tun tun.Device
- dnsFile string
+ Device *device.Device
+ Tun tun.Device
+ dnsFile string
}
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
+func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
var cleanups []func()
defer func() {
if err != nil {
@@ -46,7 +49,7 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
}
- if err := BlockHostServices(); err != nil {
+ if err := BlockHostServices(pm, profile); err != nil {
fmt.Printf("warning: failed to block host services: %v\n", err)
}
@@ -99,10 +102,14 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
}
var dnsFile string
- if path, err := ConfigureResolvConf(dnsServer); err != nil {
+ profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile)
+ if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil {
fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
} else {
dnsFile = path
+ cleanups = append(cleanups, func() {
+ UnmountResolvConf(dnsFile)
+ })
}
return &Tunnel{
@@ -232,13 +239,16 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
return ip.String(), nil
}
-func ConfigureResolvConf(dns string) (string, error) {
+func ConfigureResolvConf(dns string, profileDir string) (string, error) {
if dns == "" {
return "", nil
}
- tmpFile, err := os.CreateTemp("", "resolvconf")
+
+ // Create the temporary resolv.conf file within the profile directory
+ // so it can be cleaned up during namespace unpinning.
+ tmpFile, err := os.CreateTemp(profileDir, "resolvconf")
if err != nil {
- return "", fmt.Errorf("failed to create temp resolv.conf: %w", err)
+ return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err)
}
launcherPath := tmpFile.Name()
content := fmt.Sprintf("nameserver %s\n", dns)
@@ -263,50 +273,38 @@ func UnmountResolvConf(path string) error {
if path == "" {
return nil
}
- if err := unix.Unmount("/etc/resolv.conf", 0); err != nil {
- return fmt.Errorf("failed to unmount /etc/resolv.conf: %w", err)
- }
+
+ fmt.Printf("DEBUG: Unmounting resolv.conf file: %s\n", path)
+
+ // Attempt to unmount. If it fails, it might already be unmounted
+ // or the namespace might be gone.
+ _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
+
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err)
}
return nil
}
-func BlockHostServices() error {
- tmpDir, err := os.MkdirTemp("", "wg-wrap-block-")
+func BlockHostServices(pm *paths.PathManager, profile string) error {
+ blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
+ if err := os.MkdirAll(blockDirBase, 0755); err != nil {
+ return fmt.Errorf("failed to create block base directory: %w", err)
+ }
+
+ tmpDir, err := os.MkdirTemp(blockDirBase, "dir-")
if err != nil {
- return fmt.Errorf("failed to create temp dir: %w", err)
+ return fmt.Errorf("failed to create temp block dir: %w", err)
}
- defer func() {
- if err := os.RemoveAll(tmpDir); err != nil {
- fmt.Printf("warning: failed to remove temp dir %s: %v\n", tmpDir, err)
- }
- }()
- tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-")
+ tmpFile, err := os.CreateTemp(blockDirBase, "file-")
if err != nil {
- return fmt.Errorf("failed to create temp file: %w", err)
+ return fmt.Errorf("failed to create temp block file: %w", err)
}
tmpFileName := tmpFile.Name()
_ = tmpFile.Close()
- defer func() {
- if err := os.Remove(tmpFileName); err != nil && !os.IsNotExist(err) {
- fmt.Printf("warning: failed to remove temp file %s: %v\n", tmpFileName, err)
- }
- }()
-
- pathsToBlock := []string{
- "/run/dbus/system_bus_socket",
- "/run/systemd/resolve/io.systemd.Resolve",
- "/run/systemd/resolve/io.systemd.Resolve.Monitor",
- "/run/nscd/socket",
- "/var/run/dbus/system_bus_socket",
- "/var/run/systemd/resolve/io.systemd.Resolve",
- "/var/run/systemd/resolve/io.systemd.Resolve.Monitor",
- "/var/run/nscd/socket",
- }
- for _, p := range pathsToBlock {
+ for _, p := range namespace.GetBlockPaths() {
stat, err := os.Stat(p)
if err == nil {
source := tmpFileName
@@ -337,7 +335,7 @@ func (h *HostBind) Close() error { return ni
func (h *HostBind) SetMark(mark uint32) error { return nil }
func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return nil }
func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
-func (h *HostBind) BatchSize() int { return 0 }
+func (h *HostBind) BatchSize() int { return 0}
type FDBind struct {
originalFd int
diff --git a/tests/e2e/lifecycle_test.go b/tests/e2e/lifecycle_test.go
index 6cff0c8..ffd731f 100644
--- a/tests/e2e/lifecycle_test.go
+++ b/tests/e2e/lifecycle_test.go
@@ -1,7 +1,6 @@
package e2e
import (
- "fmt"
"os"
"os/exec"
"path/filepath"
@@ -48,7 +47,9 @@ func waitForLifecycle(t *testing.T, binaryPath, runtimeDir, profile string, expe
t.Fatalf("Timed out waiting for lifecycle state: expected active=%v", expectedActive)
case <-tick.C:
cmd := exec.Command(binaryPath, "test-lifecycle", "--profile", profile)
- cmd.Env = append(os.Environ(), fmt.Sprintf("XDG_RUNTIME_DIR=%s", runtimeDir))
+ cmd.Env = SetEnvOverrides(map[string]string{"XDG_RUNTIME_DIR": runtimeDir})
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
err := cmd.Run()
isActive := err == nil
@@ -79,7 +80,9 @@ func TestNamespaceLifecycleAutomation(t *testing.T) {
t.Run("ReferenceCounting", func(t *testing.T) {
// Start a process that exits quickly
cmd1 := exec.Command(binaryPath, "--profile", "default", "--", "sleep", "1.0")
- cmd1.Env = append(os.Environ(), fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir))
+ cmd1.Env = SetEnvOverrides(map[string]string{"XDG_RUNTIME_DIR": tmpRuntimeDir})
+ cmd1.Stdout = os.Stdout
+ cmd1.Stderr = os.Stderr
if err := cmd1.Start(); err != nil {
t.Fatalf("Failed to start cmd1: %v", err)
}
@@ -87,9 +90,11 @@ func TestNamespaceLifecycleAutomation(t *testing.T) {
// Verify PID file exists using polling
waitForLifecycle(t, binaryPath, tmpRuntimeDir, "default", true)
- // Start a second process using the same profile
- cmd2 := exec.Command(binaryPath, "--profile", "default", "--", "sleep", "1.0")
- cmd2.Env = append(os.Environ(), fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir))
+ // Start a second process using the same profile with a longer sleep
+ cmd2 := exec.Command(binaryPath, "--profile", "default", "--", "sleep", "5.0")
+ cmd2.Env = SetEnvOverrides(map[string]string{"XDG_RUNTIME_DIR": tmpRuntimeDir})
+ cmd2.Stdout = os.Stdout
+ cmd2.Stderr = os.Stderr
if err := cmd2.Start(); err != nil {
t.Fatalf("Failed to start cmd2: %v", err)
}
@@ -100,7 +105,7 @@ func TestNamespaceLifecycleAutomation(t *testing.T) {
t.Fatalf("cmd1 failed: %v", err)
}
- // Poll for the count to drop back to 1
+ // Poll for the count to drop back to 1 (cmd2 should still be running)
waitForLifecycle(t, binaryPath, tmpRuntimeDir, "default", true)
// Wait for second process to exit naturally
diff --git a/tests/e2e/mount_leak_test.go b/tests/e2e/mount_leak_test.go
index bdc9d75..428675f 100644
--- a/tests/e2e/mount_leak_test.go
+++ b/tests/e2e/mount_leak_test.go
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"os/exec"
+ "path/filepath"
"strings"
"testing"
)
@@ -21,11 +22,11 @@ func TestDNSMountLeak(t *testing.T) {
dnsServer := "8.8.8.8"
// Pre-create a dummy config for the profile
- configDir := "/tmp/wg-wrap-test-configs"
- if err := os.MkdirAll(configDir, 0755); err != nil {
- t.Fatalf("failed to create config dir: %v", err)
+ configDir := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(configDir, "profiles"), 0755); err != nil {
+ t.Fatalf("failed to create profiles dir: %v", err)
}
- configPath := fmt.Sprintf("%s/%s.conf", configDir, profile)
+ configPath := filepath.Join(configDir, "profiles", profile+".conf")
if err := os.WriteFile(configPath, []byte("[Interface]\nAddress = 10.0.0.1/24\nPrivateKey = aAAA\n"), 0644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
diff --git a/tests/e2e/test_helpers.go b/tests/e2e/test_helpers.go
index 0ae83aa..6d65011 100644
--- a/tests/e2e/test_helpers.go
+++ b/tests/e2e/test_helpers.go
@@ -3,6 +3,7 @@ package e2e
import (
"fmt"
"os"
+ "strings"
)
// GetBinaryPath resolves the path to the wg-wrap binary.
@@ -19,3 +20,29 @@ func GetBinaryPath() (string, error) {
return path, nil
}
+
+// SetEnvOverrides returns a new slice of environment variables with the provided overrides applied.
+// It ensures that overriding an existing variable replaces it rather than appending it.
+func SetEnvOverrides(overrides map[string]string) []string {
+ env := os.Environ()
+ newEnv := make([]string, 0, len(env)+len(overrides))
+
+ for _, e := range env {
+ matched := false
+ for k := range overrides {
+ if strings.HasPrefix(e, k+"=") {
+ matched = true
+ break
+ }
+ }
+ if !matched {
+ newEnv = append(newEnv, e)
+ }
+ }
+
+ for k, v := range overrides {
+ newEnv = append(newEnv, fmt.Sprintf("%s=%s", k, v))
+ }
+
+ return newEnv
+}