diff options
Diffstat (limited to 'internal/namespace')
| -rw-r--r-- | internal/namespace/launcher_src/launcher.c | 63 | ||||
| -rw-r--r-- | internal/namespace/namespace.go | 95 | ||||
| -rw-r--r-- | internal/namespace/namespace_stub.go | 6 | ||||
| -rw-r--r-- | internal/namespace/pinning.go | 66 |
4 files changed, 169 insertions, 61 deletions
diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c index 7bbbce7..60c6558 100644 --- a/internal/namespace/launcher_src/launcher.c +++ b/internal/namespace/launcher_src/launcher.c @@ -13,6 +13,69 @@ int main(int argc, char **argv) { return 1; } + // Check if we are joining an existing namespace + char *join_pid_str = getenv("WG_WRAP_JOIN_PID"); + if (join_pid_str != NULL && strlen(join_pid_str) > 0) { + int target_pid = atoi(join_pid_str); + if (target_pid > 0) { + char path[128]; + int fd; + + // 1. Join User Namespace first + snprintf(path, sizeof(path), "/proc/%d/ns/user", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target user namespace"); + return 1; + } + if (setns(fd, CLONE_NEWUSER) == -1) { + perror("setns CLONE_NEWUSER"); + close(fd); + return 1; + } + close(fd); + + // 2. Join Mount Namespace + snprintf(path, sizeof(path), "/proc/%d/ns/mnt", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target mount namespace"); + return 1; + } + if (setns(fd, CLONE_NEWNS) == -1) { + perror("setns CLONE_NEWNS"); + close(fd); + return 1; + } + close(fd); + + // 3. Join Network Namespace + snprintf(path, sizeof(path), "/proc/%d/ns/net", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target network namespace"); + return 1; + } + if (setns(fd, CLONE_NEWNET) == -1) { + perror("setns CLONE_NEWNET"); + close(fd); + return 1; + } + close(fd); + + // Execute the target command + if (argv[0] == NULL) { + fprintf(stderr, "No target binary provided in argv[0]\n"); + return 1; + } + if (execv(argv[0], argv) == -1) { + perror("execv failed"); + return 1; + } + return 0; + } + } + // 1. Capture host identities BEFORE unsharing uid_t current_uid = getuid(); gid_t current_gid = getgid(); diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go index ab3797d..0f2618b 100644 --- a/internal/namespace/namespace.go +++ b/internal/namespace/namespace.go @@ -194,3 +194,98 @@ func Bootstrap() (err error) { return nil } + +// BootstrapJoin joins the namespaces of the target PID and replaces the current process. +func BootstrapJoin(targetPid int) (err error) { + if IsIsolated() { + return nil + } + + var fdsToClose []int + defer func() { + for _, fd := range fdsToClose { + _ = syscall.Close(fd) + } + }() + + // Validate current arguments for null bytes before proceeding. + for i, arg := range os.Args { + for j := 0; j < len(arg); j++ { + if arg[j] == 0 { + return fmt.Errorf("argument %d contains null byte at position %d", i, j) + } + } + } + + self, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + // 1. Create a secure temporary file for the launcher binary. + tmpFile, err := os.CreateTemp("", "wg-wrap-launcher-") + if err != nil { + return fmt.Errorf("failed to create temp launcher file: %w", err) + } + launcherPath := tmpFile.Name() + + // 2. Write the embedded launcher binary to the temp file. + if _, err := tmpFile.Write(launcherBytes); err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to write launcher binary: %w", err) + } + + // Ensure the binary is executable (0700) + if err := tmpFile.Chmod(0700); err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to set launcher permissions: %w", err) + } + + // 2b. Open a read-only fd of the launcher to exec + execFd, err := syscall.Open(launcherPath, syscall.O_RDONLY, 0) + if err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to open launcher for exec: %w", err) + } + fdsToClose = append(fdsToClose, execFd) + + // Close the write file descriptor (to avoid ETXTBSY) + _ = tmpFile.Close() + + // Unlink the file from disk (makes it invisible and ensures it is deleted on exit) + _ = os.Remove(launcherPath) + + // Clear close-on-exec so it remains open across syscall.Exec + if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { + _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) + } + + // 3. Prepare arguments for the launcher. + args := []string{self} + args = append(args, os.Args[1:]...) + + for i, arg := range args { + for j := 0; j < len(arg); j++ { + if arg[j] == 0 { + return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + } + } + } + + // Set environment variables to tell the C launcher to join, + // and to tell the second wg-wrap instance that we are in a joined session. + env := append(os.Environ(), + fmt.Sprintf("WG_WRAP_JOIN_PID=%d", targetPid), + "WG_WRAP_JOINED=1", + ) + + err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} diff --git a/internal/namespace/namespace_stub.go b/internal/namespace/namespace_stub.go index 84946bf..db0ec24 100644 --- a/internal/namespace/namespace_stub.go +++ b/internal/namespace/namespace_stub.go @@ -17,7 +17,7 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return fmt.Errorf("namespaces are not supported on this platform") } -// JoinExistingNamespace attempts to join the namespaces (user, mount, net) of an already active process. -func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) { - return false, fmt.Errorf("namespaces are not supported on this platform") +// FindActiveProfilePid is a stub for non-Linux platforms. +func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) { + return 0, fmt.Errorf("namespaces are not supported on this platform") } diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go index 2433203..a522f17 100644 --- a/internal/namespace/pinning.go +++ b/internal/namespace/pinning.go @@ -6,12 +6,9 @@ import ( "fmt" "os" "path/filepath" - "runtime" "strconv" - "syscall" "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" - "golang.org/x/sys/unix" ) // PinNamespace touches the namespace path to indicate it is pinned/active. @@ -53,77 +50,30 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return nil } -// JoinExistingNamespace attempts to join the namespaces (user, mount, net) -// of an already active process running under the same profile. -// Returns true if a namespace was successfully joined, false if no active namespace exists. -func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) { +// FindActiveProfilePid looks for an active PID running under the specified profile. +// Returns 0 if no active process is found. +func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) { if err := PruneStalePids(pm, profile); err != nil { - return false, fmt.Errorf("failed to prune stale pids: %w", err) + return 0, fmt.Errorf("failed to prune stale pids: %w", err) } pidsDir := GetPidsDirPath(pm, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { - return false, nil + return 0, nil } - return false, fmt.Errorf("failed to read pids dir: %w", err) + return 0, fmt.Errorf("failed to read pids dir: %w", err) } - var activePid int for _, file := range files { pid, err := strconv.Atoi(file.Name()) if err != nil { continue } // Since we already pruned stale pids, the first file we find is an active pid! - activePid = pid - break + return pid, nil } - if activePid == 0 { - return false, nil - } - - // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread, - // and that the thread is not reused for other goroutines (since we never unlock it). - runtime.LockOSThread() - - // Join User Namespace first - userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid) - userFd, err := os.Open(userNsPath) - if err != nil { - return false, fmt.Errorf("failed to open user namespace: %w", err) - } - defer func() { _ = userFd.Close() }() - - if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil { - return false, fmt.Errorf("failed to join user namespace: %w", err) - } - - // Join Mount Namespace - mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid) - mntFd, err := os.Open(mntNsPath) - if err != nil { - return false, fmt.Errorf("failed to open mount namespace: %w", err) - } - defer func() { _ = mntFd.Close() }() - - if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil { - return false, fmt.Errorf("failed to join mount namespace: %w", err) - } - - // Join Network Namespace - netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid) - netFd, err := os.Open(netNsPath) - if err != nil { - return false, fmt.Errorf("failed to open network namespace: %w", err) - } - defer func() { _ = netFd.Close() }() - - if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil { - return false, fmt.Errorf("failed to join network namespace: %w", err) - } - - return true, nil + return 0, nil } |
