summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/cli/cli.go103
-rw-r--r--internal/namespace/launcher_src/launcher.c63
-rw-r--r--internal/namespace/namespace.go95
-rw-r--r--internal/namespace/namespace_stub.go6
-rw-r--r--internal/namespace/pinning.go66
-rw-r--r--internal/paths/paths.go4
6 files changed, 229 insertions, 108 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 9b3409e..87ee34f 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -127,6 +127,9 @@ func (a *App) Run() error {
pm := a.getPathManager()
+ // Preserve the host runtime base dir in the environment before bootstrapping
+ _ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", pm.RuntimeBaseDir())
+
// Acquire startup lock to prevent concurrent bootstrap/joining races
lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile)
if lockErr == nil {
@@ -135,12 +138,14 @@ func (a *App) Run() error {
// Before bootstrapping, see if an active namespace/process for the profile exists.
// If yes, we can join it!
- joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile)
- if err == nil && joined {
- // We have joined the active namespace (user, mnt, net).
+ activePid, err := namespace.FindActiveProfilePid(pm, cfg.Profile)
+ if err == nil && activePid > 0 {
// Release the lock before executing the command to allow others to join
namespace.ReleaseProfileLock(lockFile)
- return a.ExecuteCommand(cfg)
+ if err := namespace.BootstrapJoin(activePid); err != nil {
+ return fmt.Errorf("failed to join existing namespace: %w", err)
+ }
+ return nil
}
if err := namespace.Bootstrap(); err != nil {
@@ -196,60 +201,64 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}
}()
- fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile)
+ if os.Getenv("WG_WRAP_JOINED") == "1" {
+ fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile)
+ } else {
+ fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile)
- // Parse the profile configuration
- profilesDir := pm.ConfigDir()
- profilePath := filepath.Join(profilesDir, cfg.Profile+".conf")
+ // Parse the profile configuration
+ profilesDir := pm.ConfigDir()
+ profilePath := filepath.Join(profilesDir, cfg.Profile+".conf")
- // Create tunnel if the file exists
- if _, err := os.Stat(profilePath); err == nil {
- wgCfg, err := wgconf.Parse(profilePath)
- if err != nil {
- return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
- }
+ // Create tunnel if the file exists
+ if _, err := os.Stat(profilePath); err == nil {
+ wgCfg, err := wgconf.Parse(profilePath)
+ if err != nil {
+ return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
+ }
- // Start the WireGuard userspace device & routing table setup
- dnsServer := cfg.DNSServer
- if dnsServer == "" {
- dnsServer = wgCfg.DNS
- }
- if dnsServer == "" {
- dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks
- hasDefaultRoute := false
- for _, peer := range wgCfg.Peers {
- for _, ip := range peer.AllowedIPs {
- trimmed := strings.TrimSpace(ip)
- if trimmed == "0.0.0.0/0" || trimmed == "::/0" {
- hasDefaultRoute = true
+ // Start the WireGuard userspace device & routing table setup
+ dnsServer := cfg.DNSServer
+ if dnsServer == "" {
+ dnsServer = wgCfg.DNS
+ }
+ if dnsServer == "" {
+ dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks
+ hasDefaultRoute := false
+ for _, peer := range wgCfg.Peers {
+ for _, ip := range peer.AllowedIPs {
+ trimmed := strings.TrimSpace(ip)
+ if trimmed == "0.0.0.0/0" || trimmed == "::/0" {
+ hasDefaultRoute = true
+ break
+ }
+ }
+ if hasDefaultRoute {
break
}
}
- if hasDefaultRoute {
- break
+ if !hasDefaultRoute {
+ fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n")
}
}
- if !hasDefaultRoute {
- fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n")
- }
- }
- tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
- if err != nil {
- return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
- }
- defer tunnel.Close()
+ tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
+ if err != nil {
+ return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
+ }
+ defer tunnel.Close()
- // Pin the namespace so others can join it
- if err := namespace.PinNamespace(pm, cfg.Profile); err != nil {
- fmt.Printf("warning: failed to pin namespace: %v\n", err)
- }
- } else {
- // If profile is not default or it was explicitly requested but doesn't exist, we error
- if cfg.Profile != "default" {
- return fmt.Errorf("profile %s not found: %w", cfg.Profile, err)
+ // Pin the namespace so others can join it
+ if err := namespace.PinNamespace(pm, cfg.Profile); err != nil {
+ fmt.Printf("warning: failed to pin namespace: %v\n", err)
+ }
+ } else {
+ // If profile is not default or it was explicitly requested but doesn't exist, we error
+ if cfg.Profile != "default" {
+ return fmt.Errorf("profile %s not found: %w", cfg.Profile, err)
+ }
+ fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n")
}
- fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n")
}
// We can now release the startup lock and execute the command
diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c
index 7bbbce7..60c6558 100644
--- a/internal/namespace/launcher_src/launcher.c
+++ b/internal/namespace/launcher_src/launcher.c
@@ -13,6 +13,69 @@ int main(int argc, char **argv) {
return 1;
}
+ // Check if we are joining an existing namespace
+ char *join_pid_str = getenv("WG_WRAP_JOIN_PID");
+ if (join_pid_str != NULL && strlen(join_pid_str) > 0) {
+ int target_pid = atoi(join_pid_str);
+ if (target_pid > 0) {
+ char path[128];
+ int fd;
+
+ // 1. Join User Namespace first
+ snprintf(path, sizeof(path), "/proc/%d/ns/user", target_pid);
+ fd = open(path, O_RDONLY);
+ if (fd == -1) {
+ perror("open target user namespace");
+ return 1;
+ }
+ if (setns(fd, CLONE_NEWUSER) == -1) {
+ perror("setns CLONE_NEWUSER");
+ close(fd);
+ return 1;
+ }
+ close(fd);
+
+ // 2. Join Mount Namespace
+ snprintf(path, sizeof(path), "/proc/%d/ns/mnt", target_pid);
+ fd = open(path, O_RDONLY);
+ if (fd == -1) {
+ perror("open target mount namespace");
+ return 1;
+ }
+ if (setns(fd, CLONE_NEWNS) == -1) {
+ perror("setns CLONE_NEWNS");
+ close(fd);
+ return 1;
+ }
+ close(fd);
+
+ // 3. Join Network Namespace
+ snprintf(path, sizeof(path), "/proc/%d/ns/net", target_pid);
+ fd = open(path, O_RDONLY);
+ if (fd == -1) {
+ perror("open target network namespace");
+ return 1;
+ }
+ if (setns(fd, CLONE_NEWNET) == -1) {
+ perror("setns CLONE_NEWNET");
+ close(fd);
+ return 1;
+ }
+ close(fd);
+
+ // Execute the target command
+ if (argv[0] == NULL) {
+ fprintf(stderr, "No target binary provided in argv[0]\n");
+ return 1;
+ }
+ if (execv(argv[0], argv) == -1) {
+ perror("execv failed");
+ return 1;
+ }
+ return 0;
+ }
+ }
+
// 1. Capture host identities BEFORE unsharing
uid_t current_uid = getuid();
gid_t current_gid = getgid();
diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go
index ab3797d..0f2618b 100644
--- a/internal/namespace/namespace.go
+++ b/internal/namespace/namespace.go
@@ -194,3 +194,98 @@ func Bootstrap() (err error) {
return nil
}
+
+// BootstrapJoin joins the namespaces of the target PID and replaces the current process.
+func BootstrapJoin(targetPid int) (err error) {
+ if IsIsolated() {
+ return nil
+ }
+
+ var fdsToClose []int
+ defer func() {
+ for _, fd := range fdsToClose {
+ _ = syscall.Close(fd)
+ }
+ }()
+
+ // Validate current arguments for null bytes before proceeding.
+ for i, arg := range os.Args {
+ for j := 0; j < len(arg); j++ {
+ if arg[j] == 0 {
+ return fmt.Errorf("argument %d contains null byte at position %d", i, j)
+ }
+ }
+ }
+
+ self, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("failed to get executable path: %w", err)
+ }
+
+ // 1. Create a secure temporary file for the launcher binary.
+ tmpFile, err := os.CreateTemp("", "wg-wrap-launcher-")
+ if err != nil {
+ return fmt.Errorf("failed to create temp launcher file: %w", err)
+ }
+ launcherPath := tmpFile.Name()
+
+ // 2. Write the embedded launcher binary to the temp file.
+ if _, err := tmpFile.Write(launcherBytes); err != nil {
+ _ = tmpFile.Close()
+ _ = os.Remove(launcherPath)
+ return fmt.Errorf("failed to write launcher binary: %w", err)
+ }
+
+ // Ensure the binary is executable (0700)
+ if err := tmpFile.Chmod(0700); err != nil {
+ _ = tmpFile.Close()
+ _ = os.Remove(launcherPath)
+ return fmt.Errorf("failed to set launcher permissions: %w", err)
+ }
+
+ // 2b. Open a read-only fd of the launcher to exec
+ execFd, err := syscall.Open(launcherPath, syscall.O_RDONLY, 0)
+ if err != nil {
+ _ = tmpFile.Close()
+ _ = os.Remove(launcherPath)
+ return fmt.Errorf("failed to open launcher for exec: %w", err)
+ }
+ fdsToClose = append(fdsToClose, execFd)
+
+ // Close the write file descriptor (to avoid ETXTBSY)
+ _ = tmpFile.Close()
+
+ // Unlink the file from disk (makes it invisible and ensures it is deleted on exit)
+ _ = os.Remove(launcherPath)
+
+ // Clear close-on-exec so it remains open across syscall.Exec
+ if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil {
+ _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
+ }
+
+ // 3. Prepare arguments for the launcher.
+ args := []string{self}
+ args = append(args, os.Args[1:]...)
+
+ for i, arg := range args {
+ for j := 0; j < len(arg); j++ {
+ if arg[j] == 0 {
+ return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
+ }
+ }
+ }
+
+ // Set environment variables to tell the C launcher to join,
+ // and to tell the second wg-wrap instance that we are in a joined session.
+ env := append(os.Environ(),
+ fmt.Sprintf("WG_WRAP_JOIN_PID=%d", targetPid),
+ "WG_WRAP_JOINED=1",
+ )
+
+ err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env)
+ if err != nil {
+ return fmt.Errorf("launcher exec failed: %w", err)
+ }
+
+ return nil
+}
diff --git a/internal/namespace/namespace_stub.go b/internal/namespace/namespace_stub.go
index 84946bf..db0ec24 100644
--- a/internal/namespace/namespace_stub.go
+++ b/internal/namespace/namespace_stub.go
@@ -17,7 +17,7 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return fmt.Errorf("namespaces are not supported on this platform")
}
-// JoinExistingNamespace attempts to join the namespaces (user, mount, net) of an already active process.
-func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) {
- return false, fmt.Errorf("namespaces are not supported on this platform")
+// FindActiveProfilePid is a stub for non-Linux platforms.
+func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) {
+ return 0, fmt.Errorf("namespaces are not supported on this platform")
}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index 2433203..a522f17 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -6,12 +6,9 @@ import (
"fmt"
"os"
"path/filepath"
- "runtime"
"strconv"
- "syscall"
"git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
- "golang.org/x/sys/unix"
)
// PinNamespace touches the namespace path to indicate it is pinned/active.
@@ -53,77 +50,30 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return nil
}
-// JoinExistingNamespace attempts to join the namespaces (user, mount, net)
-// of an already active process running under the same profile.
-// Returns true if a namespace was successfully joined, false if no active namespace exists.
-func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) {
+// FindActiveProfilePid looks for an active PID running under the specified profile.
+// Returns 0 if no active process is found.
+func FindActiveProfilePid(pm *paths.PathManager, profile string) (int, error) {
if err := PruneStalePids(pm, profile); err != nil {
- return false, fmt.Errorf("failed to prune stale pids: %w", err)
+ return 0, fmt.Errorf("failed to prune stale pids: %w", err)
}
pidsDir := GetPidsDirPath(pm, profile)
files, err := os.ReadDir(pidsDir)
if err != nil {
if os.IsNotExist(err) {
- return false, nil
+ return 0, nil
}
- return false, fmt.Errorf("failed to read pids dir: %w", err)
+ return 0, fmt.Errorf("failed to read pids dir: %w", err)
}
- var activePid int
for _, file := range files {
pid, err := strconv.Atoi(file.Name())
if err != nil {
continue
}
// Since we already pruned stale pids, the first file we find is an active pid!
- activePid = pid
- break
+ return pid, nil
}
- if activePid == 0 {
- return false, nil
- }
-
- // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread,
- // and that the thread is not reused for other goroutines (since we never unlock it).
- runtime.LockOSThread()
-
- // Join User Namespace first
- userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid)
- userFd, err := os.Open(userNsPath)
- if err != nil {
- return false, fmt.Errorf("failed to open user namespace: %w", err)
- }
- defer func() { _ = userFd.Close() }()
-
- if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil {
- return false, fmt.Errorf("failed to join user namespace: %w", err)
- }
-
- // Join Mount Namespace
- mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid)
- mntFd, err := os.Open(mntNsPath)
- if err != nil {
- return false, fmt.Errorf("failed to open mount namespace: %w", err)
- }
- defer func() { _ = mntFd.Close() }()
-
- if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil {
- return false, fmt.Errorf("failed to join mount namespace: %w", err)
- }
-
- // Join Network Namespace
- netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid)
- netFd, err := os.Open(netNsPath)
- if err != nil {
- return false, fmt.Errorf("failed to open network namespace: %w", err)
- }
- defer func() { _ = netFd.Close() }()
-
- if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil {
- return false, fmt.Errorf("failed to join network namespace: %w", err)
- }
-
- return true, nil
+ return 0, nil
}
diff --git a/internal/paths/paths.go b/internal/paths/paths.go
index f512ad1..c7bdd94 100644
--- a/internal/paths/paths.go
+++ b/internal/paths/paths.go
@@ -44,6 +44,10 @@ func (pm *PathManager) RuntimeBaseDir() string {
return pm.RuntimeBaseOverride
}
+ if envDir := os.Getenv("WG_WRAP_HOST_RUNTIME_BASE_DIR"); envDir != "" {
+ return envDir
+ }
+
if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" {
return envDir
}