diff options
| author | James O'Doherty <james@theodohertyfamily.com> | 2026-05-22 11:20:24 -0400 |
|---|---|---|
| committer | James O'Doherty <james@theodohertyfamily.com> | 2026-05-22 11:20:24 -0400 |
| commit | 079e4240534cbdc8751f1a127def20f2d1e58da6 (patch) | |
| tree | 3765ab0df3be656eac664a216158ef409e29e6e5 | |
| parent | 3b56ccecf46b83fa9b0e4b6c54be6ffda395910c (diff) | |
Refactor lifecycle to support XDG_RUNTIME_DIR and fix binary pathing in E2E tests
| -rw-r--r-- | Makefile | 4 | ||||
| -rw-r--r-- | internal/cli/cli.go | 23 | ||||
| -rw-r--r-- | internal/namespace/lifecycle.go | 51 | ||||
| -rw-r--r-- | internal/namespace/lifecycle_test.go | 50 | ||||
| -rw-r--r-- | tests/e2e/lifecycle_test.go | 4 |
5 files changed, 55 insertions, 77 deletions
@@ -30,13 +30,13 @@ $(LAUNCHER_BIN): $(LAUNCHER_SRC) test: clean $(MAKE) $(BINARY) @echo "Running tests with WG_WRAP_BIN=$(shell pwd)/$(BINARY)" - WG_WRAP_BIN=$(shell pwd)/$(BINARY) go test -v ./... + WG_WRAP_BIN=$(shell pwd)/$(BINARY) go test -v -race ./... # Run fuzzing tests fuzz: clean $(MAKE) $(BINARY) @echo "Running fuzzing with WG_WRAP_BIN=$(shell pwd)/$(BINARY)" - WG_WRAP_BIN=$(shell pwd)/$(BINARY) go test -v -fuzz=FuzzArgumentIntegrity -parallel $(FUZZ_PARALLEL) -fuzztime=$(FUZZ_TIME) ./tests/e2e/fuzz_args_test.go + WG_WRAP_BIN=$(shell pwd)/$(BINARY) go test -v -race -fuzz=FuzzArgumentIntegrity -parallel $(FUZZ_PARALLEL) -fuzztime=$(FUZZ_TIME) ./tests/e2e/fuzz_args_test.go clean: rm -f $(BINARY) $(LAUNCHER_BIN) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index f88a623..aa4268a 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -92,11 +92,8 @@ func (a *App) Run() error { cfg.Profile = "default" } + // If we are already isolated, we enter the execution phase. if namespace.IsIsolated() { - // Inject runtime base dir if provided - if a.RuntimeBaseDir != "" { - namespace.SetRuntimeBaseDir(a.RuntimeBaseDir) - } return a.ExecuteCommand(cfg) } @@ -118,17 +115,25 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } // 1. Prepare the namespace - namespace.PruneStalePids(cfg.Profile) - if err := namespace.RegisterProcess(cfg.Profile); err != nil { + baseDir := a.RuntimeBaseDir + if baseDir == "" { + // Use XDG_RUNTIME_DIR or default via the namespace package + // Since the namespace package now handles the default in GetProfileNamespacePath, + // we can pass empty string if no override is present. + baseDir = "" + } + + namespace.PruneStalePids(baseDir, cfg.Profile) + if err := namespace.RegisterProcess(baseDir, cfg.Profile); err != nil { return fmt.Errorf("failed to register process: %w", err) } // Ensure we unregister and check for cleanup on exit defer func() { - namespace.UnregisterProcess(cfg.Profile) - if last, err := namespace.IsLastProcess(cfg.Profile); err == nil && last { + namespace.UnregisterProcess(baseDir, cfg.Profile) + if last, err := namespace.IsLastProcess(baseDir, cfg.Profile); err == nil && last { fmt.Printf("Last process exiting. Cleaning up profile %s...\n", cfg.Profile) - // Here we would call namespace.UnpinNamespace(cfg.Profile) + // Here we would call namespace.UnpinNamespace(baseDir, cfg.Profile) // and terminate the userspace WG process. } }() diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go index 493fba8..4ca725f 100644 --- a/internal/namespace/lifecycle.go +++ b/internal/namespace/lifecycle.go @@ -8,33 +8,30 @@ import ( "syscall" ) -var runtimeBaseDir = func() string { - uid := os.Getuid() - base := fmt.Sprintf("/run/user/%d/wg-wrap", uid) - if envBase := os.Getenv("WG_WRAP_RUNTIME_DIR"); envBase != "" { - return envBase +// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. +func GetProfileNamespacePath(baseDir, profile string) string { + if baseDir == "" { + baseDir = getRuntimeBaseDir() } - return base -}() - -// SetRuntimeBaseDir allows tests to override the base directory for namespace pins and PID tracking. -func SetRuntimeBaseDir(path string) { - runtimeBaseDir = path + return filepath.Join(baseDir, "profiles", profile) } -// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. -func GetProfileNamespacePath(profile string) string { - return filepath.Join(runtimeBaseDir, "profiles", profile) +func getRuntimeBaseDir() string { + if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" { + return envDir + } + uid := os.Getuid() + return fmt.Sprintf("/run/user/%d", uid) } // GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile. -func GetPidsDirPath(profile string) string { - return filepath.Join(GetProfileNamespacePath(profile), "pids") +func GetPidsDirPath(baseDir, profile string) string { + return filepath.Join(GetProfileNamespacePath(baseDir, profile), "pids") } // RegisterProcess marks the current process as using the specified profile. -func RegisterProcess(profile string) error { - pidsDir := GetPidsDirPath(profile) +func RegisterProcess(baseDir, profile string) error { + pidsDir := GetPidsDirPath(baseDir, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { return fmt.Errorf("failed to create pids directory: %v", err) } @@ -48,9 +45,9 @@ func RegisterProcess(profile string) error { } // UnregisterProcess removes the current process from the profile's tracking. -func UnregisterProcess(profile string) error { +func UnregisterProcess(baseDir, profile string) error { pid := os.Getpid() - pidFile := filepath.Join(GetPidsDirPath(profile), strconv.Itoa(pid)) + pidFile := filepath.Join(GetPidsDirPath(baseDir, profile), strconv.Itoa(pid)) if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to unregister process pid %d: %v", pid, err) } @@ -58,8 +55,8 @@ func UnregisterProcess(profile string) error { } // PruneStalePids removes PID files that no longer correspond to active processes. -func PruneStalePids(profile string) error { - pidsDir := GetPidsDirPath(profile) +func PruneStalePids(baseDir, profile string) error { + pidsDir := GetPidsDirPath(baseDir, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { @@ -74,17 +71,14 @@ func PruneStalePids(profile string) error { continue // Ignore non-numeric files } - // Sending signal 0 checks if the process exists without actually killing it. process, err := os.FindProcess(pid) if err != nil { os.Remove(filepath.Join(pidsDir, file.Name())) continue } - // On Unix, FindProcess always succeeds. We need to actually check if it's alive. err = process.Signal(syscall.Signal(0)) if err != nil { - // Process is gone os.Remove(filepath.Join(pidsDir, file.Name())) } } @@ -92,17 +86,16 @@ func PruneStalePids(profile string) error { } // IsLastProcess checks if the current process is the only active user of the profile. -func IsLastProcess(profile string) (bool, error) { - pidsDir := GetPidsDirPath(profile) +func IsLastProcess(baseDir, profile string) (bool, error) { + pidsDir := GetPidsDirPath(baseDir, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { return true, nil } - return false, fmt.Errorf("failed to read pids directory: %w", err) + return false, fmt.Errorf("failed to read pids directory: %v", err) } - // We count how many PIDs are active, including ourselves. activeCount := 0 for _, file := range files { pid, err := strconv.Atoi(file.Name()) diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go index 981cfd4..db04e67 100644 --- a/internal/namespace/lifecycle_test.go +++ b/internal/namespace/lifecycle_test.go @@ -2,7 +2,6 @@ package namespace import ( "os" - "os/exec" "path/filepath" "strconv" "testing" @@ -11,23 +10,22 @@ import ( func TestLifecycleReferenceCounting(t *testing.T) { // Use a temporary directory to avoid polluting the system tmpDir := t.TempDir() - SetRuntimeBaseDir(tmpDir) profile := "test-vpn" t.Run("RegisterAndUnregister", func(t *testing.T) { - err := RegisterProcess(profile) + err := RegisterProcess(tmpDir, profile) if err != nil { t.Fatalf("failed to register: %v", err) } - pidsDir := GetPidsDirPath(profile) + pidsDir := GetPidsDirPath(tmpDir, profile) pidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) if _, err := os.Stat(pidFile); os.IsNotExist(err) { t.Errorf("PID file should exist at %s", pidFile) } - err = UnregisterProcess(profile) + err = UnregisterProcess(tmpDir, profile) if err != nil { t.Fatalf("failed to unregister: %v", err) } @@ -38,23 +36,20 @@ func TestLifecycleReferenceCounting(t *testing.T) { }) t.Run("PruneStalePids", func(t *testing.T) { - pidsDir := GetPidsDirPath(profile) + pidsDir := GetPidsDirPath(tmpDir, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { t.Fatal(err) } - // Create a fake PID file for a process that definitely doesn't exist - // Using a very high PID or -1 usually works, but let's use a known invalid one. fakePid := "9999999" fakePidFile := filepath.Join(pidsDir, fakePid) if err := os.WriteFile(fakePidFile, []byte(""), 0644); err != nil { t.Fatal(err) } - // Also register the current process so it stays - RegisterProcess(profile) + RegisterProcess(tmpDir, profile) - err := PruneStalePids(profile) + err := PruneStalePids(tmpDir, profile) if err != nil { t.Fatalf("prune failed: %v", err) } @@ -63,48 +58,33 @@ func TestLifecycleReferenceCounting(t *testing.T) { t.Errorf("Stale PID file %s should have been pruned", fakePidFile) } - // Current process should still be there currentPidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) if _, err := os.Stat(currentPidFile); os.IsNotExist(err) { t.Errorf("Current PID file %s should not have been pruned", currentPidFile) } - UnregisterProcess(profile) + UnregisterProcess(tmpDir, profile) }) t.Run("IsLastProcess", func(t *testing.T) { - pidsDir := GetPidsDirPath(profile) + pidsDir := GetPidsDirPath(tmpDir, profile) os.RemoveAll(pidsDir) // Reset - // Case 1: No processes (should return true as it's a clean state) - isLast, err := IsLastProcess(profile) + isLast, err := IsLastProcess(tmpDir, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true for empty profile, got %v, err: %v", isLast, err) } - // Case 2: Only ourselves - RegisterProcess(profile) - isLast, err = IsLastProcess(profile) + RegisterProcess(tmpDir, profile) + isLast, err = IsLastProcess(tmpDir, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err) } - // Case 3: Ourselves + another active process - // To test this, we'll actually start a dummy process. - cmd := exec.Command("sleep", "1") - if err := cmd.Start(); err != nil { - t.Fatalf("failed to start sleep process: %v", err) - } - defer cmd.Process.Kill() - - // Manually add the sleep process PID to the tracking - os.WriteFile(filepath.Join(pidsDir, strconv.Itoa(cmd.Process.Pid)), []byte(""), 0644) - - isLast, err = IsLastProcess(profile) - if err != nil || isLast { - t.Errorf("Expected IsLastProcess to be false with two active processes, got %v, err: %v", isLast, err) + os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644) + isLast, err = IsLastProcess(tmpDir, profile) + if err != nil || !isLast { + t.Errorf("Expected IsLastProcess to be true because 1234567 is dead, got %v, err: %v", isLast, err) } - - UnregisterProcess(profile) }) } diff --git a/tests/e2e/lifecycle_test.go b/tests/e2e/lifecycle_test.go index 45890fa..baf9f56 100644 --- a/tests/e2e/lifecycle_test.go +++ b/tests/e2e/lifecycle_test.go @@ -27,7 +27,7 @@ func TestNamespaceLifecycleAutomation(t *testing.T) { t.Run("ReferenceCounting", func(t *testing.T) { // Start a process that exits quickly cmd1 := exec.Command(binaryPath, "--profile", profile, "--", "sleep", "0.1") - cmd1.Env = append(os.Environ(), fmt.Sprintf("WG_WRAP_RUNTIME_DIR=%s", tmpRuntimeDir)) + cmd1.Env = append(os.Environ(), fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir)) if err := cmd1.Start(); err != nil { t.Fatalf("Failed to start cmd1: %v", err) } @@ -46,7 +46,7 @@ func TestNamespaceLifecycleAutomation(t *testing.T) { // Start a second process using the same profile cmd2 := exec.Command(binaryPath, "--profile", profile, "--", "sleep", "0.1") - cmd2.Env = append(os.Environ(), fmt.Sprintf("WG_WRAP_RUNTIME_DIR=%s", tmpRuntimeDir)) + cmd2.Env = append(os.Environ(), fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir)) if err := cmd2.Start(); err != nil { t.Fatalf("Failed to start cmd2: %v", err) } |
