From 079e4240534cbdc8751f1a127def20f2d1e58da6 Mon Sep 17 00:00:00 2001 From: James O'Doherty Date: Fri, 22 May 2026 11:20:24 -0400 Subject: Refactor lifecycle to support XDG_RUNTIME_DIR and fix binary pathing in E2E tests --- internal/namespace/lifecycle.go | 51 ++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 29 deletions(-) (limited to 'internal/namespace/lifecycle.go') diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go index 493fba8..4ca725f 100644 --- a/internal/namespace/lifecycle.go +++ b/internal/namespace/lifecycle.go @@ -8,33 +8,30 @@ import ( "syscall" ) -var runtimeBaseDir = func() string { - uid := os.Getuid() - base := fmt.Sprintf("/run/user/%d/wg-wrap", uid) - if envBase := os.Getenv("WG_WRAP_RUNTIME_DIR"); envBase != "" { - return envBase +// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. +func GetProfileNamespacePath(baseDir, profile string) string { + if baseDir == "" { + baseDir = getRuntimeBaseDir() } - return base -}() - -// SetRuntimeBaseDir allows tests to override the base directory for namespace pins and PID tracking. -func SetRuntimeBaseDir(path string) { - runtimeBaseDir = path + return filepath.Join(baseDir, "profiles", profile) } -// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. -func GetProfileNamespacePath(profile string) string { - return filepath.Join(runtimeBaseDir, "profiles", profile) +func getRuntimeBaseDir() string { + if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" { + return envDir + } + uid := os.Getuid() + return fmt.Sprintf("/run/user/%d", uid) } // GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile. -func GetPidsDirPath(profile string) string { - return filepath.Join(GetProfileNamespacePath(profile), "pids") +func GetPidsDirPath(baseDir, profile string) string { + return filepath.Join(GetProfileNamespacePath(baseDir, profile), "pids") } // RegisterProcess marks the current process as using the specified profile. -func RegisterProcess(profile string) error { - pidsDir := GetPidsDirPath(profile) +func RegisterProcess(baseDir, profile string) error { + pidsDir := GetPidsDirPath(baseDir, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { return fmt.Errorf("failed to create pids directory: %v", err) } @@ -48,9 +45,9 @@ func RegisterProcess(profile string) error { } // UnregisterProcess removes the current process from the profile's tracking. -func UnregisterProcess(profile string) error { +func UnregisterProcess(baseDir, profile string) error { pid := os.Getpid() - pidFile := filepath.Join(GetPidsDirPath(profile), strconv.Itoa(pid)) + pidFile := filepath.Join(GetPidsDirPath(baseDir, profile), strconv.Itoa(pid)) if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to unregister process pid %d: %v", pid, err) } @@ -58,8 +55,8 @@ func UnregisterProcess(profile string) error { } // PruneStalePids removes PID files that no longer correspond to active processes. -func PruneStalePids(profile string) error { - pidsDir := GetPidsDirPath(profile) +func PruneStalePids(baseDir, profile string) error { + pidsDir := GetPidsDirPath(baseDir, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { @@ -74,17 +71,14 @@ func PruneStalePids(profile string) error { continue // Ignore non-numeric files } - // Sending signal 0 checks if the process exists without actually killing it. process, err := os.FindProcess(pid) if err != nil { os.Remove(filepath.Join(pidsDir, file.Name())) continue } - // On Unix, FindProcess always succeeds. We need to actually check if it's alive. err = process.Signal(syscall.Signal(0)) if err != nil { - // Process is gone os.Remove(filepath.Join(pidsDir, file.Name())) } } @@ -92,17 +86,16 @@ func PruneStalePids(profile string) error { } // IsLastProcess checks if the current process is the only active user of the profile. -func IsLastProcess(profile string) (bool, error) { - pidsDir := GetPidsDirPath(profile) +func IsLastProcess(baseDir, profile string) (bool, error) { + pidsDir := GetPidsDirPath(baseDir, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { return true, nil } - return false, fmt.Errorf("failed to read pids directory: %w", err) + return false, fmt.Errorf("failed to read pids directory: %v", err) } - // We count how many PIDs are active, including ourselves. activeCount := 0 for _, file := range files { pid, err := strconv.Atoi(file.Name()) -- cgit v1.2.3