summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/cli/cli.go91
-rw-r--r--internal/cli/cli_test.go27
-rw-r--r--internal/cli/profile_test.go9
-rw-r--r--internal/namespace/launcher_src/launcher.c2
-rw-r--r--internal/namespace/lock_linux.go40
-rw-r--r--internal/namespace/lock_stub.go18
-rw-r--r--internal/namespace/namespace.go26
-rw-r--r--internal/namespace/pinning.go5
-rw-r--r--internal/wireguard/wireguard.go5
9 files changed, 213 insertions, 10 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index af408c5..85b9ae3 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -117,17 +117,29 @@ func (a *App) Run() error {
cfg.Profile = "default"
}
+ if !IsValidProfileName(cfg.Profile) {
+ return fmt.Errorf("invalid profile name: %q (only alphanumeric characters, underscores, and hyphens are allowed)", cfg.Profile)
+ }
+
if namespace.IsIsolated() {
return a.ExecuteCommand(cfg)
}
+ pm := a.getPathManager()
+
+ // Acquire startup lock to prevent concurrent bootstrap/joining races
+ lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile)
+ if lockErr == nil {
+ defer namespace.ReleaseProfileLock(lockFile)
+ }
+
// Before bootstrapping, see if an active namespace/process for the profile exists.
// If yes, we can join it!
- pm := a.getPathManager()
joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile)
if err == nil && joined {
// We have joined the active namespace (user, mnt, net).
- // We can now execute the command immediately in this context!
+ // Release the lock before executing the command to allow others to join
+ namespace.ReleaseProfileLock(lockFile)
return a.ExecuteCommand(cfg)
}
@@ -145,10 +157,16 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
pm := a.getPathManager()
+ // Acquire execution lock during configuration and startup inside the namespace
+ lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile)
+
if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil {
fmt.Printf("failed to prune stale pids: %v\n", err)
}
if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("failed to register process: %w", err)
}
@@ -174,6 +192,9 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
if _, err := os.Stat(profilePath); err == nil {
wgCfg, err := wgconf.Parse(profilePath)
if err != nil {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
}
@@ -184,10 +205,29 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}
if dnsServer == "" {
dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks
+ hasDefaultRoute := false
+ for _, peer := range wgCfg.Peers {
+ for _, ip := range peer.AllowedIPs {
+ trimmed := strings.TrimSpace(ip)
+ if trimmed == "0.0.0.0/0" || trimmed == "::/0" {
+ hasDefaultRoute = true
+ break
+ }
+ }
+ if hasDefaultRoute {
+ break
+ }
+ }
+ if !hasDefaultRoute {
+ fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n")
+ }
}
tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
if err != nil {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
}
defer tunnel.Close()
@@ -199,11 +239,19 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
} else {
// If profile is not default or it was explicitly requested but doesn't exist, we error
if cfg.Profile != "default" {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("profile %s not found: %w", cfg.Profile, err)
}
fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n")
}
+ // Setup and initialization are complete. We can now safely release the startup lock!
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
+
cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -249,6 +297,9 @@ func (a *App) handleProfileCmd() error {
if len(a.Args) < 4 {
return fmt.Errorf("usage: wg-wrap profile stop <name>")
}
+ if !IsValidProfileName(a.Args[3]) {
+ return fmt.Errorf("invalid profile name: %q", a.Args[3])
+ }
fmt.Printf("Stopping profile %s and unpinning namespace...\n", a.Args[3])
pm := a.getPathManager()
if err := namespace.UnpinNamespace(pm, a.Args[3]); err != nil {
@@ -262,6 +313,9 @@ func (a *App) handleProfileCmd() error {
}
func (a *App) handleProfileConfigure(name string) error {
+ if !IsValidProfileName(name) {
+ return fmt.Errorf("invalid profile name: %q", name)
+ }
profilesDir := a.getPathManager().ConfigDir()
profilePath := filepath.Join(profilesDir, name+".conf")
if _, err := os.Stat(profilePath); os.IsNotExist(err) {
@@ -329,7 +383,15 @@ func (a *App) handleProfileImport(srcPath string, name string) error {
}
}
+ if !IsValidProfileName(name) {
+ return fmt.Errorf("invalid profile name: %q", name)
+ }
+
destPath := filepath.Join(profilesDir, name+".conf")
+ if _, err := os.Stat(destPath); err == nil {
+ return fmt.Errorf("profile '%s' already exists", name)
+ }
+
data, err := os.ReadFile(srcPath)
if err != nil {
return fmt.Errorf("failed to read source file: %w", err)
@@ -344,6 +406,9 @@ func (a *App) handleProfileImport(srcPath string, name string) error {
}
func (a *App) handleProfileDelete(name string) error {
+ if !IsValidProfileName(name) {
+ return fmt.Errorf("invalid profile name: %q", name)
+ }
profilesDir := a.getPathManager().ConfigDir()
destPath := filepath.Join(profilesDir, name+".conf")
@@ -369,6 +434,10 @@ func (a *App) showConfig() error {
_ = fs.Parse(a.Args[2:])
}
+ if !IsValidProfileName(cfg.Profile) {
+ return fmt.Errorf("invalid profile name: %q", cfg.Profile)
+ }
+
pm := paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir)
profilePath := pm.ProfileNamespacePath(cfg.Profile)
pidsPath := pm.ProfilePidsDir(cfg.Profile)
@@ -384,3 +453,21 @@ func (a *App) showConfig() error {
fmt.Printf(" UID: %d\n", os.Getuid())
return nil
}
+
+// IsValidProfileName checks if a WireGuard profile name is safe and valid.
+// It allows only alphanumeric characters, underscores, and hyphens, and prevents
+// directory traversal attacks and hidden files.
+func IsValidProfileName(name string) bool {
+ if name == "" {
+ return false
+ }
+ for _, r := range name {
+ if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' && r != '-' {
+ return false
+ }
+ }
+ if name == "." || name == ".." || strings.HasPrefix(name, "-") {
+ return false
+ }
+ return true
+}
diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go
index fcf489a..aea80f7 100644
--- a/internal/cli/cli_test.go
+++ b/internal/cli/cli_test.go
@@ -57,3 +57,30 @@ AllowedIPs = 10.0.0.0/24
})
}
}
+
+func TestIsValidProfileName(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"default", true},
+ {"home", true},
+ {"work-vpn", true},
+ {"my_vpn_123", true},
+ {"", false},
+ {"..", false},
+ {"../home", false},
+ {"/etc/shadow", false},
+ {"-profile", false},
+ {"profile.conf", false}, // we append .conf so the name itself shouldn't have .
+ {"foo/bar", false},
+ {"foo\\bar", false},
+ }
+
+ for _, tt := range tests {
+ got := IsValidProfileName(tt.name)
+ if got != tt.want {
+ t.Errorf("IsValidProfileName(%q) = %v; want %v", tt.name, got, tt.want)
+ }
+ }
+}
diff --git a/internal/cli/profile_test.go b/internal/cli/profile_test.go
index c9b1274..e08ffc5 100644
--- a/internal/cli/profile_test.go
+++ b/internal/cli/profile_test.go
@@ -70,6 +70,15 @@ func TestProfileImport(t *testing.T) {
if _, err := os.Stat(destCustomFile); os.IsNotExist(err) {
t.Errorf("expected profile to be imported to %s", destCustomFile)
}
+
+ // 3. Test duplicate import (should fail)
+ appDup := NewApp([]string{"wg-wrap", "profile", "import", srcFile, customName})
+ appDup.ConfigDir = profilesDir
+
+ err = appDup.Route()
+ if err == nil {
+ t.Errorf("expected error when importing duplicate profile, got nil")
+ }
}
func TestProfileDelete(t *testing.T) {
diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c
index e108da6..7bbbce7 100644
--- a/internal/namespace/launcher_src/launcher.c
+++ b/internal/namespace/launcher_src/launcher.c
@@ -9,7 +9,7 @@
int main(int argc, char **argv) {
if (argc < 1) {
- fprintf(stderr, "Usage: %s <command> [args...]\n", argv[0]);
+ fprintf(stderr, "Usage: launcher <command> [args...]\n");
return 1;
}
diff --git a/internal/namespace/lock_linux.go b/internal/namespace/lock_linux.go
new file mode 100644
index 0000000..8da98f6
--- /dev/null
+++ b/internal/namespace/lock_linux.go
@@ -0,0 +1,40 @@
+//go:build linux
+
+package namespace
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
+)
+
+// AcquireProfileLock locks the profile to prevent concurrent startup races.
+func AcquireProfileLock(pm *paths.PathManager, profile string) (*os.File, error) {
+ lockPath := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile+".lock")
+ if err := os.MkdirAll(filepath.Dir(lockPath), 0755); err != nil {
+ return nil, fmt.Errorf("failed to create lock directory: %w", err)
+ }
+
+ file, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0600)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open lock file: %w", err)
+ }
+
+ if err := unix.Flock(int(file.Fd()), unix.LOCK_EX); err != nil {
+ _ = file.Close()
+ return nil, fmt.Errorf("failed to lock profile: %w", err)
+ }
+
+ return file, nil
+}
+
+// ReleaseProfileLock unlocks the profile.
+func ReleaseProfileLock(file *os.File) {
+ if file != nil {
+ _ = unix.Flock(int(file.Fd()), unix.LOCK_UN)
+ _ = file.Close()
+ }
+}
diff --git a/internal/namespace/lock_stub.go b/internal/namespace/lock_stub.go
new file mode 100644
index 0000000..2282852
--- /dev/null
+++ b/internal/namespace/lock_stub.go
@@ -0,0 +1,18 @@
+//go:build !linux
+
+package namespace
+
+import (
+ "os"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+)
+
+// AcquireProfileLock is a stub for non-Linux platforms.
+func AcquireProfileLock(pm *paths.PathManager, profile string) (*os.File, error) {
+ return nil, nil
+}
+
+// ReleaseProfileLock is a stub for non-Linux platforms.
+func ReleaseProfileLock(file *os.File) {
+}
diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go
index e2ef2f1..ab3797d 100644
--- a/internal/namespace/namespace.go
+++ b/internal/namespace/namespace.go
@@ -69,11 +69,20 @@ func VerifyArguments(args []string) error {
// Bootstrap ensures the process is running in an isolated user and network namespace.
// It writes the embedded C launcher to a temporary file and replaces the current process.
-func Bootstrap() error {
+func Bootstrap() (err error) {
if IsIsolated() {
return nil
}
+ var fdsToClose []int
+ defer func() {
+ if err != nil {
+ for _, fd := range fdsToClose {
+ _ = syscall.Close(fd)
+ }
+ }
+ }()
+
// 0. Validate current arguments for null bytes before proceeding.
// If any argument contains a null byte, syscall.Exec will fail with 'invalid argument'.
for i, arg := range os.Args {
@@ -118,6 +127,7 @@ func Bootstrap() error {
_ = os.Remove(launcherPath)
return fmt.Errorf("failed to open launcher for exec: %w", err)
}
+ fdsToClose = append(fdsToClose, execFd)
// Close the write file descriptor (to avoid ETXTBSY)
_ = tmpFile.Close()
@@ -152,6 +162,8 @@ func Bootstrap() error {
if err != nil {
return fmt.Errorf("failed to open host netns: %w", err)
}
+ fdsToClose = append(fdsToClose, hostNetFd)
+
// Clear close-on-exec so it remains open across syscall.Exec
if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil {
_, _ = unix.FcntlInt(uintptr(hostNetFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
@@ -160,15 +172,17 @@ func Bootstrap() error {
env := append(os.Environ(), fmt.Sprintf("WG_WRAP_HOST_NETNS_FD=%d", hostNetFd))
// Open a host UDP socket on 0.0.0.0:0 before unsharing network namespace.
- laddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
- if err == nil {
- if conn, err := net.ListenUDP("udp", laddr); err == nil {
- if file, err := conn.File(); err == nil {
+ laddr, errAddr := net.ResolveUDPAddr("udp", "0.0.0.0:0")
+ if errAddr == nil {
+ if conn, errConn := net.ListenUDP("udp", laddr); errConn == nil {
+ if file, errFile := conn.File(); errFile == nil {
hostSocketFd := file.Fd()
- if flags, err := unix.FcntlInt(hostSocketFd, unix.F_GETFD, 0); err == nil {
+ if flags, fcntlErr := unix.FcntlInt(hostSocketFd, unix.F_GETFD, 0); fcntlErr == nil {
_, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC)
}
env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd))
+ fdsToClose = append(fdsToClose, int(hostSocketFd))
+ _ = conn.Close()
}
}
}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index eb0a376..2433203 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "runtime"
"strconv"
"syscall"
@@ -84,6 +85,10 @@ func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error)
return false, nil
}
+ // Lock the OS thread before joining namespaces to ensure this goroutine stays on the modified thread,
+ // and that the thread is not reused for other goroutines (since we never unlock it).
+ runtime.LockOSThread()
+
// Join User Namespace first
userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid)
userFd, err := os.Open(userNsPath)
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 5bbc518..48bd562 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -297,7 +297,10 @@ func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16,
// Switch this thread back to the isolated network namespace
if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil {
_ = h.inner.Close()
- return nil, 0, fmt.Errorf("failed to restore isolated netns: %w", err)
+ // CRITICAL: The thread is stuck in the host network namespace. Returning it to the Go runtime pool
+ // will cause other goroutines to run in the host namespace, breaching isolation. We must panic
+ // immediately to abort the process and prevent a namespace escape.
+ panic(fmt.Sprintf("CRITICAL: failed to restore isolated netns: %v", err))
}
return fns, actualPort, nil