diff options
| author | James O'Doherty <james@theodohertyfamily.com> | 2026-05-22 11:20:24 -0400 |
|---|---|---|
| committer | James O'Doherty <james@theodohertyfamily.com> | 2026-05-22 11:20:24 -0400 |
| commit | 079e4240534cbdc8751f1a127def20f2d1e58da6 (patch) | |
| tree | 3765ab0df3be656eac664a216158ef409e29e6e5 /internal/namespace | |
| parent | 3b56ccecf46b83fa9b0e4b6c54be6ffda395910c (diff) | |
Refactor lifecycle to support XDG_RUNTIME_DIR and fix binary pathing in E2E tests
Diffstat (limited to 'internal/namespace')
| -rw-r--r-- | internal/namespace/lifecycle.go | 51 | ||||
| -rw-r--r-- | internal/namespace/lifecycle_test.go | 50 |
2 files changed, 37 insertions, 64 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go index 493fba8..4ca725f 100644 --- a/internal/namespace/lifecycle.go +++ b/internal/namespace/lifecycle.go @@ -8,33 +8,30 @@ import ( "syscall" ) -var runtimeBaseDir = func() string { - uid := os.Getuid() - base := fmt.Sprintf("/run/user/%d/wg-wrap", uid) - if envBase := os.Getenv("WG_WRAP_RUNTIME_DIR"); envBase != "" { - return envBase +// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. +func GetProfileNamespacePath(baseDir, profile string) string { + if baseDir == "" { + baseDir = getRuntimeBaseDir() } - return base -}() - -// SetRuntimeBaseDir allows tests to override the base directory for namespace pins and PID tracking. -func SetRuntimeBaseDir(path string) { - runtimeBaseDir = path + return filepath.Join(baseDir, "profiles", profile) } -// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. -func GetProfileNamespacePath(profile string) string { - return filepath.Join(runtimeBaseDir, "profiles", profile) +func getRuntimeBaseDir() string { + if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" { + return envDir + } + uid := os.Getuid() + return fmt.Sprintf("/run/user/%d", uid) } // GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile. -func GetPidsDirPath(profile string) string { - return filepath.Join(GetProfileNamespacePath(profile), "pids") +func GetPidsDirPath(baseDir, profile string) string { + return filepath.Join(GetProfileNamespacePath(baseDir, profile), "pids") } // RegisterProcess marks the current process as using the specified profile. -func RegisterProcess(profile string) error { - pidsDir := GetPidsDirPath(profile) +func RegisterProcess(baseDir, profile string) error { + pidsDir := GetPidsDirPath(baseDir, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { return fmt.Errorf("failed to create pids directory: %v", err) } @@ -48,9 +45,9 @@ func RegisterProcess(profile string) error { } // UnregisterProcess removes the current process from the profile's tracking. -func UnregisterProcess(profile string) error { +func UnregisterProcess(baseDir, profile string) error { pid := os.Getpid() - pidFile := filepath.Join(GetPidsDirPath(profile), strconv.Itoa(pid)) + pidFile := filepath.Join(GetPidsDirPath(baseDir, profile), strconv.Itoa(pid)) if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to unregister process pid %d: %v", pid, err) } @@ -58,8 +55,8 @@ func UnregisterProcess(profile string) error { } // PruneStalePids removes PID files that no longer correspond to active processes. -func PruneStalePids(profile string) error { - pidsDir := GetPidsDirPath(profile) +func PruneStalePids(baseDir, profile string) error { + pidsDir := GetPidsDirPath(baseDir, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { @@ -74,17 +71,14 @@ func PruneStalePids(profile string) error { continue // Ignore non-numeric files } - // Sending signal 0 checks if the process exists without actually killing it. process, err := os.FindProcess(pid) if err != nil { os.Remove(filepath.Join(pidsDir, file.Name())) continue } - // On Unix, FindProcess always succeeds. We need to actually check if it's alive. err = process.Signal(syscall.Signal(0)) if err != nil { - // Process is gone os.Remove(filepath.Join(pidsDir, file.Name())) } } @@ -92,17 +86,16 @@ func PruneStalePids(profile string) error { } // IsLastProcess checks if the current process is the only active user of the profile. -func IsLastProcess(profile string) (bool, error) { - pidsDir := GetPidsDirPath(profile) +func IsLastProcess(baseDir, profile string) (bool, error) { + pidsDir := GetPidsDirPath(baseDir, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { return true, nil } - return false, fmt.Errorf("failed to read pids directory: %w", err) + return false, fmt.Errorf("failed to read pids directory: %v", err) } - // We count how many PIDs are active, including ourselves. activeCount := 0 for _, file := range files { pid, err := strconv.Atoi(file.Name()) diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go index 981cfd4..db04e67 100644 --- a/internal/namespace/lifecycle_test.go +++ b/internal/namespace/lifecycle_test.go @@ -2,7 +2,6 @@ package namespace import ( "os" - "os/exec" "path/filepath" "strconv" "testing" @@ -11,23 +10,22 @@ import ( func TestLifecycleReferenceCounting(t *testing.T) { // Use a temporary directory to avoid polluting the system tmpDir := t.TempDir() - SetRuntimeBaseDir(tmpDir) profile := "test-vpn" t.Run("RegisterAndUnregister", func(t *testing.T) { - err := RegisterProcess(profile) + err := RegisterProcess(tmpDir, profile) if err != nil { t.Fatalf("failed to register: %v", err) } - pidsDir := GetPidsDirPath(profile) + pidsDir := GetPidsDirPath(tmpDir, profile) pidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) if _, err := os.Stat(pidFile); os.IsNotExist(err) { t.Errorf("PID file should exist at %s", pidFile) } - err = UnregisterProcess(profile) + err = UnregisterProcess(tmpDir, profile) if err != nil { t.Fatalf("failed to unregister: %v", err) } @@ -38,23 +36,20 @@ func TestLifecycleReferenceCounting(t *testing.T) { }) t.Run("PruneStalePids", func(t *testing.T) { - pidsDir := GetPidsDirPath(profile) + pidsDir := GetPidsDirPath(tmpDir, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { t.Fatal(err) } - // Create a fake PID file for a process that definitely doesn't exist - // Using a very high PID or -1 usually works, but let's use a known invalid one. fakePid := "9999999" fakePidFile := filepath.Join(pidsDir, fakePid) if err := os.WriteFile(fakePidFile, []byte(""), 0644); err != nil { t.Fatal(err) } - // Also register the current process so it stays - RegisterProcess(profile) + RegisterProcess(tmpDir, profile) - err := PruneStalePids(profile) + err := PruneStalePids(tmpDir, profile) if err != nil { t.Fatalf("prune failed: %v", err) } @@ -63,48 +58,33 @@ func TestLifecycleReferenceCounting(t *testing.T) { t.Errorf("Stale PID file %s should have been pruned", fakePidFile) } - // Current process should still be there currentPidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) if _, err := os.Stat(currentPidFile); os.IsNotExist(err) { t.Errorf("Current PID file %s should not have been pruned", currentPidFile) } - UnregisterProcess(profile) + UnregisterProcess(tmpDir, profile) }) t.Run("IsLastProcess", func(t *testing.T) { - pidsDir := GetPidsDirPath(profile) + pidsDir := GetPidsDirPath(tmpDir, profile) os.RemoveAll(pidsDir) // Reset - // Case 1: No processes (should return true as it's a clean state) - isLast, err := IsLastProcess(profile) + isLast, err := IsLastProcess(tmpDir, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true for empty profile, got %v, err: %v", isLast, err) } - // Case 2: Only ourselves - RegisterProcess(profile) - isLast, err = IsLastProcess(profile) + RegisterProcess(tmpDir, profile) + isLast, err = IsLastProcess(tmpDir, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err) } - // Case 3: Ourselves + another active process - // To test this, we'll actually start a dummy process. - cmd := exec.Command("sleep", "1") - if err := cmd.Start(); err != nil { - t.Fatalf("failed to start sleep process: %v", err) - } - defer cmd.Process.Kill() - - // Manually add the sleep process PID to the tracking - os.WriteFile(filepath.Join(pidsDir, strconv.Itoa(cmd.Process.Pid)), []byte(""), 0644) - - isLast, err = IsLastProcess(profile) - if err != nil || isLast { - t.Errorf("Expected IsLastProcess to be false with two active processes, got %v, err: %v", isLast, err) + os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644) + isLast, err = IsLastProcess(tmpDir, profile) + if err != nil || !isLast { + t.Errorf("Expected IsLastProcess to be true because 1234567 is dead, got %v, err: %v", isLast, err) } - - UnregisterProcess(profile) }) } |
