summaryrefslogtreecommitdiff
path: root/internal/namespace/pinning.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/namespace/pinning.go')
-rw-r--r--internal/namespace/pinning.go97
1 files changed, 92 insertions, 5 deletions
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index cd81a38..7976937 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -1,26 +1,42 @@
+//go:build linux
+
package namespace
import (
"fmt"
"os"
+ "path/filepath"
+ "strconv"
+ "syscall"
"git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
)
+// PinNamespace touches the namespace path to indicate it is pinned/active.
+func PinNamespace(pm *paths.PathManager, profile string) error {
+ nsPath := GetProfileNamespacePath(pm, profile)
+ profilesDir := filepath.Dir(nsPath)
+ if err := os.MkdirAll(profilesDir, 0755); err != nil {
+ return fmt.Errorf("failed to create profiles directory: %w", err)
+ }
+
+ // We write a placeholder file to indicate the profile namespace is pinned.
+ if err := os.WriteFile(nsPath, []byte("active"), 0644); err != nil {
+ return fmt.Errorf("failed to create namespace pin file: %w", err)
+ }
+ return nil
+}
+
// UnpinNamespace removes the pinned namespace file from the filesystem.
// This allows the namespace to be destroyed once the last process exits.
func UnpinNamespace(pm *paths.PathManager, profile string) error {
nsPath := GetProfileNamespacePath(pm, profile)
- // We only want to unpin if there are no more active processes.
- // The caller (cli.ExecuteCommand) is responsible for calling this
- // when IsLastProcess returns true.
-
if _, err := os.Stat(nsPath); os.IsNotExist(err) {
return nil
}
- // We also want to remove the pids directory if it's empty.
pidsDir := GetPidsDirPath(pm, profile)
// Unlink the namespace file
@@ -33,3 +49,74 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return nil
}
+
+// JoinExistingNamespace attempts to join the namespaces (user, mount, net)
+// of an already active process running under the same profile.
+// Returns true if a namespace was successfully joined, false if no active namespace exists.
+func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) {
+ if err := PruneStalePids(pm, profile); err != nil {
+ return false, fmt.Errorf("failed to prune stale pids: %w", err)
+ }
+
+ pidsDir := GetPidsDirPath(pm, profile)
+ files, err := os.ReadDir(pidsDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, fmt.Errorf("failed to read pids dir: %w", err)
+ }
+
+ var activePid int
+ for _, file := range files {
+ pid, err := strconv.Atoi(file.Name())
+ if err != nil {
+ continue
+ }
+ // Since we already pruned stale pids, the first file we find is an active pid!
+ activePid = pid
+ break
+ }
+
+ if activePid == 0 {
+ return false, nil
+ }
+
+ // Join User Namespace first
+ userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid)
+ userFd, err := os.Open(userNsPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to open user namespace: %w", err)
+ }
+ defer func() { _ = userFd.Close() }()
+
+ if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil {
+ return false, fmt.Errorf("failed to join user namespace: %w", err)
+ }
+
+ // Join Mount Namespace
+ mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid)
+ mntFd, err := os.Open(mntNsPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to open mount namespace: %w", err)
+ }
+ defer func() { _ = mntFd.Close() }()
+
+ if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil {
+ return false, fmt.Errorf("failed to join mount namespace: %w", err)
+ }
+
+ // Join Network Namespace
+ netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid)
+ netFd, err := os.Open(netNsPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to open network namespace: %w", err)
+ }
+ defer func() { _ = netFd.Close() }()
+
+ if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil {
+ return false, fmt.Errorf("failed to join network namespace: %w", err)
+ }
+
+ return true, nil
+}