summaryrefslogtreecommitdiff
path: root/internal/namespace/lifecycle.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/namespace/lifecycle.go')
-rw-r--r--internal/namespace/lifecycle.go122
1 files changed, 122 insertions, 0 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go
new file mode 100644
index 0000000..493fba8
--- /dev/null
+++ b/internal/namespace/lifecycle.go
@@ -0,0 +1,122 @@
+package namespace
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strconv"
+ "syscall"
+)
+
+var runtimeBaseDir = func() string {
+ uid := os.Getuid()
+ base := fmt.Sprintf("/run/user/%d/wg-wrap", uid)
+ if envBase := os.Getenv("WG_WRAP_RUNTIME_DIR"); envBase != "" {
+ return envBase
+ }
+ return base
+}()
+
+// SetRuntimeBaseDir allows tests to override the base directory for namespace pins and PID tracking.
+func SetRuntimeBaseDir(path string) {
+ runtimeBaseDir = path
+}
+
+// GetProfileNamespacePath returns the path to the pinned namespace file for a profile.
+func GetProfileNamespacePath(profile string) string {
+ return filepath.Join(runtimeBaseDir, "profiles", profile)
+}
+
+// GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile.
+func GetPidsDirPath(profile string) string {
+ return filepath.Join(GetProfileNamespacePath(profile), "pids")
+}
+
+// RegisterProcess marks the current process as using the specified profile.
+func RegisterProcess(profile string) error {
+ pidsDir := GetPidsDirPath(profile)
+ if err := os.MkdirAll(pidsDir, 0755); err != nil {
+ return fmt.Errorf("failed to create pids directory: %v", err)
+ }
+
+ pid := os.Getpid()
+ pidFile := filepath.Join(pidsDir, strconv.Itoa(pid))
+ if err := os.WriteFile(pidFile, []byte(""), 0644); err != nil {
+ return fmt.Errorf("failed to register process pid %d: %v", pid, err)
+ }
+ return nil
+}
+
+// UnregisterProcess removes the current process from the profile's tracking.
+func UnregisterProcess(profile string) error {
+ pid := os.Getpid()
+ pidFile := filepath.Join(GetPidsDirPath(profile), strconv.Itoa(pid))
+ if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to unregister process pid %d: %v", pid, err)
+ }
+ return nil
+}
+
+// PruneStalePids removes PID files that no longer correspond to active processes.
+func PruneStalePids(profile string) error {
+ pidsDir := GetPidsDirPath(profile)
+ files, err := os.ReadDir(pidsDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return fmt.Errorf("failed to read pids directory: %v", err)
+ }
+
+ for _, file := range files {
+ pid, err := strconv.Atoi(file.Name())
+ if err != nil {
+ continue // Ignore non-numeric files
+ }
+
+ // Sending signal 0 checks if the process exists without actually killing it.
+ process, err := os.FindProcess(pid)
+ if err != nil {
+ os.Remove(filepath.Join(pidsDir, file.Name()))
+ continue
+ }
+
+ // On Unix, FindProcess always succeeds. We need to actually check if it's alive.
+ err = process.Signal(syscall.Signal(0))
+ if err != nil {
+ // Process is gone
+ os.Remove(filepath.Join(pidsDir, file.Name()))
+ }
+ }
+ return nil
+}
+
+// IsLastProcess checks if the current process is the only active user of the profile.
+func IsLastProcess(profile string) (bool, error) {
+ pidsDir := GetPidsDirPath(profile)
+ files, err := os.ReadDir(pidsDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return true, nil
+ }
+ return false, fmt.Errorf("failed to read pids directory: %w", err)
+ }
+
+ // We count how many PIDs are active, including ourselves.
+ activeCount := 0
+ for _, file := range files {
+ pid, err := strconv.Atoi(file.Name())
+ if err != nil {
+ continue
+ }
+ process, err := os.FindProcess(pid)
+ if err != nil {
+ continue
+ }
+ if process.Signal(syscall.Signal(0)) == nil {
+ activeCount++
+ }
+ }
+
+ return activeCount <= 1, nil
+}