summaryrefslogtreecommitdiff
path: root/internal/namespace/pinning.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/namespace/pinning.go')
-rw-r--r--internal/namespace/pinning.go66
1 files changed, 8 insertions, 58 deletions
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index 2433203..a522f17 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -6,12 +6,9 @@ import (
"fmt"
"os"
"path/filepath"
- "runtime"
"strconv"
- "syscall"
"git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
- "golang.org/x/sys/unix"
)
// PinNamespace touches the namespace path to indicate it is pinned/active.
@@ -53,77 +50,30 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return nil
}
-// JoinExistingNamespace attempts to join the namespaces (user, mount, net)
-// of an already active process running under the same profile.
-// Returns true if a namespace was successfully joined, false if no active namespace exists.
-func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) {
+// FindActiveProfilePid looks for an active PID running under the specified profile.
+// Returns 0 if no active process is found.
+func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) {
if err := PruneStalePids(pm, profile); err != nil {
- return false, fmt.Errorf("failed to prune stale pids: %w", err)
+ return 0, fmt.Errorf("failed to prune stale pids: %w", err)
}
pidsDir := GetPidsDirPath(pm, profile)
files, err := os.ReadDir(pidsDir)
if err != nil {
if os.IsNotExist(err) {
- return false, nil
+ return 0, nil
}
- return false, fmt.Errorf("failed to read pids dir: %w", err)
+ return 0, fmt.Errorf("failed to read pids dir: %w", err)
}
- var activePid int
for _, file := range files {
pid, err := strconv.Atoi(file.Name())
if err != nil {
continue
}
// Since we already pruned stale pids, the first file we find is an active pid!
- activePid = pid
- break
+ return pid, nil
}
- if activePid == 0 {
- return false, nil
- }
-
- // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread,
- // and that the thread is not reused for other goroutines (since we never unlock it).
- runtime.LockOSThread()
-
- // Join User Namespace first
- userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid)
- userFd, err := os.Open(userNsPath)
- if err != nil {
- return false, fmt.Errorf("failed to open user namespace: %w", err)
- }
- defer func() { _ = userFd.Close() }()
-
- if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil {
- return false, fmt.Errorf("failed to join user namespace: %w", err)
- }
-
- // Join Mount Namespace
- mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid)
- mntFd, err := os.Open(mntNsPath)
- if err != nil {
- return false, fmt.Errorf("failed to open mount namespace: %w", err)
- }
- defer func() { _ = mntFd.Close() }()
-
- if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil {
- return false, fmt.Errorf("failed to join mount namespace: %w", err)
- }
-
- // Join Network Namespace
- netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid)
- netFd, err := os.Open(netNsPath)
- if err != nil {
- return false, fmt.Errorf("failed to open network namespace: %w", err)
- }
- defer func() { _ = netFd.Close() }()
-
- if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil {
- return false, fmt.Errorf("failed to join network namespace: %w", err)
- }
-
- return true, nil
+ return 0, nil
}