diff options
| author | James O'Doherty <james@theodohertyfamily.com> | 2026-05-22 11:12:21 -0400 |
|---|---|---|
| committer | James O'Doherty <james@theodohertyfamily.com> | 2026-05-22 11:12:21 -0400 |
| commit | 3b56ccecf46b83fa9b0e4b6c54be6ffda395910c (patch) | |
| tree | 2a4f7b8598cfdfaec2627ec13d4bfb30c14e28fd /internal/namespace | |
| parent | cefff85a054d64f124aa1f3e91b9425695aa210b (diff) | |
Implement automatic namespace lifecycle cleanup with last-man-out reference counting
Diffstat (limited to 'internal/namespace')
| -rw-r--r-- | internal/namespace/lifecycle.go | 122 | ||||
| -rw-r--r-- | internal/namespace/lifecycle_test.go | 110 |
2 files changed, 232 insertions, 0 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go new file mode 100644 index 0000000..493fba8 --- /dev/null +++ b/internal/namespace/lifecycle.go @@ -0,0 +1,122 @@ +package namespace + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "syscall" +) + +var runtimeBaseDir = func() string { + uid := os.Getuid() + base := fmt.Sprintf("/run/user/%d/wg-wrap", uid) + if envBase := os.Getenv("WG_WRAP_RUNTIME_DIR"); envBase != "" { + return envBase + } + return base +}() + +// SetRuntimeBaseDir allows tests to override the base directory for namespace pins and PID tracking. +func SetRuntimeBaseDir(path string) { + runtimeBaseDir = path +} + +// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. +func GetProfileNamespacePath(profile string) string { + return filepath.Join(runtimeBaseDir, "profiles", profile) +} + +// GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile. +func GetPidsDirPath(profile string) string { + return filepath.Join(GetProfileNamespacePath(profile), "pids") +} + +// RegisterProcess marks the current process as using the specified profile. +func RegisterProcess(profile string) error { + pidsDir := GetPidsDirPath(profile) + if err := os.MkdirAll(pidsDir, 0755); err != nil { + return fmt.Errorf("failed to create pids directory: %v", err) + } + + pid := os.Getpid() + pidFile := filepath.Join(pidsDir, strconv.Itoa(pid)) + if err := os.WriteFile(pidFile, []byte(""), 0644); err != nil { + return fmt.Errorf("failed to register process pid %d: %v", pid, err) + } + return nil +} + +// UnregisterProcess removes the current process from the profile's tracking. +func UnregisterProcess(profile string) error { + pid := os.Getpid() + pidFile := filepath.Join(GetPidsDirPath(profile), strconv.Itoa(pid)) + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to unregister process pid %d: %v", pid, err) + } + return nil +} + +// PruneStalePids removes PID files that no longer correspond to active processes. +func PruneStalePids(profile string) error { + pidsDir := GetPidsDirPath(profile) + files, err := os.ReadDir(pidsDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read pids directory: %v", err) + } + + for _, file := range files { + pid, err := strconv.Atoi(file.Name()) + if err != nil { + continue // Ignore non-numeric files + } + + // Sending signal 0 checks if the process exists without actually killing it. + process, err := os.FindProcess(pid) + if err != nil { + os.Remove(filepath.Join(pidsDir, file.Name())) + continue + } + + // On Unix, FindProcess always succeeds. We need to actually check if it's alive. + err = process.Signal(syscall.Signal(0)) + if err != nil { + // Process is gone + os.Remove(filepath.Join(pidsDir, file.Name())) + } + } + return nil +} + +// IsLastProcess checks if the current process is the only active user of the profile. +func IsLastProcess(profile string) (bool, error) { + pidsDir := GetPidsDirPath(profile) + files, err := os.ReadDir(pidsDir) + if err != nil { + if os.IsNotExist(err) { + return true, nil + } + return false, fmt.Errorf("failed to read pids directory: %w", err) + } + + // We count how many PIDs are active, including ourselves. + activeCount := 0 + for _, file := range files { + pid, err := strconv.Atoi(file.Name()) + if err != nil { + continue + } + process, err := os.FindProcess(pid) + if err != nil { + continue + } + if process.Signal(syscall.Signal(0)) == nil { + activeCount++ + } + } + + return activeCount <= 1, nil +} diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go new file mode 100644 index 0000000..981cfd4 --- /dev/null +++ b/internal/namespace/lifecycle_test.go @@ -0,0 +1,110 @@ +package namespace + +import ( + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" +) + +func TestLifecycleReferenceCounting(t *testing.T) { + // Use a temporary directory to avoid polluting the system + tmpDir := t.TempDir() + SetRuntimeBaseDir(tmpDir) + + profile := "test-vpn" + + t.Run("RegisterAndUnregister", func(t *testing.T) { + err := RegisterProcess(profile) + if err != nil { + t.Fatalf("failed to register: %v", err) + } + + pidsDir := GetPidsDirPath(profile) + pidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) + if _, err := os.Stat(pidFile); os.IsNotExist(err) { + t.Errorf("PID file should exist at %s", pidFile) + } + + err = UnregisterProcess(profile) + if err != nil { + t.Fatalf("failed to unregister: %v", err) + } + + if _, err := os.Stat(pidFile); err == nil { + t.Errorf("PID file should have been removed at %s", pidFile) + } + }) + + t.Run("PruneStalePids", func(t *testing.T) { + pidsDir := GetPidsDirPath(profile) + if err := os.MkdirAll(pidsDir, 0755); err != nil { + t.Fatal(err) + } + + // Create a fake PID file for a process that definitely doesn't exist + // Using a very high PID or -1 usually works, but let's use a known invalid one. + fakePid := "9999999" + fakePidFile := filepath.Join(pidsDir, fakePid) + if err := os.WriteFile(fakePidFile, []byte(""), 0644); err != nil { + t.Fatal(err) + } + + // Also register the current process so it stays + RegisterProcess(profile) + + err := PruneStalePids(profile) + if err != nil { + t.Fatalf("prune failed: %v", err) + } + + if _, err := os.Stat(fakePidFile); err == nil { + t.Errorf("Stale PID file %s should have been pruned", fakePidFile) + } + + // Current process should still be there + currentPidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) + if _, err := os.Stat(currentPidFile); os.IsNotExist(err) { + t.Errorf("Current PID file %s should not have been pruned", currentPidFile) + } + + UnregisterProcess(profile) + }) + + t.Run("IsLastProcess", func(t *testing.T) { + pidsDir := GetPidsDirPath(profile) + os.RemoveAll(pidsDir) // Reset + + // Case 1: No processes (should return true as it's a clean state) + isLast, err := IsLastProcess(profile) + if err != nil || !isLast { + t.Errorf("Expected IsLastProcess to be true for empty profile, got %v, err: %v", isLast, err) + } + + // Case 2: Only ourselves + RegisterProcess(profile) + isLast, err = IsLastProcess(profile) + if err != nil || !isLast { + t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err) + } + + // Case 3: Ourselves + another active process + // To test this, we'll actually start a dummy process. + cmd := exec.Command("sleep", "1") + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start sleep process: %v", err) + } + defer cmd.Process.Kill() + + // Manually add the sleep process PID to the tracking + os.WriteFile(filepath.Join(pidsDir, strconv.Itoa(cmd.Process.Pid)), []byte(""), 0644) + + isLast, err = IsLastProcess(profile) + if err != nil || isLast { + t.Errorf("Expected IsLastProcess to be false with two active processes, got %v, err: %v", isLast, err) + } + + UnregisterProcess(profile) + }) +} |
