diff options
| -rw-r--r-- | internal/cli/cli.go | 91 | ||||
| -rw-r--r-- | internal/cli/cli_test.go | 27 | ||||
| -rw-r--r-- | internal/cli/profile_test.go | 9 | ||||
| -rw-r--r-- | internal/namespace/launcher_src/launcher.c | 2 | ||||
| -rw-r--r-- | internal/namespace/lock_linux.go | 40 | ||||
| -rw-r--r-- | internal/namespace/lock_stub.go | 18 | ||||
| -rw-r--r-- | internal/namespace/namespace.go | 26 | ||||
| -rw-r--r-- | internal/namespace/pinning.go | 5 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 5 |
9 files changed, 213 insertions, 10 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index af408c5..85b9ae3 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -117,17 +117,29 @@ func (a *App) Run() error { cfg.Profile = "default" } + if !IsValidProfileName(cfg.Profile) { + return fmt.Errorf("invalid profile name: %q (only alphanumeric characters, underscores, and hyphens are allowed)", cfg.Profile) + } + if namespace.IsIsolated() { return a.ExecuteCommand(cfg) } + pm := a.getPathManager() + + // Acquire startup lock to prevent concurrent bootstrap/joining races + lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) + if lockErr == nil { + defer namespace.ReleaseProfileLock(lockFile) + } + // Before bootstrapping, see if an active namespace/process for the profile exists. // If yes, we can join it! - pm := a.getPathManager() joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile) if err == nil && joined { // We have joined the active namespace (user, mnt, net). - // We can now execute the command immediately in this context! + // Release the lock before executing the command to allow others to join + namespace.ReleaseProfileLock(lockFile) return a.ExecuteCommand(cfg) } @@ -145,10 +157,16 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { pm := a.getPathManager() + // Acquire execution lock during configuration and startup inside the namespace + lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile) + if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil { fmt.Printf("failed to prune stale pids: %v\n", err) } if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("failed to register process: %w", err) } @@ -174,6 +192,9 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { if _, err := os.Stat(profilePath); err == nil { wgCfg, err := wgconf.Parse(profilePath) if err != nil { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err) } @@ -184,10 +205,29 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } if dnsServer == "" { dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks + hasDefaultRoute := false + for _, peer := range wgCfg.Peers { + for _, ip := range peer.AllowedIPs { + trimmed := strings.TrimSpace(ip) + if trimmed == "0.0.0.0/0" || trimmed == "::/0" { + hasDefaultRoute = true + break + } + } + if hasDefaultRoute { + break + } + } + if !hasDefaultRoute { + fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n") + } } tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer) if err != nil { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("failed to start WireGuard tunnel: %w", err) } defer tunnel.Close() @@ -199,11 +239,19 @@ func (a *App) ExecuteCommand(cfg *config.Config) error { } else { // If profile is not default or it was explicitly requested but doesn't exist, we error if cfg.Profile != "default" { + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } return fmt.Errorf("profile %s not found: %w", cfg.Profile, err) } fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n") } + // Setup and initialization are complete. We can now safely release the startup lock! + if lockErr == nil { + namespace.ReleaseProfileLock(lockFile) + } + cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -249,6 +297,9 @@ func (a *App) handleProfileCmd() error { if len(a.Args) < 4 { return fmt.Errorf("usage: wg-wrap profile stop <name>") } + if !IsValidProfileName(a.Args[3]) { + return fmt.Errorf("invalid profile name: %q", a.Args[3]) + } fmt.Printf("Stopping profile %s and unpinning namespace...\n", a.Args[3]) pm := a.getPathManager() if err := namespace.UnpinNamespace(pm, a.Args[3]); err != nil { @@ -262,6 +313,9 @@ func (a *App) handleProfileCmd() error { } func (a *App) handleProfileConfigure(name string) error { + if !IsValidProfileName(name) { + return fmt.Errorf("invalid profile name: %q", name) + } profilesDir := a.getPathManager().ConfigDir() profilePath := filepath.Join(profilesDir, name+".conf") if _, err := os.Stat(profilePath); os.IsNotExist(err) { @@ -329,7 +383,15 @@ func (a *App) handleProfileImport(srcPath string, name string) error { } } + if !IsValidProfileName(name) { + return fmt.Errorf("invalid profile name: %q", name) + } + destPath := filepath.Join(profilesDir, name+".conf") + if _, err := os.Stat(destPath); err == nil { + return fmt.Errorf("profile '%s' already exists", name) + } + data, err := os.ReadFile(srcPath) if err != nil { return fmt.Errorf("failed to read source file: %w", err) @@ -344,6 +406,9 @@ func (a *App) handleProfileImport(srcPath string, name string) error { } func (a *App) handleProfileDelete(name string) error { + if !IsValidProfileName(name) { + return fmt.Errorf("invalid profile name: %q", name) + } profilesDir := a.getPathManager().ConfigDir() destPath := filepath.Join(profilesDir, name+".conf") @@ -369,6 +434,10 @@ func (a *App) showConfig() error { _ = fs.Parse(a.Args[2:]) } + if !IsValidProfileName(cfg.Profile) { + return fmt.Errorf("invalid profile name: %q", cfg.Profile) + } + pm := paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir) profilePath := pm.ProfileNamespacePath(cfg.Profile) pidsPath := pm.ProfilePidsDir(cfg.Profile) @@ -384,3 +453,21 @@ func (a *App) showConfig() error { fmt.Printf(" UID: %d\n", os.Getuid()) return nil } + +// IsValidProfileName checks if a WireGuard profile name is safe and valid. +// It allows only alphanumeric characters, underscores, and hyphens, and prevents +// directory traversal attacks and hidden files. +func IsValidProfileName(name string) bool { + if name == "" { + return false + } + for _, r := range name { + if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' && r != '-' { + return false + } + } + if name == "." || name == ".." || strings.HasPrefix(name, "-") { + return false + } + return true +} diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index fcf489a..aea80f7 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -57,3 +57,30 @@ AllowedIPs = 10.0.0.0/24 }) } } + +func TestIsValidProfileName(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"default", true}, + {"home", true}, + {"work-vpn", true}, + {"my_vpn_123", true}, + {"", false}, + {"..", false}, + {"../home", false}, + {"/etc/shadow", false}, + {"-profile", false}, + {"profile.conf", false}, // we append .conf so the name itself shouldn't have . + {"foo/bar", false}, + {"foo\\bar", false}, + } + + for _, tt := range tests { + got := IsValidProfileName(tt.name) + if got != tt.want { + t.Errorf("IsValidProfileName(%q) = %v; want %v", tt.name, got, tt.want) + } + } +} diff --git a/internal/cli/profile_test.go b/internal/cli/profile_test.go index c9b1274..e08ffc5 100644 --- a/internal/cli/profile_test.go +++ b/internal/cli/profile_test.go @@ -70,6 +70,15 @@ func TestProfileImport(t *testing.T) { if _, err := os.Stat(destCustomFile); os.IsNotExist(err) { t.Errorf("expected profile to be imported to %s", destCustomFile) } + + // 3. Test duplicate import (should fail) + appDup := NewApp([]string{"wg-wrap", "profile", "import", srcFile, customName}) + appDup.ConfigDir = profilesDir + + err = appDup.Route() + if err == nil { + t.Errorf("expected error when importing duplicate profile, got nil") + } } func TestProfileDelete(t *testing.T) { diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c index e108da6..7bbbce7 100644 --- a/internal/namespace/launcher_src/launcher.c +++ b/internal/namespace/launcher_src/launcher.c @@ -9,7 +9,7 @@ int main(int argc, char **argv) { if (argc < 1) { - fprintf(stderr, "Usage: %s <command> [args...]\n", argv[0]); + fprintf(stderr, "Usage: launcher <command> [args...]\n"); return 1; } diff --git a/internal/namespace/lock_linux.go b/internal/namespace/lock_linux.go new file mode 100644 index 0000000..8da98f6 --- /dev/null +++ b/internal/namespace/lock_linux.go @@ -0,0 +1,40 @@ +//go:build linux + +package namespace + +import ( + "fmt" + "os" + "path/filepath" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" + "golang.org/x/sys/unix" +) + +// AcquireProfileLock locks the profile to prevent concurrent startup races. +func AcquireProfileLock(pm *paths.PathManager, profile string) (*os.File, error) { + lockPath := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile+".lock") + if err := os.MkdirAll(filepath.Dir(lockPath), 0755); err != nil { + return nil, fmt.Errorf("failed to create lock directory: %w", err) + } + + file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, fmt.Errorf("failed to open lock file: %w", err) + } + + if err := unix.Flock(int(file.Fd()), unix.LOCK_EX); err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to lock profile: %w", err) + } + + return file, nil +} + +// ReleaseProfileLock unlocks the profile. +func ReleaseProfileLock(file *os.File) { + if file != nil { + _ = unix.Flock(int(file.Fd()), unix.LOCK_UN) + _ = file.Close() + } +} diff --git a/internal/namespace/lock_stub.go b/internal/namespace/lock_stub.go new file mode 100644 index 0000000..2282852 --- /dev/null +++ b/internal/namespace/lock_stub.go @@ -0,0 +1,18 @@ +//go:build !linux + +package namespace + +import ( + "os" + + "git.theodohertyfamily.com/tools/wg-wrap/internal/paths" +) + +// AcquireProfileLock is a stub for non-Linux platforms. +func AcquireProfileLock(pm *paths.PathManager, profile string) (*os.File, error) { + return nil, nil +} + +// ReleaseProfileLock is a stub for non-Linux platforms. +func ReleaseProfileLock(file *os.File) { +} diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go index e2ef2f1..ab3797d 100644 --- a/internal/namespace/namespace.go +++ b/internal/namespace/namespace.go @@ -69,11 +69,20 @@ func VerifyArguments(args []string) error { // Bootstrap ensures the process is running in an isolated user and network namespace. // It writes the embedded C launcher to a temporary file and replaces the current process. -func Bootstrap() error { +func Bootstrap() (err error) { if IsIsolated() { return nil } + var fdsToClose []int + defer func() { + if err != nil { + for _, fd := range fdsToClose { + _ = syscall.Close(fd) + } + } + }() + // 0. Validate current arguments for null bytes before proceeding. // If any argument contains a null byte, syscall.Exec will fail with 'invalid argument'. for i, arg := range os.Args { @@ -118,6 +127,7 @@ func Bootstrap() error { _ = os.Remove(launcherPath) return fmt.Errorf("failed to open launcher for exec: %w", err) } + fdsToClose = append(fdsToClose, execFd) // Close the write file descriptor (to avoid ETXTBSY) _ = tmpFile.Close() @@ -152,6 +162,8 @@ func Bootstrap() error { if err != nil { return fmt.Errorf("failed to open host netns: %w", err) } + fdsToClose = append(fdsToClose, hostNetFd) + // Clear close-on-exec so it remains open across syscall.Exec if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil { _, _ = unix.FcntlInt(uintptr(hostNetFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) @@ -160,15 +172,17 @@ func Bootstrap() error { env := append(os.Environ(), fmt.Sprintf("WG_WRAP_HOST_NETNS_FD=%d", hostNetFd)) // Open a host UDP socket on 0.0.0.0:0 before unsharing network namespace. - laddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") - if err == nil { - if conn, err := net.ListenUDP("udp", laddr); err == nil { - if file, err := conn.File(); err == nil { + laddr, errAddr := net.ResolveUDPAddr("udp", "0.0.0.0:0") + if errAddr == nil { + if conn, errConn := net.ListenUDP("udp", laddr); errConn == nil { + if file, errFile := conn.File(); errFile == nil { hostSocketFd := file.Fd() - if flags, err := unix.FcntlInt(hostSocketFd, unix.F_GETFD, 0); err == nil { + if flags, fcntlErr := unix.FcntlInt(hostSocketFd, unix.F_GETFD, 0); fcntlErr == nil { _, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC) } env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd)) + fdsToClose = append(fdsToClose, int(hostSocketFd)) + _ = conn.Close() } } } diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go index eb0a376..2433203 100644 --- a/internal/namespace/pinning.go +++ b/internal/namespace/pinning.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "strconv" "syscall" @@ -84,6 +85,10 @@ func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) return false, nil } + // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread, + // and that the thread is not reused for other goroutines (since we never unlock it). + runtime.LockOSThread() + // Join User Namespace first userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid) userFd, err := os.Open(userNsPath) diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 5bbc518..48bd562 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -297,7 +297,10 @@ func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, // Switch this thread back to the isolated network namespace if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil { _ = h.inner.Close() - return nil, 0, fmt.Errorf("failed to restore isolated netns: %w", err) + // CRITICAL: The thread is stuck in the host network namespace. Returning it to the Go runtime pool + // will cause other goroutines to run in the host namespace, breaching isolation. We must panic + // immediately to abort the process and prevent a namespace escape. + panic(fmt.Sprintf("CRITICAL: failed to restore isolated netns: %v", err)) } return fns, actualPort, nil |
