summaryrefslogtreecommitdiff
path: root/internal/cli
diff options
context:
space:
mode:
Diffstat (limited to 'internal/cli')
-rw-r--r--internal/cli/cli.go91
-rw-r--r--internal/cli/cli_test.go27
-rw-r--r--internal/cli/profile_test.go9
3 files changed, 125 insertions, 2 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index af408c5..85b9ae3 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -117,17 +117,29 @@ func (a *App) Run() error {
cfg.Profile = "default"
}
+ if !IsValidProfileName(cfg.Profile) {
+ return fmt.Errorf("invalid profile name: %q (only alphanumeric characters, underscores, and hyphens are allowed)", cfg.Profile)
+ }
+
if namespace.IsIsolated() {
return a.ExecuteCommand(cfg)
}
+ pm := a.getPathManager()
+
+ // Acquire startup lock to prevent concurrent bootstrap/joining races
+ lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile)
+ if lockErr == nil {
+ defer namespace.ReleaseProfileLock(lockFile)
+ }
+
// Before bootstrapping, see if an active namespace/process for the profile exists.
// If yes, we can join it!
- pm := a.getPathManager()
joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile)
if err == nil && joined {
// We have joined the active namespace (user, mnt, net).
- // We can now execute the command immediately in this context!
+ // Release the lock before executing the command to allow others to join
+ namespace.ReleaseProfileLock(lockFile)
return a.ExecuteCommand(cfg)
}
@@ -145,10 +157,16 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
pm := a.getPathManager()
+ // Acquire execution lock during configuration and startup inside the namespace
+ lockFile, lockErr := namespace.AcquireProfileLock(pm, cfg.Profile)
+
if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil {
fmt.Printf("failed to prune stale pids: %v\n", err)
}
if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("failed to register process: %w", err)
}
@@ -174,6 +192,9 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
if _, err := os.Stat(profilePath); err == nil {
wgCfg, err := wgconf.Parse(profilePath)
if err != nil {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
}
@@ -184,10 +205,29 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}
if dnsServer == "" {
dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks
+ hasDefaultRoute := false
+ for _, peer := range wgCfg.Peers {
+ for _, ip := range peer.AllowedIPs {
+ trimmed := strings.TrimSpace(ip)
+ if trimmed == "0.0.0.0/0" || trimmed == "::/0" {
+ hasDefaultRoute = true
+ break
+ }
+ }
+ if hasDefaultRoute {
+ break
+ }
+ }
+ if !hasDefaultRoute {
+ fmt.Printf("warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n")
+ }
}
tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
if err != nil {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
}
defer tunnel.Close()
@@ -199,11 +239,19 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
} else {
// If profile is not default or it was explicitly requested but doesn't exist, we error
if cfg.Profile != "default" {
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
return fmt.Errorf("profile %s not found: %w", cfg.Profile, err)
}
fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n")
}
+ // Setup and initialization are complete. We can now safely release the startup lock!
+ if lockErr == nil {
+ namespace.ReleaseProfileLock(lockFile)
+ }
+
cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -249,6 +297,9 @@ func (a *App) handleProfileCmd() error {
if len(a.Args) < 4 {
return fmt.Errorf("usage: wg-wrap profile stop <name>")
}
+ if !IsValidProfileName(a.Args[3]) {
+ return fmt.Errorf("invalid profile name: %q", a.Args[3])
+ }
fmt.Printf("Stopping profile %s and unpinning namespace...\n", a.Args[3])
pm := a.getPathManager()
if err := namespace.UnpinNamespace(pm, a.Args[3]); err != nil {
@@ -262,6 +313,9 @@ func (a *App) handleProfileCmd() error {
}
func (a *App) handleProfileConfigure(name string) error {
+ if !IsValidProfileName(name) {
+ return fmt.Errorf("invalid profile name: %q", name)
+ }
profilesDir := a.getPathManager().ConfigDir()
profilePath := filepath.Join(profilesDir, name+".conf")
if _, err := os.Stat(profilePath); os.IsNotExist(err) {
@@ -329,7 +383,15 @@ func (a *App) handleProfileImport(srcPath string, name string) error {
}
}
+ if !IsValidProfileName(name) {
+ return fmt.Errorf("invalid profile name: %q", name)
+ }
+
destPath := filepath.Join(profilesDir, name+".conf")
+ if _, err := os.Stat(destPath); err == nil {
+ return fmt.Errorf("profile '%s' already exists", name)
+ }
+
data, err := os.ReadFile(srcPath)
if err != nil {
return fmt.Errorf("failed to read source file: %w", err)
@@ -344,6 +406,9 @@ func (a *App) handleProfileImport(srcPath string, name string) error {
}
func (a *App) handleProfileDelete(name string) error {
+ if !IsValidProfileName(name) {
+ return fmt.Errorf("invalid profile name: %q", name)
+ }
profilesDir := a.getPathManager().ConfigDir()
destPath := filepath.Join(profilesDir, name+".conf")
@@ -369,6 +434,10 @@ func (a *App) showConfig() error {
_ = fs.Parse(a.Args[2:])
}
+ if !IsValidProfileName(cfg.Profile) {
+ return fmt.Errorf("invalid profile name: %q", cfg.Profile)
+ }
+
pm := paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir)
profilePath := pm.ProfileNamespacePath(cfg.Profile)
pidsPath := pm.ProfilePidsDir(cfg.Profile)
@@ -384,3 +453,21 @@ func (a *App) showConfig() error {
fmt.Printf(" UID: %d\n", os.Getuid())
return nil
}
+
+// IsValidProfileName checks if a WireGuard profile name is safe and valid.
+// It allows only alphanumeric characters, underscores, and hyphens, and prevents
+// directory traversal attacks and hidden files.
+func IsValidProfileName(name string) bool {
+ if name == "" {
+ return false
+ }
+ for _, r := range name {
+ if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' && r != '-' {
+ return false
+ }
+ }
+ if name == "." || name == ".." || strings.HasPrefix(name, "-") {
+ return false
+ }
+ return true
+}
diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go
index fcf489a..aea80f7 100644
--- a/internal/cli/cli_test.go
+++ b/internal/cli/cli_test.go
@@ -57,3 +57,30 @@ AllowedIPs = 10.0.0.0/24
})
}
}
+
+func TestIsValidProfileName(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"default", true},
+ {"home", true},
+ {"work-vpn", true},
+ {"my_vpn_123", true},
+ {"", false},
+ {"..", false},
+ {"../home", false},
+ {"/etc/shadow", false},
+ {"-profile", false},
+ {"profile.conf", false}, // we append .conf so the name itself shouldn't have .
+ {"foo/bar", false},
+ {"foo\\bar", false},
+ }
+
+ for _, tt := range tests {
+ got := IsValidProfileName(tt.name)
+ if got != tt.want {
+ t.Errorf("IsValidProfileName(%q) = %v; want %v", tt.name, got, tt.want)
+ }
+ }
+}
diff --git a/internal/cli/profile_test.go b/internal/cli/profile_test.go
index c9b1274..e08ffc5 100644
--- a/internal/cli/profile_test.go
+++ b/internal/cli/profile_test.go
@@ -70,6 +70,15 @@ func TestProfileImport(t *testing.T) {
if _, err := os.Stat(destCustomFile); os.IsNotExist(err) {
t.Errorf("expected profile to be imported to %s", destCustomFile)
}
+
+ // 3. Test duplicate import (should fail)
+ appDup := NewApp([]string{"wg-wrap", "profile", "import", srcFile, customName})
+ appDup.ConfigDir = profilesDir
+
+ err = appDup.Route()
+ if err == nil {
+ t.Errorf("expected error when importing duplicate profile, got nil")
+ }
}
func TestProfileDelete(t *testing.T) {