diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/cli/cli.go | 103 | ||||
| -rw-r--r-- | internal/namespace/launcher_src/launcher.c | 63 | ||||
| -rw-r--r-- | internal/namespace/namespace.go | 95 | ||||
| -rw-r--r-- | internal/namespace/namespace_stub.go | 6 | ||||
| -rw-r--r-- | internal/namespace/pinning.go | 66 | ||||
| -rw-r--r-- | internal/paths/paths.go | 4 |
6 files changed, 229 insertions, 108 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 9b3409e..87ee34f 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -127,6 +127,9 @@ func (a *App) Run() error { pm := a.getPathManager() + // Preserve the host runtime base dir in the environment before bootstrapping + _ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir()) + // Acquire startup lock to prevent concurrent bootstrap/joining races lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) if lockErr == nil { @@ -135,12 +138,14 @@ func (a *App) Run() error { // Before bootstrapping, see if an active namespace/process for the profile exists. // If yes, we can join it! - joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile) - if err == nil && joined { - // We have joined the active namespace (user, mnt, net). + activePid, err := namespace.FindActiveProfilePid(pm, cfg.Profile) + if err == nil && activePid > 0 { // Release the lock before executing the command to allow others to join namespace.ReleaseProfileLock(lockFile) - return a.ExecuteCommand(cfg) + if err := namespace.BootstrapJoin(activePid); err != nil { + return fmt.Errorf("failed to join existing namespace: %w", err) + } + return nil } if err := namespace.Bootstrap(); err != nil { @@ -196,60 +201,64 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } }() - fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) + if os.Getenv("WG_WRAP_JOINED") == "1" { + fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile) + } else { + fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) - // Parse the profile configuration - profilesDir := pm.ConfigDir() - profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") + // Parse the profile configuration + profilesDir := pm.ConfigDir() + profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") - // Create tunnel if the file exists - if _, err := os.Stat(profilePath); err == nil { - wgCfg, err := wgconf.Parse(profilePath) - if err != nil { - return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) - } + // Create tunnel if the file exists + if _, err := os.Stat(profilePath); err == nil { + wgCfg, err := wgconf.Parse(profilePath) + if err != nil { + return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) + } - // Start the WireGuard userspace device & routing table setup - dnsServer := cfg.DNSServer - if dnsServer == "" { - dnsServer = wgCfg.DNS - } - if dnsServer == "" { - dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks - hasDefaultRoute := false - for _, peer := range wgCfg.Peers { - for _, ip := range peer.AllowedIPs { - trimmed := strings.TrimSpace(ip) - if trimmed == "0.0.0.0/0" || trimmed == "::/0" { - hasDefaultRoute = true + // Start the WireGuard userspace device & routing table setup + dnsServer := cfg.DNSServer + if dnsServer == "" { + dnsServer = wgCfg.DNS + } + if dnsServer == "" { + dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks + hasDefaultRoute := false + for _, peer := range wgCfg.Peers { + for _, ip := range peer.AllowedIPs { + trimmed := strings.TrimSpace(ip) + if trimmed == "0.0.0.0/0" || trimmed == "::/0" { + hasDefaultRoute = true + break + } + } + if hasDefaultRoute { break } } - if hasDefaultRoute { - break + if !hasDefaultRoute { + fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") } } - if !hasDefaultRoute { - fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") - } - } - tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) - if err != nil { - return fmt.Errorf("failed to start WireGuard tunnel: %w", err) - } - defer tunnel.Close() + tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) + if err != nil { + return fmt.Errorf("failed to start WireGuard tunnel: %w", err) + } + defer tunnel.Close() - // Pin the namespace so others can join it - if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { - fmt.Printf("warning: failed to pin namespace: %v\n", err) - } - } else { - // If profile is not default or it was explicitly requested but doesn't exist, we error - if cfg.Profile != "default" { - return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) + // Pin the namespace so others can join it + if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { + fmt.Printf("warning: failed to pin namespace: %v\n", err) + } + } else { + // If profile is not default or it was explicitly requested but doesn't exist, we error + if cfg.Profile != "default" { + return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) + } + fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } - fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } // We can now release the startup lock and execute the command diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c index 7bbbce7..60c6558 100644 --- a/internal/namespace/launcher_src/launcher.c +++ b/internal/namespace/launcher_src/launcher.c @@ -13,6 +13,69 @@ int main(int argc, char **argv) { return 1; } + // Check if we are joining an existing namespace + char *join_pid_str = getenv("WG_WRAP_JOIN_PID"); + if (join_pid_str != NULL && strlen(join_pid_str) > 0) { + int target_pid = atoi(join_pid_str); + if (target_pid > 0) { + char path[128]; + int fd; + + // 1. Join User Namespace first + snprintf(path, sizeof(path), "/proc/%d/ns/user", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target user namespace"); + return 1; + } + if (setns(fd, CLONE_NEWUSER) == -1) { + perror("setns CLONE_NEWUSER"); + close(fd); + return 1; + } + close(fd); + + // 2. Join Mount Namespace + snprintf(path, sizeof(path), "/proc/%d/ns/mnt", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target mount namespace"); + return 1; + } + if (setns(fd, CLONE_NEWNS) == -1) { + perror("setns CLONE_NEWNS"); + close(fd); + return 1; + } + close(fd); + + // 3. Join Network Namespace + snprintf(path, sizeof(path), "/proc/%d/ns/net", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target network namespace"); + return 1; + } + if (setns(fd, CLONE_NEWNET) == -1) { + perror("setns CLONE_NEWNET"); + close(fd); + return 1; + } + close(fd); + + // Execute the target command + if (argv[0] == NULL) { + fprintf(stderr, "No target binary provided in argv[0]\n"); + return 1; + } + if (execv(argv[0], argv) == -1) { + perror("execv failed"); + return 1; + } + return 0; + } + } + // 1. Capture host identities BEFORE unsharing uid_t current_uid = getuid(); gid_t current_gid = getgid(); diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go index ab3797d..0f2618b 100644 --- a/internal/namespace/namespace.go +++ b/internal/namespace/namespace.go @@ -194,3 +194,98 @@ func Bootstrap() (err error) { return nil } + +// BootstrapJoin joins the namespaces of the target PID and replaces the current process. +func BootstrapJoin(targetPid int) (err error) { + if IsIsolated() { + return nil + } + + var fdsToClose []int + defer func() { + for _, fd := range fdsToClose { + _ = syscall.Close(fd) + } + }() + + // Validate current arguments for null bytes before proceeding. + for i, arg := range os.Args { + for j := 0; j < len(arg); j++ { + if arg[j] == 0 { + return fmt.Errorf("argument %d contains null byte at position %d", i, j) + } + } + } + + self, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + // 1. Create a secure temporary file for the launcher binary. + tmpFile, err := os.CreateTemp("", "wg-wrap-launcher-") + if err != nil { + return fmt.Errorf("failed to create temp launcher file: %w", err) + } + launcherPath := tmpFile.Name() + + // 2. Write the embedded launcher binary to the temp file. + if _, err := tmpFile.Write(launcherBytes); err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to write launcher binary: %w", err) + } + + // Ensure the binary is executable (0700) + if err := tmpFile.Chmod(0700); err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to set launcher permissions: %w", err) + } + + // 2b. Open a read-only fd of the launcher to exec + execFd, err := syscall.Open(launcherPath, syscall.O_RDONLY, 0) + if err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to open launcher for exec: %w", err) + } + fdsToClose = append(fdsToClose, execFd) + + // Close the write file descriptor (to avoid ETXTBSY) + _ = tmpFile.Close() + + // Unlink the file from disk (makes it invisible and ensures it is deleted on exit) + _ = os.Remove(launcherPath) + + // Clear close-on-exec so it remains open across syscall.Exec + if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { + _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) + } + + // 3. Prepare arguments for the launcher. + args := []string{self} + args = append(args, os.Args[1:]...) + + for i, arg := range args { + for j := 0; j < len(arg); j++ { + if arg[j] == 0 { + return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + } + } + } + + // Set environment variables to tell the C launcher to join, + // and to tell the second wg-wrap instance that we are in a joined session. + env := append(os.Environ(), + fmt.Sprintf("WG_WRAP_JOIN_PID=%d", targetPid), + "WG_WRAP_JOINED=1", + ) + + err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} diff --git a/internal/namespace/namespace_stub.go b/internal/namespace/namespace_stub.go index 84946bf..db0ec24 100644 --- a/internal/namespace/namespace_stub.go +++ b/internal/namespace/namespace_stub.go @@ -17,7 +17,7 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return fmt.Errorf("namespaces are not supported on this platform") } -// JoinExistingNamespace attempts to join the namespaces (user, mount, net) of an already active process. -func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) { - return false, fmt.Errorf("namespaces are not supported on this platform") +// FindActiveProfilePid is a stub for non-Linux platforms. +func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) { + return 0, fmt.Errorf("namespaces are not supported on this platform") } diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go index 2433203..a522f17 100644 --- a/internal/namespace/pinning.go +++ b/internal/namespace/pinning.go @@ -6,12 +6,9 @@ import ( "fmt" "os" "path/filepath" - "runtime" "strconv" - "syscall" "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" - "golang.org/x/sys/unix" ) // PinNamespace touches the namespace path to indicate it is pinned/active. @@ -53,77 +50,30 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return nil } -// JoinExistingNamespace attempts to join the namespaces (user, mount, net) -// of an already active process running under the same profile. -// Returns true if a namespace was successfully joined, false if no active namespace exists. -func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) { +// FindActiveProfilePid looks for an active PID running under the specified profile. +// Returns 0 if no active process is found. +func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) { if err := PruneStalePids(pm, profile); err != nil { - return false, fmt.Errorf("failed to prune stale pids: %w", err) + return 0, fmt.Errorf("failed to prune stale pids: %w", err) } pidsDir := GetPidsDirPath(pm, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { - return false, nil + return 0, nil } - return false, fmt.Errorf("failed to read pids dir: %w", err) + return 0, fmt.Errorf("failed to read pids dir: %w", err) } - var activePid int for _, file := range files { pid, err := strconv.Atoi(file.Name()) if err != nil { continue } // Since we already pruned stale pids, the first file we find is an active pid! - activePid = pid - break + return pid, nil } - if activePid == 0 { - return false, nil - } - - // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread, - // and that the thread is not reused for other goroutines (since we never unlock it). - runtime.LockOSThread() - - // Join User Namespace first - userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid) - userFd, err := os.Open(userNsPath) - if err != nil { - return false, fmt.Errorf("failed to open user namespace: %w", err) - } - defer func() { _ = userFd.Close() }() - - if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil { - return false, fmt.Errorf("failed to join user namespace: %w", err) - } - - // Join Mount Namespace - mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid) - mntFd, err := os.Open(mntNsPath) - if err != nil { - return false, fmt.Errorf("failed to open mount namespace: %w", err) - } - defer func() { _ = mntFd.Close() }() - - if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil { - return false, fmt.Errorf("failed to join mount namespace: %w", err) - } - - // Join Network Namespace - netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid) - netFd, err := os.Open(netNsPath) - if err != nil { - return false, fmt.Errorf("failed to open network namespace: %w", err) - } - defer func() { _ = netFd.Close() }() - - if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil { - return false, fmt.Errorf("failed to join network namespace: %w", err) - } - - return true, nil + return 0, nil } diff --git a/internal/paths/paths.go b/internal/paths/paths.go index f512ad1..c7bdd94 100644 --- a/internal/paths/paths.go +++ b/internal/paths/paths.go @@ -44,6 +44,10 @@ func (pm *PathManager) RuntimeBaseDir() string { return pm.RuntimeBaseOverride } + if envDir := os.Getenv("WG_WRAP_HOST_RUNTIME_BASE_DIR"); envDir != "" { + return envDir + } + if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" { return envDir } |
