diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | AGENTS.md | 4 | ||||
| -rw-r--r-- | cmd/wg-wrap/main.go | 27 | ||||
| -rw-r--r-- | internal/cli/cli.go | 103 | ||||
| -rw-r--r-- | internal/namespace/launcher_src/launcher.c | 63 | ||||
| -rw-r--r-- | internal/namespace/namespace.go | 95 | ||||
| -rw-r--r-- | internal/namespace/namespace_stub.go | 6 | ||||
| -rw-r--r-- | internal/namespace/pinning.go | 66 | ||||
| -rw-r--r-- | internal/paths/paths.go | 4 | ||||
| -rw-r--r-- | tests/e2e/sharing_test.go | 108 |
10 files changed, 367 insertions, 111 deletions
@@ -1,6 +1,6 @@ # Binaries bin/ -wg-wrap +/wg-wrap **/launcher **/launcher.bin **/launcher_test @@ -83,7 +83,7 @@ We employ a three-tier testing approach to balance speed and reliability: ### 3. Namespace Lifecycle - **Creation**: `CLONE_NEWUSER` $\rightarrow$ `CLONE_NEWNS` $\rightarrow$ `CLONE_NEWNET` inside an embedded C launcher. -- **Persistence & Sharing**: Namespaces are pinned and shared rootlessly. Processes record active runs inside a profile's `pids/` directory. Subsequent wrapping calls use `setns` (via `unix.Setns`) to enter the existing namespace context in $\approx 10\text{ms}$. +- **Persistence & Sharing**: Namespaces are pinned and shared rootlessly. Processes record active runs inside a profile's `pids/` directory. Subsequent wrapping calls discover the active PID and re-execute through our single-threaded C launcher to call `setns` (joining User, Mount, and Network namespaces) in $\approx 10\text{ms}$ before the Go runtime starts, bypassing Go's multi-threaded `CLONE_NEWNS` limitation. - **Cleanup**: When the last active process registers its exit, the reference counting detects 0 remaining sessions, automatically unpins state files, and releases resources cleanly. ## System Assumptions @@ -99,4 +99,4 @@ The project assumes the target environment is a modern Linux system configured f 3. **Host Socket Preservation**: Open UDP sockets on the host before isolation and pass them (`WG_WRAP_HOST_SOCKET_FD`) to `wireguard-go` using `FDBind` to bypass kernel security boundaries. 4. **Data Path**: Integrate `wireguard-go` with `tun` devices seamlessly inside the namespace. 5. **Routing**: Automatically build default routing gateway tables in the isolated network namespace. -6. **Namespace Sharing**: Connect concurrent wrapping runs to the active tunnel rootlessly via `setns`. +6. **Namespace Sharing**: Connect concurrent wrapping runs to the active tunnel rootlessly via `setns` inside the single-threaded C launcher. diff --git a/cmd/wg-wrap/main.go b/cmd/wg-wrap/main.go new file mode 100644 index 0000000..7e2de9f --- /dev/null +++ b/cmd/wg-wrap/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "errors" + "fmt" + "os" + "os/exec" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/cli" +) + +func main() { + app := cli.NewApp(os.Args) + + // 1. Routing Phase: Handle diagnostic and management commands first. + // These should run in the current namespace/context. + if err := app.Route(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + + // Propagate the exact exit code if the wrapped command failed with a non-zero exit status. + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitCode()) + } + os.Exit(1) + } +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 9b3409e..87ee34f 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -127,6 +127,9 @@ func (a *App) Run() error { pm := a.getPathManager() + // Preserve the host runtime base dir in the environment before bootstrapping + _ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir()) + // Acquire startup lock to prevent concurrent bootstrap/joining races lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) if lockErr == nil { @@ -135,12 +138,14 @@ func (a *App) Run() error { // Before bootstrapping, see if an active namespace/process for the profile exists. // If yes, we can join it! - joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile) - if err == nil && joined { - // We have joined the active namespace (user, mnt, net). + activePid, err := namespace.FindActiveProfilePid(pm, cfg.Profile) + if err == nil && activePid > 0 { // Release the lock before executing the command to allow others to join namespace.ReleaseProfileLock(lockFile) - return a.ExecuteCommand(cfg) + if err := namespace.BootstrapJoin(activePid); err != nil { + return fmt.Errorf("failed to join existing namespace: %w", err) + } + return nil } if err := namespace.Bootstrap(); err != nil { @@ -196,60 +201,64 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } }() - fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) + if os.Getenv("WG_WRAP_JOINED") == "1" { + fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile) + } else { + fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile) - // Parse the profile configuration - profilesDir := pm.ConfigDir() - profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") + // Parse the profile configuration + profilesDir := pm.ConfigDir() + profilePath := filepath.Join(profilesDir, cfg.Profile+".conf") - // Create tunnel if the file exists - if _, err := os.Stat(profilePath); err == nil { - wgCfg, err := wgconf.Parse(profilePath) - if err != nil { - return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) - } + // Create tunnel if the file exists + if _, err := os.Stat(profilePath); err == nil { + wgCfg, err := wgconf.Parse(profilePath) + if err != nil { + return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) + } - // Start the WireGuard userspace device & routing table setup - dnsServer := cfg.DNSServer - if dnsServer == "" { - dnsServer = wgCfg.DNS - } - if dnsServer == "" { - dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks - hasDefaultRoute := false - for _, peer := range wgCfg.Peers { - for _, ip := range peer.AllowedIPs { - trimmed := strings.TrimSpace(ip) - if trimmed == "0.0.0.0/0" || trimmed == "::/0" { - hasDefaultRoute = true + // Start the WireGuard userspace device & routing table setup + dnsServer := cfg.DNSServer + if dnsServer == "" { + dnsServer = wgCfg.DNS + } + if dnsServer == "" { + dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks + hasDefaultRoute := false + for _, peer := range wgCfg.Peers { + for _, ip := range peer.AllowedIPs { + trimmed := strings.TrimSpace(ip) + if trimmed == "0.0.0.0/0" || trimmed == "::/0" { + hasDefaultRoute = true + break + } + } + if hasDefaultRoute { break } } - if hasDefaultRoute { - break + if !hasDefaultRoute { + fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") } } - if !hasDefaultRoute { - fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") - } - } - tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) - if err != nil { - return fmt.Errorf("failed to start WireGuard tunnel: %w", err) - } - defer tunnel.Close() + tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) + if err != nil { + return fmt.Errorf("failed to start WireGuard tunnel: %w", err) + } + defer tunnel.Close() - // Pin the namespace so others can join it - if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { - fmt.Printf("warning: failed to pin namespace: %v\n", err) - } - } else { - // If profile is not default or it was explicitly requested but doesn't exist, we error - if cfg.Profile != "default" { - return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) + // Pin the namespace so others can join it + if err := namespace.PinNamespace(pm, cfg.Profile); err != nil { + fmt.Printf("warning: failed to pin namespace: %v\n", err) + } + } else { + // If profile is not default or it was explicitly requested but doesn't exist, we error + if cfg.Profile != "default" { + return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) + } + fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } - fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } // We can now release the startup lock and execute the command diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c index 7bbbce7..60c6558 100644 --- a/internal/namespace/launcher_src/launcher.c +++ b/internal/namespace/launcher_src/launcher.c @@ -13,6 +13,69 @@ int main(int argc, char **argv) { return 1; } + // Check if we are joining an existing namespace + char *join_pid_str = getenv("WG_WRAP_JOIN_PID"); + if (join_pid_str != NULL && strlen(join_pid_str) > 0) { + int target_pid = atoi(join_pid_str); + if (target_pid > 0) { + char path[128]; + int fd; + + // 1. Join User Namespace first + snprintf(path, sizeof(path), "/proc/%d/ns/user", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target user namespace"); + return 1; + } + if (setns(fd, CLONE_NEWUSER) == -1) { + perror("setns CLONE_NEWUSER"); + close(fd); + return 1; + } + close(fd); + + // 2. Join Mount Namespace + snprintf(path, sizeof(path), "/proc/%d/ns/mnt", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target mount namespace"); + return 1; + } + if (setns(fd, CLONE_NEWNS) == -1) { + perror("setns CLONE_NEWNS"); + close(fd); + return 1; + } + close(fd); + + // 3. Join Network Namespace + snprintf(path, sizeof(path), "/proc/%d/ns/net", target_pid); + fd = open(path, O_RDONLY); + if (fd == -1) { + perror("open target network namespace"); + return 1; + } + if (setns(fd, CLONE_NEWNET) == -1) { + perror("setns CLONE_NEWNET"); + close(fd); + return 1; + } + close(fd); + + // Execute the target command + if (argv[0] == NULL) { + fprintf(stderr, "No target binary provided in argv[0]\n"); + return 1; + } + if (execv(argv[0], argv) == -1) { + perror("execv failed"); + return 1; + } + return 0; + } + } + // 1. Capture host identities BEFORE unsharing uid_t current_uid = getuid(); gid_t current_gid = getgid(); diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go index ab3797d..0f2618b 100644 --- a/internal/namespace/namespace.go +++ b/internal/namespace/namespace.go @@ -194,3 +194,98 @@ func Bootstrap() (err error) { return nil } + +// BootstrapJoin joins the namespaces of the target PID and replaces the current process. +func BootstrapJoin(targetPid int) (err error) { + if IsIsolated() { + return nil + } + + var fdsToClose []int + defer func() { + for _, fd := range fdsToClose { + _ = syscall.Close(fd) + } + }() + + // Validate current arguments for null bytes before proceeding. + for i, arg := range os.Args { + for j := 0; j < len(arg); j++ { + if arg[j] == 0 { + return fmt.Errorf("argument %d contains null byte at position %d", i, j) + } + } + } + + self, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + // 1. Create a secure temporary file for the launcher binary. + tmpFile, err := os.CreateTemp("", "wg-wrap-launcher-") + if err != nil { + return fmt.Errorf("failed to create temp launcher file: %w", err) + } + launcherPath := tmpFile.Name() + + // 2. Write the embedded launcher binary to the temp file. + if _, err := tmpFile.Write(launcherBytes); err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to write launcher binary: %w", err) + } + + // Ensure the binary is executable (0700) + if err := tmpFile.Chmod(0700); err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to set launcher permissions: %w", err) + } + + // 2b. Open a read-only fd of the launcher to exec + execFd, err := syscall.Open(launcherPath, syscall.O_RDONLY, 0) + if err != nil { + _ = tmpFile.Close() + _ = os.Remove(launcherPath) + return fmt.Errorf("failed to open launcher for exec: %w", err) + } + fdsToClose = append(fdsToClose, execFd) + + // Close the write file descriptor (to avoid ETXTBSY) + _ = tmpFile.Close() + + // Unlink the file from disk (makes it invisible and ensures it is deleted on exit) + _ = os.Remove(launcherPath) + + // Clear close-on-exec so it remains open across syscall.Exec + if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { + _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) + } + + // 3. Prepare arguments for the launcher. + args := []string{self} + args = append(args, os.Args[1:]...) + + for i, arg := range args { + for j := 0; j < len(arg); j++ { + if arg[j] == 0 { + return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + } + } + } + + // Set environment variables to tell the C launcher to join, + // and to tell the second wg-wrap instance that we are in a joined session. + env := append(os.Environ(), + fmt.Sprintf("WG_WRAP_JOIN_PID=%d", targetPid), + "WG_WRAP_JOINED=1", + ) + + err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} diff --git a/internal/namespace/namespace_stub.go b/internal/namespace/namespace_stub.go index 84946bf..db0ec24 100644 --- a/internal/namespace/namespace_stub.go +++ b/internal/namespace/namespace_stub.go @@ -17,7 +17,7 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return fmt.Errorf("namespaces are not supported on this platform") } -// JoinExistingNamespace attempts to join the namespaces (user, mount, net) of an already active process. -func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) { - return false, fmt.Errorf("namespaces are not supported on this platform") +// FindActiveProfilePid is a stub for non-Linux platforms. +func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) { + return 0, fmt.Errorf("namespaces are not supported on this platform") } diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go index 2433203..a522f17 100644 --- a/internal/namespace/pinning.go +++ b/internal/namespace/pinning.go @@ -6,12 +6,9 @@ import ( "fmt" "os" "path/filepath" - "runtime" "strconv" - "syscall" "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" - "golang.org/x/sys/unix" ) // PinNamespace touches the namespace path to indicate it is pinned/active. @@ -53,77 +50,30 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error { return nil } -// JoinExistingNamespace attempts to join the namespaces (user, mount, net) -// of an already active process running under the same profile. -// Returns true if a namespace was successfully joined, false if no active namespace exists. -func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) { +// FindActiveProfilePid looks for an active PID running under the specified profile. +// Returns 0 if no active process is found. +func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) { if err := PruneStalePids(pm, profile); err != nil { - return false, fmt.Errorf("failed to prune stale pids: %w", err) + return 0, fmt.Errorf("failed to prune stale pids: %w", err) } pidsDir := GetPidsDirPath(pm, profile) files, err := os.ReadDir(pidsDir) if err != nil { if os.IsNotExist(err) { - return false, nil + return 0, nil } - return false, fmt.Errorf("failed to read pids dir: %w", err) + return 0, fmt.Errorf("failed to read pids dir: %w", err) } - var activePid int for _, file := range files { pid, err := strconv.Atoi(file.Name()) if err != nil { continue } // Since we already pruned stale pids, the first file we find is an active pid! - activePid = pid - break + return pid, nil } - if activePid == 0 { - return false, nil - } - - // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread, - // and that the thread is not reused for other goroutines (since we never unlock it). - runtime.LockOSThread() - - // Join User Namespace first - userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid) - userFd, err := os.Open(userNsPath) - if err != nil { - return false, fmt.Errorf("failed to open user namespace: %w", err) - } - defer func() { _ = userFd.Close() }() - - if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil { - return false, fmt.Errorf("failed to join user namespace: %w", err) - } - - // Join Mount Namespace - mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid) - mntFd, err := os.Open(mntNsPath) - if err != nil { - return false, fmt.Errorf("failed to open mount namespace: %w", err) - } - defer func() { _ = mntFd.Close() }() - - if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil { - return false, fmt.Errorf("failed to join mount namespace: %w", err) - } - - // Join Network Namespace - netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid) - netFd, err := os.Open(netNsPath) - if err != nil { - return false, fmt.Errorf("failed to open network namespace: %w", err) - } - defer func() { _ = netFd.Close() }() - - if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil { - return false, fmt.Errorf("failed to join network namespace: %w", err) - } - - return true, nil + return 0, nil } diff --git a/internal/paths/paths.go b/internal/paths/paths.go index f512ad1..c7bdd94 100644 --- a/internal/paths/paths.go +++ b/internal/paths/paths.go @@ -44,6 +44,10 @@ func (pm *PathManager) RuntimeBaseDir() string { return pm.RuntimeBaseOverride } + if envDir := os.Getenv("WG_WRAP_HOST_RUNTIME_BASE_DIR"); envDir != "" { + return envDir + } + if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" { return envDir } diff --git a/tests/e2e/sharing_test.go b/tests/e2e/sharing_test.go new file mode 100644 index 0000000..b0971f9 --- /dev/null +++ b/tests/e2e/sharing_test.go @@ -0,0 +1,108 @@ +package e2e + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestNamespaceSharing(t *testing.T) { + binaryPath, err := GetBinaryPath() + if err != nil { + t.Skipf("Skipping test: %v", err) + } + + tmpRuntimeDir := t.TempDir() + tmpConfigDir := t.TempDir() + profile := "sharing-test" + + // Write a valid dummy profile so it doesn't run in bare isolation + profilesDir := filepath.Join(tmpConfigDir, "wg-wrap", "profiles") + if err := os.MkdirAll(profilesDir, 0755); err != nil { + t.Fatal(err) + } + profileConfPath := filepath.Join(profilesDir, profile+".conf") + conf := `[Interface] +Address = 10.0.0.2/24 +PrivateKey = 0000000000000000000000000000000000000000000000000000000000000000 +DNS = 1.1.1.1 + +[Peer] +PublicKey = 0000000000000000000000000000000000000000000000000000000000000000 +AllowedIPs = 0.0.0.0/0 +Endpoint = 1.1.1.1:51820 +` + if err := os.WriteFile(profileConfPath, []byte(conf), 0644); err != nil { + t.Fatal(err) + } + + pidsDir := filepath.Join(tmpRuntimeDir, "profiles", profile, "pids") + + // Start Process A running a command that outputs its netns and sleeps + cmdA := exec.Command(binaryPath, "--profile", profile, "--", "sh", "-c", "readlink /proc/self/ns/net && sleep 5") + cmdA.Env = append(os.Environ(), + fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir), + fmt.Sprintf("XDG_CONFIG_HOME=%s", tmpConfigDir), + ) + + outA, err := cmdA.StdoutPipe() + if err != nil { + t.Fatalf("Failed to create stdout pipe for Process A: %v", err) + } + + if err := cmdA.Start(); err != nil { + t.Fatalf("Failed to start Process A: %v", err) + } + defer func() { _ = cmdA.Process.Kill() }() + + // Wait for Process A to output its netns ID line by line + var parsedNetnsA string + scannerA := bufio.NewScanner(outA) + for scannerA.Scan() { + line := strings.TrimSpace(scannerA.Text()) + if strings.HasPrefix(line, "net:[") { + parsedNetnsA = line + break + } + } + + if parsedNetnsA == "" { + t.Fatalf("Failed to get netns ID from Process A") + } + + // Wait for Process A's PID to be registered on the host + waitForPids(t, pidsDir, 1) + + // Start Process B to check its netns ID + cmdB := exec.Command(binaryPath, "--profile", profile, "--", "readlink", "/proc/self/ns/net") + cmdB.Env = append(os.Environ(), + fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir), + fmt.Sprintf("XDG_CONFIG_HOME=%s", tmpConfigDir), + ) + + outB, err := cmdB.CombinedOutput() + if err != nil { + t.Fatalf("Process B failed to execute: %v\nOutput: %s", err, string(outB)) + } + + var parsedNetnsB string + for _, line := range strings.Split(string(outB), "\n") { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "net:[") { + parsedNetnsB = trimmed + break + } + } + + if parsedNetnsB == "" { + t.Fatalf("Invalid netns ID format from Process B: %q", string(outB)) + } + + if parsedNetnsA != parsedNetnsB { + t.Errorf("BUG: Process A and Process B do not share the same network namespace!\nProcess A: %s\nProcess B: %s", parsedNetnsA, parsedNetnsB) + } +} |
