diff options
Diffstat (limited to 'internal/namespace')
| -rw-r--r-- | internal/namespace/lifecycle.go | 45 | ||||
| -rw-r--r-- | internal/namespace/lifecycle_test.go | 45 | ||||
| -rw-r--r-- | internal/namespace/pinning.go | 35 | ||||
| -rw-r--r-- | internal/namespace/pinning_test.go | 53 |
4 files changed, 137 insertions, 41 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go index 4ca725f..47a804f 100644 --- a/internal/namespace/lifecycle.go +++ b/internal/namespace/lifecycle.go @@ -6,32 +6,23 @@ import ( "path/filepath" "strconv" "syscall" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" ) // GetProfileNamespacePath returns the path to the pinned namespace file for a profile. -func GetProfileNamespacePath(baseDir, profile string) string { - if baseDir == "" { - baseDir = getRuntimeBaseDir() - } - return filepath.Join(baseDir, "profiles", profile) -} - -func getRuntimeBaseDir() string { - if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" { - return envDir - } - uid := os.Getuid() - return fmt.Sprintf("/run/user/%d", uid) +func GetProfileNamespacePath(pm *paths.PathManager, profile string) string { + return pm.ProfileNamespacePath(profile) } // GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile. -func GetPidsDirPath(baseDir, profile string) string { - return filepath.Join(GetProfileNamespacePath(baseDir, profile), "pids") +func GetPidsDirPath(pm *paths.PathManager, profile string) string { + return pm.ProfilePidsDir(profile) } // RegisterProcess marks the current process as using the specified profile. -func RegisterProcess(baseDir, profile string) error { - pidsDir := GetPidsDirPath(baseDir, profile) +func RegisterProcess(pm *paths.PathManager, profile string) error { + pidsDir := GetPidsDirPath(pm, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { return fmt.Errorf("failed to create pids directory: %v", err) } @@ -45,9 +36,9 @@ func RegisterProcess(baseDir, profile string) error { } // UnregisterProcess removes the current process from the profile's tracking. -func UnregisterProcess(baseDir, profile string) error { +func UnregisterProcess(pm *paths.PathManager, profile string) error { pid := os.Getpid() - pidFile := filepath.Join(GetPidsDirPath(baseDir, profile), strconv.Itoa(pid)) + pidFile := filepath.Join(GetPidsDirPath(pm, profile), strconv.Itoa(pid)) if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to unregister process pid %d: %v", pid, err) } @@ -55,8 +46,8 @@ func UnregisterProcess(baseDir, profile string) error { } // PruneStalePids removes PID files that no longer correspond to active processes. -func PruneStalePids(baseDir, profile string) error { - pidsDir := GetPidsDirPath(baseDir, profile) +func PruneStalePids(pm *paths.PathManager, profile string) error { + pidsDir := GetPidsDirPath(pm, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { @@ -73,21 +64,25 @@ func PruneStalePids(baseDir, profile string) error { process, err := os.FindProcess(pid) if err != nil { - os.Remove(filepath.Join(pidsDir, file.Name())) + if err := os.Remove(filepath.Join(pidsDir, file.Name())); err != nil { + fmt.Printf("failed to remove stale pid file %s: %v\n", file.Name(), err) + } continue } err = process.Signal(syscall.Signal(0)) if err != nil { - os.Remove(filepath.Join(pidsDir, file.Name())) + if err := os.Remove(filepath.Join(pidsDir, file.Name())); err != nil { + fmt.Printf("failed to remove stale pid file %s: %v\n", file.Name(), err) + } } } return nil } // IsLastProcess checks if the current process is the only active user of the profile. -func IsLastProcess(baseDir, profile string) (bool, error) { - pidsDir := GetPidsDirPath(baseDir, profile) +func IsLastProcess(pm *paths.PathManager, profile string) (bool, error) { + pidsDir := GetPidsDirPath(pm, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go index db04e67..230e93a 100644 --- a/internal/namespace/lifecycle_test.go +++ b/internal/namespace/lifecycle_test.go @@ -5,27 +5,30 @@ import ( "path/filepath" "strconv" "testing" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" ) func TestLifecycleReferenceCounting(t *testing.T) { // Use a temporary directory to avoid polluting the system tmpDir := t.TempDir() + pm := paths.NewPathManager("", tmpDir) profile := "test-vpn" t.Run("RegisterAndUnregister", func(t *testing.T) { - err := RegisterProcess(tmpDir, profile) + err := RegisterProcess(pm, profile) if err != nil { t.Fatalf("failed to register: %v", err) } - pidsDir := GetPidsDirPath(tmpDir, profile) + pidsDir := GetPidsDirPath(pm, profile) pidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid())) if _, err := os.Stat(pidFile); os.IsNotExist(err) { t.Errorf("PID file should exist at %s", pidFile) } - err = UnregisterProcess(tmpDir, profile) + err = UnregisterProcess(pm, profile) if err != nil { t.Fatalf("failed to unregister: %v", err) } @@ -36,20 +39,22 @@ func TestLifecycleReferenceCounting(t *testing.T) { }) t.Run("PruneStalePids", func(t *testing.T) { - pidsDir := GetPidsDirPath(tmpDir, profile) + pidsDir := GetPidsDirPath(pm, profile) if err := os.MkdirAll(pidsDir, 0755); err != nil { t.Fatal(err) } - fakePid := "9999999" + fakePid := "9999999" fakePidFile := filepath.Join(pidsDir, fakePid) if err := os.WriteFile(fakePidFile, []byte(""), 0644); err != nil { t.Fatal(err) } - RegisterProcess(tmpDir, profile) + if err := RegisterProcess(pm, profile); err != nil { + t.Fatal(err) + } - err := PruneStalePids(tmpDir, profile) + err := PruneStalePids(pm, profile) if err != nil { t.Fatalf("prune failed: %v", err) } @@ -62,27 +67,35 @@ func TestLifecycleReferenceCounting(t *testing.T) { if _, err := os.Stat(currentPidFile); os.IsNotExist(err) { t.Errorf("Current PID file %s should not have been pruned", currentPidFile) } - - UnregisterProcess(tmpDir, profile) + + if err := UnregisterProcess(pm, profile); err != nil { + t.Fatal(err) + } }) t.Run("IsLastProcess", func(t *testing.T) { - pidsDir := GetPidsDirPath(tmpDir, profile) - os.RemoveAll(pidsDir) // Reset + pidsDir := GetPidsDirPath(pm, profile) + if err := os.RemoveAll(pidsDir); err != nil { + t.Fatal(err) + } - isLast, err := IsLastProcess(tmpDir, profile) + isLast, err := IsLastProcess(pm, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true for empty profile, got %v, err: %v", isLast, err) } - RegisterProcess(tmpDir, profile) - isLast, err = IsLastProcess(tmpDir, profile) + if err := RegisterProcess(pm, profile); err != nil { + t.Fatal(err) + } + isLast, err = IsLastProcess(pm, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err) } - os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644) - isLast, err = IsLastProcess(tmpDir, profile) + if err := os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644); err != nil { + t.Fatal(err) + } + isLast, err = IsLastProcess(pm, profile) if err != nil || !isLast { t.Errorf("Expected IsLastProcess to be true because 1234567 is dead, got %v, err: %v", isLast, err) } diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go new file mode 100644 index 0000000..cd81a38 --- /dev/null +++ b/internal/namespace/pinning.go @@ -0,0 +1,35 @@ +package namespace + +import ( + "fmt" + "os" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" +) + +// UnpinNamespace removes the pinned namespace file from the filesystem. +// This allows the namespace to be destroyed once the last process exits. +func UnpinNamespace(pm *paths.PathManager, profile string) error { + nsPath := GetProfileNamespacePath(pm, profile) + + // We only want to unpin if there are no more active processes. + // The caller (cli.ExecuteCommand) is responsible for calling this + // when IsLastProcess returns true. + + if _, err := os.Stat(nsPath); os.IsNotExist(err) { + return nil + } + + // We also want to remove the pids directory if it's empty. + pidsDir := GetPidsDirPath(pm, profile) + + // Unlink the namespace file + if err := os.Remove(nsPath); err != nil { + return fmt.Errorf("failed to unpin namespace %s: %w", nsPath, err) + } + + // Try to remove pids directory + _ = os.Remove(pidsDir) + + return nil +} diff --git a/internal/namespace/pinning_test.go b/internal/namespace/pinning_test.go new file mode 100644 index 0000000..c65e1b1 --- /dev/null +++ b/internal/namespace/pinning_test.go @@ -0,0 +1,53 @@ +package namespace + +import ( + "os" + "path/filepath" + "testing" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" +) + +func TestUnpinNamespace(t *testing.T) { + tmpDir := t.TempDir() + pm := paths.NewPathManager("", tmpDir) + profile := "test-profile" + nsPath := GetProfileNamespacePath(pm, profile) + + // Create the base profiles directory first + profilesDir := filepath.Dir(nsPath) + if err := os.MkdirAll(profilesDir, 0755); err != nil { + t.Fatalf("failed to create profiles dir: %v", err) + } + + // Create dummy namespace file + if err := os.WriteFile(nsPath, []byte("dummy"), 0644); err != nil { + t.Fatalf("failed to create ns file: %v", err) + } + + pidsDir := GetPidsDirPath(pm, profile) + if err := os.MkdirAll(pidsDir, 0755); err != nil { + t.Fatalf("failed to create pids dir: %v", err) + } + + t.Run("successfully unpins", func(t *testing.T) { + err := UnpinNamespace(pm, profile) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if _, err := os.Stat(nsPath); !os.IsNotExist(err) { + t.Errorf("namespace file should have been deleted") + } + if _, err := os.Stat(pidsDir); !os.IsNotExist(err) { + t.Errorf("pids directory should have been deleted") + } + }) + + t.Run("handles non-existent namespace", func(t *testing.T) { + err := UnpinNamespace(pm, profile) + if err != nil { + t.Errorf("unexpected error when unpinning non-existent namespace: %v", err) + } + }) +} |
