summaryrefslogtreecommitdiff
path: root/internal/namespace
diff options
context:
space:
mode:
Diffstat (limited to 'internal/namespace')
-rw-r--r--internal/namespace/lifecycle.go27
-rw-r--r--internal/namespace/pinning.go30
-rw-r--r--internal/namespace/pinning_test.go2
3 files changed, 51 insertions, 8 deletions
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go
index 47a804f..99209d5 100644
--- a/internal/namespace/lifecycle.go
+++ b/internal/namespace/lifecycle.go
@@ -20,6 +20,11 @@ func GetPidsDirPath(pm *paths.PathManager, profile string) string {
return pm.ProfilePidsDir(profile)
}
+// GetControllerPidPath returns the path to the file storing the PID of the tunnel controller.
+func GetControllerPidPath(pm *paths.PathManager, profile string) string {
+ return filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "controller.pid")
+}
+
// RegisterProcess marks the current process as using the specified profile.
func RegisterProcess(pm *paths.PathManager, profile string) error {
pidsDir := GetPidsDirPath(pm, profile)
@@ -57,6 +62,9 @@ func PruneStalePids(pm *paths.PathManager, profile string) error {
}
for _, file := range files {
+ if file.Name() == "controller.pid" {
+ continue
+ }
pid, err := strconv.Atoi(file.Name())
if err != nil {
continue // Ignore non-numeric files
@@ -108,3 +116,22 @@ func IsLastProcess(pm *paths.PathManager, profile string) (bool, error) {
return activeCount <= 1, nil
}
+
+// SetControllerPid records the current process as the owner of the namespace.
+func SetControllerPid(pm *paths.PathManager, profile string) error {
+ path := GetControllerPidPath(pm, profile)
+ if err := os.WriteFile(path, []byte(strconv.Itoa(os.Getpid())), 0644); err != nil {
+ return fmt.Errorf("failed to write controller pid: %w", err)
+ }
+ return nil
+}
+
+// GetControllerPid retrieves the PID of the process responsible for cleaning up the namespace.
+func GetControllerPid(pm *paths.PathManager, profile string) (int, error) {
+ path := GetControllerPidPath(pm, profile)
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(string(data))
+}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index a522f17..e257187 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -9,9 +9,11 @@ import (
"strconv"
"git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
)
-// PinNamespace touches the namespace path to indicate it is pinned/active.
+// PinNamespace binds the current network namespace to the profile's namespace path.
+// This prevents the kernel from destroying the namespace when all processes exit.
func PinNamespace(pm *paths.PathManager, profile string) error {
nsPath := GetProfileNamespacePath(pm, profile)
profilesDir := filepath.Dir(nsPath)
@@ -19,15 +21,21 @@ func PinNamespace(pm *paths.PathManager, profile string) error {
return fmt.Errorf("failed to create profiles directory: %w", err)
}
- // We write a placeholder file to indicate the profile namespace is pinned.
- if err := os.WriteFile(nsPath, []byte("active"), 0644); err != nil {
+ // 1. Create an empty file to serve as the mount point
+ if err := os.WriteFile(nsPath, []byte(""), 0644); err != nil {
return fmt.Errorf("failed to create namespace pin file: %w", err)
}
+
+ // 2. Bind-mount the current network namespace to the file.
+ // This increments the kernel's reference count for the namespace.
+ if err := unix.Mount("/proc/self/ns/net", nsPath, "", unix.MS_BIND, ""); err != nil {
+ return fmt.Errorf("failed to bind-mount network namespace: %w", err)
+ }
+
return nil
}
-// UnpinNamespace removes the pinned namespace file from the filesystem.
-// This allows the namespace to be destroyed once the last process exits.
+// UnpinNamespace unmounts and removes the pinned namespace file.
func UnpinNamespace(pm *paths.PathManager, profile string) error {
nsPath := GetProfileNamespacePath(pm, profile)
@@ -35,13 +43,19 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return nil
}
- pidsDir := GetPidsDirPath(pm, profile)
+ // 1. Unmount the namespace first.
+ // If this is the last reference to the namespace, the kernel will destroy it.
+ if err := unix.Unmount(nsPath, 0); err != nil {
+ return fmt.Errorf("failed to unmount namespace %s: %w", nsPath, err)
+ }
- // Unlink the namespace file
+ // 2. Remove the mount point file.
if err := os.Remove(nsPath); err != nil {
- return fmt.Errorf("failed to unpin namespace %s: %w", nsPath, err)
+ return fmt.Errorf("failed to remove pin file %s: %w", nsPath, err)
}
+ pidsDir := GetPidsDirPath(pm, profile)
+
// Try to remove pids directory and empty parent directories
_ = os.Remove(pidsDir)
_ = os.Remove(filepath.Dir(pidsDir))
diff --git a/internal/namespace/pinning_test.go b/internal/namespace/pinning_test.go
index c65e1b1..18aba00 100644
--- a/internal/namespace/pinning_test.go
+++ b/internal/namespace/pinning_test.go
@@ -1,3 +1,5 @@
+//go:build linux && integration
+
package namespace
import (