diff options
Diffstat (limited to 'internal/namespace')
| -rw-r--r-- | internal/namespace/lifecycle.go | 27 | ||||
| -rw-r--r-- | internal/namespace/pinning.go | 30 | ||||
| -rw-r--r-- | internal/namespace/pinning_test.go | 2 |
3 files changed, 51 insertions, 8 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go index 47a804f..99209d5 100644 --- a/internal/namespace/lifecycle.go +++ b/internal/namespace/lifecycle.go @@ -20,6 +20,11 @@ func GetPidsDirPath(pm *paths.PathManager, profile string) string { return pm.ProfilePidsDir(profile) } +// GetControllerPidPath returns the path to the file storing the PID of the tunnel controller. +func GetControllerPidPath(pm *paths.PathManager, profile string) string { + return filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "controller.pid") +} + // RegisterProcess marks the current process as using the specified profile. func RegisterProcess(pm *paths.PathManager, profile string) error { pidsDir := GetPidsDirPath(pm, profile) @@ -57,6 +62,9 @@ func PruneStalePids(pm *paths.PathManager, profile string) error { } for _, file := range files { + if file.Name() == "controller.pid" { + continue + } pid, err := strconv.Atoi(file.Name()) if err != nil { continue // Ignore non-numeric files @@ -108,3 +116,22 @@ func IsLastProcess(pm *paths.PathManager, profile string) (bool, error) { return activeCount <= 1, nil } + +// SetControllerPid records the current process as the owner of the namespace. +func SetControllerPid(pm *paths.PathManager, profile string) error { + path := GetControllerPidPath(pm, profile) + if err := os.WriteFile(path, []byte(strconv.Itoa(os.Getpid())), 0644); err != nil { + return fmt.Errorf("failed to write controller pid: %w", err) + } + return nil +} + +// GetControllerPid retrieves the PID of the process responsible for cleaning up the namespace. +func GetControllerPid(pm *paths.PathManager, profile string) (int, error) { + path := GetControllerPidPath(pm, profile) + data, err := os.ReadFile(path) + if err != nil { + return 0, err + } + return strconv.Atoi(string(data)) +} diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go index a522f17..e257187 100644 --- a/internal/namespace/pinning.go +++ b/internal/namespace/pinning.go @@ -9,9 +9,11 @@ import ( "strconv" "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" + "golang.org/x/sys/unix" ) -// PinNamespace touches the namespace path to indicate it is pinned/active. +// PinNamespace binds the current network namespace to the profile's namespace path. +// This prevents the kernel from destroying the namespace when all processes exit. func PinNamespace(pm *paths.PathManager, profile string) error { nsPath := GetProfileNamespacePath(pm, profile) profilesDir := filepath.Dir(nsPath) @@ -19,15 +21,21 @@ func PinNamespace(pm *paths.PathManager, profile string) error { return fmt.Errorf("failed to create profiles directory: %w", err) } - // We write a placeholder file to indicate the profile namespace is pinned. - if err := os.WriteFile(nsPath, []byte("active"), 0644); err != nil { + // 1. Create an empty file to serve as the mount point + if err := os.WriteFile(nsPath, []byte(""), 0644); err != nil { return fmt.Errorf("failed to create namespace pin file: %w", err) } + + // 2. Bind-mount the current network namespace to the file. + // This increments the kernel's reference count for the namespace. + if err := unix.Mount("/proc/self/ns/net", nsPath, "", unix.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to bind-mount network namespace: %w", err) + } + return nil } -// UnpinNamespace removes the pinned namespace file from the filesystem. -// This allows the namespace to be destroyed once the last process exits. +// UnpinNamespace unmounts and removes the pinned namespace file. func UnpinNamespace(pm *paths.PathManager, profile string) error { nsPath := GetProfileNamespacePath(pm, profile) @@ -35,13 +43,19 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return nil } - pidsDir := GetPidsDirPath(pm, profile) + // 1. Unmount the namespace first. + // If this is the last reference to the namespace, the kernel will destroy it. + if err := unix.Unmount(nsPath, 0); err != nil { + return fmt.Errorf("failed to unmount namespace %s: %w", nsPath, err) + } - // Unlink the namespace file + // 2. Remove the mount point file. if err := os.Remove(nsPath); err != nil { - return fmt.Errorf("failed to unpin namespace %s: %w", nsPath, err) + return fmt.Errorf("failed to remove pin file %s: %w", nsPath, err) } + pidsDir := GetPidsDirPath(pm, profile) + // Try to remove pids directory and empty parent directories _ = os.Remove(pidsDir) _ = os.Remove(filepath.Dir(pidsDir)) diff --git a/internal/namespace/pinning_test.go b/internal/namespace/pinning_test.go index c65e1b1..18aba00 100644 --- a/internal/namespace/pinning_test.go +++ b/internal/namespace/pinning_test.go @@ -1,3 +1,5 @@ +//go:build linux && integration + package namespace import ( |
