diff options
Diffstat (limited to 'internal/namespace/lifecycle.go')
| -rw-r--r-- | internal/namespace/lifecycle.go | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go new file mode 100644 index 0000000..493fba8 --- /dev/null +++ b/internal/namespace/lifecycle.go @@ -0,0 +1,122 @@ +package namespace + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "syscall" +) + +var runtimeBaseDir = func() string { + uid := os.Getuid() + base := fmt.Sprintf("/run/user/%d/wg-wrap", uid) + if envBase := os.Getenv("WG_WRAP_RUNTIME_DIR"); envBase != "" { + return envBase + } + return base +}() + +// SetRuntimeBaseDir allows tests to override the base directory for namespace pins and PID tracking. +func SetRuntimeBaseDir(path string) { + runtimeBaseDir = path +} + +// GetProfileNamespacePath returns the path to the pinned namespace file for a profile. +func GetProfileNamespacePath(profile string) string { + return filepath.Join(runtimeBaseDir, "profiles", profile) +} + +// GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile. +func GetPidsDirPath(profile string) string { + return filepath.Join(GetProfileNamespacePath(profile), "pids") +} + +// RegisterProcess marks the current process as using the specified profile. +func RegisterProcess(profile string) error { + pidsDir := GetPidsDirPath(profile) + if err := os.MkdirAll(pidsDir, 0755); err != nil { + return fmt.Errorf("failed to create pids directory: %v", err) + } + + pid := os.Getpid() + pidFile := filepath.Join(pidsDir, strconv.Itoa(pid)) + if err := os.WriteFile(pidFile, []byte(""), 0644); err != nil { + return fmt.Errorf("failed to register process pid %d: %v", pid, err) + } + return nil +} + +// UnregisterProcess removes the current process from the profile's tracking. +func UnregisterProcess(profile string) error { + pid := os.Getpid() + pidFile := filepath.Join(GetPidsDirPath(profile), strconv.Itoa(pid)) + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to unregister process pid %d: %v", pid, err) + } + return nil +} + +// PruneStalePids removes PID files that no longer correspond to active processes. +func PruneStalePids(profile string) error { + pidsDir := GetPidsDirPath(profile) + files, err := os.ReadDir(pidsDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read pids directory: %v", err) + } + + for _, file := range files { + pid, err := strconv.Atoi(file.Name()) + if err != nil { + continue // Ignore non-numeric files + } + + // Sending signal 0 checks if the process exists without actually killing it. + process, err := os.FindProcess(pid) + if err != nil { + os.Remove(filepath.Join(pidsDir, file.Name())) + continue + } + + // On Unix, FindProcess always succeeds. We need to actually check if it's alive. + err = process.Signal(syscall.Signal(0)) + if err != nil { + // Process is gone + os.Remove(filepath.Join(pidsDir, file.Name())) + } + } + return nil +} + +// IsLastProcess checks if the current process is the only active user of the profile. +func IsLastProcess(profile string) (bool, error) { + pidsDir := GetPidsDirPath(profile) + files, err := os.ReadDir(pidsDir) + if err != nil { + if os.IsNotExist(err) { + return true, nil + } + return false, fmt.Errorf("failed to read pids directory: %w", err) + } + + // We count how many PIDs are active, including ourselves. + activeCount := 0 + for _, file := range files { + pid, err := strconv.Atoi(file.Name()) + if err != nil { + continue + } + process, err := os.FindProcess(pid) + if err != nil { + continue + } + if process.Signal(syscall.Signal(0)) == nil { + activeCount++ + } + } + + return activeCount <= 1, nil +} |
