summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/cli/cli.go178
-rw-r--r--internal/cli/cli_test.go4
-rw-r--r--internal/cli/profile_test.go125
-rw-r--r--internal/config/config.go18
-rw-r--r--internal/namespace/lifecycle.go45
-rw-r--r--internal/namespace/lifecycle_test.go45
-rw-r--r--internal/namespace/pinning.go35
-rw-r--r--internal/namespace/pinning_test.go53
-rw-r--r--internal/paths/paths.go62
-rw-r--r--pkg/wgconf/wgconf.go97
-rw-r--r--pkg/wgconf/wgconf_test.go66
-rw-r--r--tests/e2e/config_test.go10
-rw-r--r--tests/e2e/lifecycle_test.go10
13 files changed, 618 insertions, 130 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 13a4a6b..66b5f79 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -5,12 +5,15 @@ import (
"fmt"
"os"
"os/exec"
+ "path/filepath"
+ "strings"
"git.theodohertyfamily.com/tools/wg-wrap/internal/config"
"git.theodohertyfamily.com/tools/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+ "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf"
)
-
type App struct {
Args []string
ConfigDir string // Optional override for profile storage location
@@ -21,8 +24,11 @@ func NewApp(args []string) *App {
return &App{Args: args}
}
+func (a *App) getPathManager() *paths.PathManager {
+ return paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir)
+}
+
func (a *App) Route() error {
- // 1. Validate arguments for null bytes to prevent exec failures in the C launcher
for i, arg := range a.Args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
@@ -31,7 +37,6 @@ func (a *App) Route() error {
}
}
- // Handle the internal diagnostic commands that should run on the HOST
if len(a.Args) > 1 {
switch a.Args[1] {
case "show-config":
@@ -39,18 +44,14 @@ func (a *App) Route() error {
}
}
- // Handle subcommands first (profile list, import, configure, delete, stop)
if len(a.Args) > 1 && a.Args[1] == "profile" {
return a.handleProfileCmd()
}
- // If we reach here, we are either wrapping a process or running a command
- // that requires isolation (e.g., test-ns, test-args).
return a.Run()
}
func (a *App) Run() error {
- // Handle the internal diagnostic commands that require ISOLATION
if len(a.Args) > 1 {
switch a.Args[1] {
case "test-ns":
@@ -78,7 +79,7 @@ func (a *App) Run() error {
cfg := &config.Config{}
fs := flag.NewFlagSet("wg-wrap", flag.ExitOnError)
- fs.StringVar(&cfg.Profile, "profile", "", "WireGuard profile to use (filename without extension in ~/.config/wg-wrap/profiles/)")
+ fs.StringVar(&cfg.Profile, "profile", "", "WireGuard profile to use")
fs.StringVar(&cfg.DNSServer, "dns-server", "", "Override DNS server to use")
args := a.Args[1:]
@@ -115,57 +116,46 @@ func (a *App) Run() error {
cfg.Profile = "default"
}
- // If we are already isolated, we enter the execution phase.
if namespace.IsIsolated() {
return a.ExecuteCommand(cfg)
}
- // If we are not isolated, we bootstrap.
- // The Bootstrap process will replace this process and restart it.
if err := namespace.Bootstrap(); err != nil {
return fmt.Errorf("bootstrap failed: %w", err)
}
- // This point is never reached because Bootstrap uses syscall.Exec
return nil
}
-// ExecuteCommand handles the isolated execution of the target application.
-// This is called after the bootstrap loop has successfully isolated the process.
func (a *App) ExecuteCommand(cfg *config.Config) error {
if !namespace.IsIsolated() {
return fmt.Errorf("ExecuteCommand called without namespace isolation")
}
- // 1. Prepare the namespace
- baseDir := a.RuntimeBaseDir
- if baseDir == "" {
- // Use XDG_RUNTIME_DIR or default via the namespace package
- // Since the namespace package now handles the default in GetProfileNamespacePath,
- // we can pass empty string if no override is present.
- baseDir = ""
- }
+ pm := a.getPathManager()
- namespace.PruneStalePids(baseDir, cfg.Profile)
- if err := namespace.RegisterProcess(baseDir, cfg.Profile); err != nil {
+ if err := namespace.PruneStalePids(pm, cfg.Profile); err != nil {
+ fmt.Printf("failed to prune stale pids: %v\n", err)
+ }
+ if err := namespace.RegisterProcess(pm, cfg.Profile); err != nil {
return fmt.Errorf("failed to register process: %w", err)
}
- // Ensure we unregister and check for cleanup on exit
defer func() {
- namespace.UnregisterProcess(baseDir, cfg.Profile)
- if last, err := namespace.IsLastProcess(baseDir, cfg.Profile); err == nil && last {
+ if err := namespace.UnregisterProcess(pm, cfg.Profile); err != nil {
+ fmt.Printf("failed to unregister process: %v\n", err)
+ }
+ if last, err := namespace.IsLastProcess(pm, cfg.Profile); err == nil && last {
fmt.Printf("Last process exiting. Cleaning up profile %s...\n", cfg.Profile)
- // Here we would call namespace.UnpinNamespace(baseDir, cfg.Profile)
- // and terminate the userspace WG process.
+ if err := namespace.UnpinNamespace(pm, cfg.Profile); err != nil {
+ fmt.Printf("failed to unpin namespace: %v\n", err)
+ }
}
}()
- // 2. VPN Setup (Stubbed)
fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile)
// TODO: Integrate with internal/wireguard to set up TUN and WG-Go
- // 3. Execute the target command
cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
@@ -187,69 +177,147 @@ func (a *App) handleProfileCmd() error {
subCmd := a.Args[2]
switch subCmd {
case "list":
- fmt.Println("Listing profiles...")
- return fmt.Errorf("profile list not yet implemented")
+ return a.handleProfileList()
case "import":
if len(a.Args) < 4 {
return fmt.Errorf("usage: wg-wrap profile import <path>")
}
- fmt.Printf("Importing profile from %s...\n", a.Args[3])
- return fmt.Errorf("profile import not yet implemented")
+ return a.handleProfileImport(a.Args[3])
case "configure":
if len(a.Args) < 4 {
return fmt.Errorf("usage: wg-wrap profile configure <name>")
}
- fmt.Printf("Configuring profile %s...\n", a.Args[3])
- return fmt.Errorf("profile configure not yet implemented")
+ return a.handleProfileConfigure(a.Args[3])
case "delete":
if len(a.Args) < 4 {
return fmt.Errorf("usage: wg-wrap profile delete <name>")
}
- fmt.Printf("Deleting profile %s...\n", a.Args[3])
- return fmt.Errorf("profile delete not yet implemented")
+ return a.handleProfileDelete(a.Args[3])
case "stop":
if len(a.Args) < 4 {
return fmt.Errorf("usage: wg-wrap profile stop <name>")
}
fmt.Printf("Stopping profile %s and unpinning namespace...\n", a.Args[3])
- return fmt.Errorf("profile stop not yet implemented")
+ pm := a.getPathManager()
+ if err := namespace.UnpinNamespace(pm, a.Args[3]); err != nil {
+ return fmt.Errorf("failed to stop profile: %w", err)
+ }
+ fmt.Printf("Profile %s stopped and unpinned.\n", a.Args[3])
+ return nil
default:
return fmt.Errorf("unknown profile subcommand: %s", subCmd)
}
}
+func (a *App) handleProfileConfigure(name string) error {
+ profilesDir := a.getPathManager().ConfigDir()
+ profilePath := filepath.Join(profilesDir, name+".conf")
+ if _, err := os.Stat(profilePath); os.IsNotExist(err) {
+ return fmt.Errorf("profile '%s' not found", name)
+ }
+
+ cfg, err := wgconf.Parse(profilePath)
+ if err != nil {
+ return fmt.Errorf("failed to parse profile %s: %w", name, err)
+ }
+
+ fmt.Printf("Editing profile %s...\n", name)
+ fmt.Println("DNS server (current: '" + cfg.DNS + "'):")
+
+ return fmt.Errorf("interactive configuration not supported in this environment, use a config file")
+}
+
+func (a *App) handleProfileList() error {
+ profilesDir := a.getPathManager().ConfigDir()
+ entries, err := os.ReadDir(profilesDir)
+ if err != nil {
+ return fmt.Errorf("failed to read profiles directory %s: %w", profilesDir, err)
+ }
+
+ fmt.Println("Available profiles:")
+ found := false
+ for _, entry := range entries {
+ if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".conf") {
+ name := strings.TrimSuffix(entry.Name(), ".conf")
+ fmt.Printf("- %s\n", name)
+ found = true
+ }
+ }
+
+ if !found {
+ fmt.Println(" (no profiles found)")
+ }
+
+ return nil
+}
+
+func (a *App) handleProfileImport(srcPath string) error {
+ profilesDir := a.getPathManager().ConfigDir()
+ if err := os.MkdirAll(profilesDir, 0755); err != nil {
+ return fmt.Errorf("failed to create profiles directory: %w", err)
+ }
+
+ if _, err := wgconf.Parse(srcPath); err != nil {
+ return fmt.Errorf("invalid WireGuard configuration at %s: %w", srcPath, err)
+ }
+
+ baseName := filepath.Base(srcPath)
+ name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
+ if name == "" {
+ return fmt.Errorf("invalid source filename")
+ }
+
+ destPath := filepath.Join(profilesDir, name+".conf")
+ data, err := os.ReadFile(srcPath)
+ if err != nil {
+ return fmt.Errorf("failed to read source file: %w", err)
+ }
+
+ if err := os.WriteFile(destPath, data, 0644); err != nil {
+ return fmt.Errorf("failed to write profile to %s: %w", destPath, err)
+ }
+
+ fmt.Printf("Profile '%s' imported successfully to %s\n", name, destPath)
+ return nil
+}
+
+func (a *App) handleProfileDelete(name string) error {
+ profilesDir := a.getPathManager().ConfigDir()
+ destPath := filepath.Join(profilesDir, name+".conf")
+
+ if _, err := os.Stat(destPath); os.IsNotExist(err) {
+ return fmt.Errorf("profile '%s' not found", name)
+ }
+
+ if err := os.Remove(destPath); err != nil {
+ return fmt.Errorf("failed to delete profile %s: %w", name, err)
+ }
+
+ fmt.Printf("Profile '%s' deleted successfully\n", name)
+ return nil
+}
+
func (a *App) showConfig() error {
cfg := &config.Config{}
fs := flag.NewFlagSet("wg-wrap", flag.ExitOnError)
fs.StringVar(&cfg.Profile, "profile", "default", "WireGuard profile to use")
fs.StringVar(&cfg.DNSServer, "dns-server", "", "Override DNS server to use")
- // Parse the arguments that follow 'show-config'
if len(a.Args) > 2 {
_ = fs.Parse(a.Args[2:])
}
- // Determine runtime base directory
- runtimeBase := a.RuntimeBaseDir
- if runtimeBase == "" {
- runtimeBase = os.Getenv("XDG_RUNTIME_DIR")
- if runtimeBase == "" {
- runtimeBase = fmt.Sprintf("/run/user/%d", os.Getuid())
- }
- }
-
- // Resolve paths
- profilePath := namespace.GetProfileNamespacePath(runtimeBase, cfg.Profile)
- pidsPath := namespace.GetPidsDirPath(runtimeBase, cfg.Profile)
+ pm := paths.NewPathManager(a.ConfigDir, a.RuntimeBaseDir)
+ profilePath := pm.ProfileNamespacePath(cfg.Profile)
+ pidsPath := pm.ProfilePidsDir(cfg.Profile)
fmt.Printf("Configuration:\n")
fmt.Printf(" Profile: %s\n", cfg.Profile)
fmt.Printf(" DNS Server: %s\n", cfg.DNSServer)
- fmt.Printf(" Runtime Base: %s\n", runtimeBase)
+ fmt.Printf(" Runtime Base: %s\n", pm.RuntimeBaseDir())
fmt.Printf(" Profile Path: %s\n", profilePath)
fmt.Printf(" PIDs Path: %s\n", pidsPath)
fmt.Printf(" Isolated: %v\n", namespace.IsIsolated())
fmt.Printf(" UID: %d\n", os.Getuid())
-
return nil
}
diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go
index 0274fbc..a0d6263 100644
--- a/internal/cli/cli_test.go
+++ b/internal/cli/cli_test.go
@@ -1,8 +1,8 @@
package cli
import (
- "testing"
"strings"
+ "testing"
)
func TestAppRun_ProfileDirInjection(t *testing.T) {
@@ -25,7 +25,7 @@ func TestAppRun_ProfileDirInjection(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := NewApp(tt.args)
- app.ConfigDir = tmpDir // Inject temporary directory
+ app.ConfigDir = tmpDir // Inject temporary directory
app.RuntimeBaseDir = tmpDir // Inject temporary directory for PID tracking
err := app.Run()
diff --git a/internal/cli/profile_test.go b/internal/cli/profile_test.go
new file mode 100644
index 0000000..d256cb0
--- /dev/null
+++ b/internal/cli/profile_test.go
@@ -0,0 +1,125 @@
+package cli
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestProfileList(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create some dummy profile files
+ profiles := []string{"home.conf", "work.conf", "not-a-conf.txt"}
+ for _, p := range profiles {
+ err := os.WriteFile(filepath.Join(tmpDir, p), []byte("test content"), 0644)
+ if err != nil {
+ t.Fatalf("failed to create test profile %s: %v", p, err)
+ }
+ }
+
+ app := NewApp([]string{"wg-wrap", "profile", "list"})
+ app.ConfigDir = tmpDir
+
+ err := app.Route()
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+}
+
+func TestProfileImport(t *testing.T) {
+ tmpDir := t.TempDir()
+ profilesDir := filepath.Join(tmpDir, "profiles")
+ err := os.MkdirAll(profilesDir, 0755)
+ if err != nil {
+ t.Fatalf("failed to create profiles dir: %v", err)
+ }
+
+ srcFile := filepath.Join(tmpDir, "source.conf")
+ err = os.WriteFile(srcFile, []byte("[Interface]\nPrivateKey = test\n"), 0644)
+ if err != nil {
+ t.Fatalf("failed to create source conf: %v", err)
+ }
+
+ app := NewApp([]string{"wg-wrap", "profile", "import", srcFile})
+ app.ConfigDir = profilesDir
+
+ err = app.Route()
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+
+ // Verify the file was actually copied
+ destFile := filepath.Join(profilesDir, "source.conf")
+ if _, err := os.Stat(destFile); os.IsNotExist(err) {
+ t.Errorf("expected profile to be imported to %s", destFile)
+ }
+}
+
+func TestProfileDelete(t *testing.T) {
+ tmpDir := t.TempDir()
+ profilesDir := filepath.Join(tmpDir, "profiles")
+ err := os.MkdirAll(profilesDir, 0755)
+ if err != nil {
+ t.Fatalf("failed to create profiles dir: %v", err)
+ }
+
+ profileName := "test-profile"
+ profileFile := filepath.Join(profilesDir, profileName+".conf")
+ err = os.WriteFile(profileFile, []byte("[Interface]\nPrivateKey = test\n"), 0644)
+ if err != nil {
+ t.Fatalf("failed to create profile file: %v", err)
+ }
+
+ app := NewApp([]string{"wg-wrap", "profile", "delete", profileName})
+ app.ConfigDir = profilesDir
+
+ err = app.Route()
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+
+ if _, err := os.Stat(profileFile); !os.IsNotExist(err) {
+ t.Errorf("expected profile file %s to be deleted", profileFile)
+ }
+}
+
+func TestProfileDeleteNotFound(t *testing.T) {
+ tmpDir := t.TempDir()
+ app := NewApp([]string{"wg-wrap", "profile", "delete", "non-existent"})
+ app.ConfigDir = tmpDir
+
+ err := app.Route()
+ if err == nil {
+ t.Errorf("expected error when deleting non-existent profile, got nil")
+ }
+}
+
+func TestProfileConfigure(t *testing.T) {
+ // profile configure is intended to modify existing configs.
+ // For now, we just want to ensure it doesn't crash and we can
+ // eventually implement it.
+
+ tmpDir := t.TempDir()
+ profilesDir := filepath.Join(tmpDir, "profiles")
+ err := os.MkdirAll(profilesDir, 0755)
+ if err != nil {
+ t.Fatalf("failed to create profiles dir: %v", err)
+ }
+
+ profileName := "test-profile"
+ profileFile := filepath.Join(profilesDir, profileName+".conf")
+ err = os.WriteFile(profileFile, []byte("[Interface]\nPrivateKey = test\n"), 0644)
+ if err != nil {
+ t.Fatalf("failed to create profile file: %v", err)
+ }
+
+ app := NewApp([]string{"wg-wrap", "profile", "configure", profileName})
+ app.ConfigDir = profilesDir
+
+ err = app.Route()
+ // This will currently return "not yet implemented" error, which is expected for now.
+ if err == nil {
+ t.Errorf("expected 'not yet implemented' error, got nil")
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index d81a1f6..5aa8462 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -1,25 +1,7 @@
package config
-import (
- "os"
- "path/filepath"
-)
-
type Config struct {
Profile string
DNSServer string
Command []string
}
-
-// GetDefaultProfilesDir returns the standard XDG path for wg-wrap profiles.
-func GetDefaultProfilesDir() string {
- configHome := os.Getenv("XDG_CONFIG_HOME")
- if configHome == "" {
- home, err := os.UserHomeDir()
- if err != nil {
- return "etc/wg-wrap/profiles" // Fallback
- }
- configHome = filepath.Join(home, ".config")
- }
- return filepath.Join(configHome, "wg-wrap", "profiles")
-}
diff --git a/internal/namespace/lifecycle.go b/internal/namespace/lifecycle.go
index 4ca725f..47a804f 100644
--- a/internal/namespace/lifecycle.go
+++ b/internal/namespace/lifecycle.go
@@ -6,32 +6,23 @@ import (
"path/filepath"
"strconv"
"syscall"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
)
// GetProfileNamespacePath returns the path to the pinned namespace file for a profile.
-func GetProfileNamespacePath(baseDir, profile string) string {
- if baseDir == "" {
- baseDir = getRuntimeBaseDir()
- }
- return filepath.Join(baseDir, "profiles", profile)
-}
-
-func getRuntimeBaseDir() string {
- if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" {
- return envDir
- }
- uid := os.Getuid()
- return fmt.Sprintf("/run/user/%d", uid)
+func GetProfileNamespacePath(pm *paths.PathManager, profile string) string {
+ return pm.ProfileNamespacePath(profile)
}
// GetPidsDirPath returns the path to the directory where process PIDs are tracked for a profile.
-func GetPidsDirPath(baseDir, profile string) string {
- return filepath.Join(GetProfileNamespacePath(baseDir, profile), "pids")
+func GetPidsDirPath(pm *paths.PathManager, profile string) string {
+ return pm.ProfilePidsDir(profile)
}
// RegisterProcess marks the current process as using the specified profile.
-func RegisterProcess(baseDir, profile string) error {
- pidsDir := GetPidsDirPath(baseDir, profile)
+func RegisterProcess(pm *paths.PathManager, profile string) error {
+ pidsDir := GetPidsDirPath(pm, profile)
if err := os.MkdirAll(pidsDir, 0755); err != nil {
return fmt.Errorf("failed to create pids directory: %v", err)
}
@@ -45,9 +36,9 @@ func RegisterProcess(baseDir, profile string) error {
}
// UnregisterProcess removes the current process from the profile's tracking.
-func UnregisterProcess(baseDir, profile string) error {
+func UnregisterProcess(pm *paths.PathManager, profile string) error {
pid := os.Getpid()
- pidFile := filepath.Join(GetPidsDirPath(baseDir, profile), strconv.Itoa(pid))
+ pidFile := filepath.Join(GetPidsDirPath(pm, profile), strconv.Itoa(pid))
if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to unregister process pid %d: %v", pid, err)
}
@@ -55,8 +46,8 @@ func UnregisterProcess(baseDir, profile string) error {
}
// PruneStalePids removes PID files that no longer correspond to active processes.
-func PruneStalePids(baseDir, profile string) error {
- pidsDir := GetPidsDirPath(baseDir, profile)
+func PruneStalePids(pm *paths.PathManager, profile string) error {
+ pidsDir := GetPidsDirPath(pm, profile)
files, err := os.ReadDir(pidsDir)
if err != nil {
if os.IsNotExist(err) {
@@ -73,21 +64,25 @@ func PruneStalePids(baseDir, profile string) error {
process, err := os.FindProcess(pid)
if err != nil {
- os.Remove(filepath.Join(pidsDir, file.Name()))
+ if err := os.Remove(filepath.Join(pidsDir, file.Name())); err != nil {
+ fmt.Printf("failed to remove stale pid file %s: %v\n", file.Name(), err)
+ }
continue
}
err = process.Signal(syscall.Signal(0))
if err != nil {
- os.Remove(filepath.Join(pidsDir, file.Name()))
+ if err := os.Remove(filepath.Join(pidsDir, file.Name())); err != nil {
+ fmt.Printf("failed to remove stale pid file %s: %v\n", file.Name(), err)
+ }
}
}
return nil
}
// IsLastProcess checks if the current process is the only active user of the profile.
-func IsLastProcess(baseDir, profile string) (bool, error) {
- pidsDir := GetPidsDirPath(baseDir, profile)
+func IsLastProcess(pm *paths.PathManager, profile string) (bool, error) {
+ pidsDir := GetPidsDirPath(pm, profile)
files, err := os.ReadDir(pidsDir)
if err != nil {
if os.IsNotExist(err) {
diff --git a/internal/namespace/lifecycle_test.go b/internal/namespace/lifecycle_test.go
index db04e67..230e93a 100644
--- a/internal/namespace/lifecycle_test.go
+++ b/internal/namespace/lifecycle_test.go
@@ -5,27 +5,30 @@ import (
"path/filepath"
"strconv"
"testing"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
)
func TestLifecycleReferenceCounting(t *testing.T) {
// Use a temporary directory to avoid polluting the system
tmpDir := t.TempDir()
+ pm := paths.NewPathManager("", tmpDir)
profile := "test-vpn"
t.Run("RegisterAndUnregister", func(t *testing.T) {
- err := RegisterProcess(tmpDir, profile)
+ err := RegisterProcess(pm, profile)
if err != nil {
t.Fatalf("failed to register: %v", err)
}
- pidsDir := GetPidsDirPath(tmpDir, profile)
+ pidsDir := GetPidsDirPath(pm, profile)
pidFile := filepath.Join(pidsDir, strconv.Itoa(os.Getpid()))
if _, err := os.Stat(pidFile); os.IsNotExist(err) {
t.Errorf("PID file should exist at %s", pidFile)
}
- err = UnregisterProcess(tmpDir, profile)
+ err = UnregisterProcess(pm, profile)
if err != nil {
t.Fatalf("failed to unregister: %v", err)
}
@@ -36,20 +39,22 @@ func TestLifecycleReferenceCounting(t *testing.T) {
})
t.Run("PruneStalePids", func(t *testing.T) {
- pidsDir := GetPidsDirPath(tmpDir, profile)
+ pidsDir := GetPidsDirPath(pm, profile)
if err := os.MkdirAll(pidsDir, 0755); err != nil {
t.Fatal(err)
}
- fakePid := "9999999"
+ fakePid := "9999999"
fakePidFile := filepath.Join(pidsDir, fakePid)
if err := os.WriteFile(fakePidFile, []byte(""), 0644); err != nil {
t.Fatal(err)
}
- RegisterProcess(tmpDir, profile)
+ if err := RegisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
- err := PruneStalePids(tmpDir, profile)
+ err := PruneStalePids(pm, profile)
if err != nil {
t.Fatalf("prune failed: %v", err)
}
@@ -62,27 +67,35 @@ func TestLifecycleReferenceCounting(t *testing.T) {
if _, err := os.Stat(currentPidFile); os.IsNotExist(err) {
t.Errorf("Current PID file %s should not have been pruned", currentPidFile)
}
-
- UnregisterProcess(tmpDir, profile)
+
+ if err := UnregisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
})
t.Run("IsLastProcess", func(t *testing.T) {
- pidsDir := GetPidsDirPath(tmpDir, profile)
- os.RemoveAll(pidsDir) // Reset
+ pidsDir := GetPidsDirPath(pm, profile)
+ if err := os.RemoveAll(pidsDir); err != nil {
+ t.Fatal(err)
+ }
- isLast, err := IsLastProcess(tmpDir, profile)
+ isLast, err := IsLastProcess(pm, profile)
if err != nil || !isLast {
t.Errorf("Expected IsLastProcess to be true for empty profile, got %v, err: %v", isLast, err)
}
- RegisterProcess(tmpDir, profile)
- isLast, err = IsLastProcess(tmpDir, profile)
+ if err := RegisterProcess(pm, profile); err != nil {
+ t.Fatal(err)
+ }
+ isLast, err = IsLastProcess(pm, profile)
if err != nil || !isLast {
t.Errorf("Expected IsLastProcess to be true for single process, got %v, err: %v", isLast, err)
}
- os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644)
- isLast, err = IsLastProcess(tmpDir, profile)
+ if err := os.WriteFile(filepath.Join(pidsDir, "1234567"), []byte(""), 0644); err != nil {
+ t.Fatal(err)
+ }
+ isLast, err = IsLastProcess(pm, profile)
if err != nil || !isLast {
t.Errorf("Expected IsLastProcess to be true because 1234567 is dead, got %v, err: %v", isLast, err)
}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
new file mode 100644
index 0000000..cd81a38
--- /dev/null
+++ b/internal/namespace/pinning.go
@@ -0,0 +1,35 @@
+package namespace
+
+import (
+ "fmt"
+ "os"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+)
+
+// UnpinNamespace removes the pinned namespace file from the filesystem.
+// This allows the namespace to be destroyed once the last process exits.
+func UnpinNamespace(pm *paths.PathManager, profile string) error {
+ nsPath := GetProfileNamespacePath(pm, profile)
+
+ // We only want to unpin if there are no more active processes.
+ // The caller (cli.ExecuteCommand) is responsible for calling this
+ // when IsLastProcess returns true.
+
+ if _, err := os.Stat(nsPath); os.IsNotExist(err) {
+ return nil
+ }
+
+ // We also want to remove the pids directory if it's empty.
+ pidsDir := GetPidsDirPath(pm, profile)
+
+ // Unlink the namespace file
+ if err := os.Remove(nsPath); err != nil {
+ return fmt.Errorf("failed to unpin namespace %s: %w", nsPath, err)
+ }
+
+ // Try to remove pids directory
+ _ = os.Remove(pidsDir)
+
+ return nil
+}
diff --git a/internal/namespace/pinning_test.go b/internal/namespace/pinning_test.go
new file mode 100644
index 0000000..c65e1b1
--- /dev/null
+++ b/internal/namespace/pinning_test.go
@@ -0,0 +1,53 @@
+package namespace
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+)
+
+func TestUnpinNamespace(t *testing.T) {
+ tmpDir := t.TempDir()
+ pm := paths.NewPathManager("", tmpDir)
+ profile := "test-profile"
+ nsPath := GetProfileNamespacePath(pm, profile)
+
+ // Create the base profiles directory first
+ profilesDir := filepath.Dir(nsPath)
+ if err := os.MkdirAll(profilesDir, 0755); err != nil {
+ t.Fatalf("failed to create profiles dir: %v", err)
+ }
+
+ // Create dummy namespace file
+ if err := os.WriteFile(nsPath, []byte("dummy"), 0644); err != nil {
+ t.Fatalf("failed to create ns file: %v", err)
+ }
+
+ pidsDir := GetPidsDirPath(pm, profile)
+ if err := os.MkdirAll(pidsDir, 0755); err != nil {
+ t.Fatalf("failed to create pids dir: %v", err)
+ }
+
+ t.Run("successfully unpins", func(t *testing.T) {
+ err := UnpinNamespace(pm, profile)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+
+ if _, err := os.Stat(nsPath); !os.IsNotExist(err) {
+ t.Errorf("namespace file should have been deleted")
+ }
+ if _, err := os.Stat(pidsDir); !os.IsNotExist(err) {
+ t.Errorf("pids directory should have been deleted")
+ }
+ })
+
+ t.Run("handles non-existent namespace", func(t *testing.T) {
+ err := UnpinNamespace(pm, profile)
+ if err != nil {
+ t.Errorf("unexpected error when unpinning non-existent namespace: %v", err)
+ }
+ })
+}
diff --git a/internal/paths/paths.go b/internal/paths/paths.go
new file mode 100644
index 0000000..f512ad1
--- /dev/null
+++ b/internal/paths/paths.go
@@ -0,0 +1,62 @@
+package paths
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// PathManager handles the resolution of configuration and runtime directories.
+// By using a struct, we can instantiate different managers for parallel tests.
+type PathManager struct {
+ ConfigDirOverride string
+ RuntimeBaseOverride string
+}
+
+// NewPathManager creates a PathManager with the given overrides.
+func NewPathManager(configOverride, runtimeOverride string) *PathManager {
+ return &PathManager{
+ ConfigDirOverride: configOverride,
+ RuntimeBaseOverride: runtimeOverride,
+ }
+}
+
+// ConfigDir returns the persistent storage path for .conf files.
+func (pm *PathManager) ConfigDir() string {
+ if pm.ConfigDirOverride != "" {
+ return pm.ConfigDirOverride
+ }
+
+ configHome := os.Getenv("XDG_CONFIG_HOME")
+ if configHome == "" {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "/etc/wg-wrap/profiles" // Fallback
+ }
+ configHome = filepath.Join(home, ".config")
+ }
+ return filepath.Join(configHome, "wg-wrap", "profiles")
+}
+
+// RuntimeBaseDir returns the base ephemeral path.
+func (pm *PathManager) RuntimeBaseDir() string {
+ if pm.RuntimeBaseOverride != "" {
+ return pm.RuntimeBaseOverride
+ }
+
+ if envDir := os.Getenv("XDG_RUNTIME_DIR"); envDir != "" {
+ return envDir
+ }
+ uid := os.Getuid()
+ return fmt.Sprintf("/run/user/%d", uid)
+}
+
+// ProfileNamespacePath returns the specific path for a pinned namespace.
+func (pm *PathManager) ProfileNamespacePath(profile string) string {
+ return filepath.Join(pm.RuntimeBaseDir(), "profiles", profile+".ns")
+}
+
+// ProfilePidsDir returns the path for PID tracking.
+func (pm *PathManager) ProfilePidsDir(profile string) string {
+ return filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "pids")
+}
diff --git a/pkg/wgconf/wgconf.go b/pkg/wgconf/wgconf.go
new file mode 100644
index 0000000..2615892
--- /dev/null
+++ b/pkg/wgconf/wgconf.go
@@ -0,0 +1,97 @@
+package wgconf
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "strings"
+)
+
+// Config represents a parsed WireGuard configuration file.
+type Config struct {
+ PrivateKey string
+ Address string
+ DNS string
+ Peers []Peer
+}
+
+// Peer represents a WireGuard peer.
+type Peer struct {
+ PublicKey string
+ Endpoint string
+ AllowedIPs []string
+}
+
+// Parse reads a WireGuard .conf file and returns a Config struct.
+func Parse(path string) (*Config, error) {
+ file, err := os.Open(path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open config file: %w", err)
+ }
+ defer func() {
+ if err := file.Close(); err != nil {
+ // We use a simple print here because we are in a defer
+ fmt.Printf("warning: failed to close config file %s: %v\n", path, err)
+ }
+ }()
+
+ cfg := &Config{}
+ var currentPeer *Peer
+
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || strings.HasPrefix(line, "#") {
+ continue
+ }
+
+ if strings.HasPrefix(line, "[") {
+ section := strings.Trim(line, "[]")
+ if section == "Peer" {
+ if currentPeer != nil {
+ cfg.Peers = append(cfg.Peers, *currentPeer)
+ }
+ currentPeer = &Peer{}
+ }
+ continue
+ }
+
+ parts := strings.SplitN(line, "=", 2)
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid line format: %s", line)
+ }
+
+ key := strings.TrimSpace(parts[0])
+ val := strings.TrimSpace(parts[1])
+
+ if currentPeer != nil {
+ switch key {
+ case "PublicKey":
+ currentPeer.PublicKey = val
+ case "Endpoint":
+ currentPeer.Endpoint = val
+ case "AllowedIPs":
+ currentPeer.AllowedIPs = strings.Split(val, ",")
+ }
+ } else {
+ switch key {
+ case "PrivateKey":
+ cfg.PrivateKey = val
+ case "Address":
+ cfg.Address = val
+ case "DNS":
+ cfg.DNS = val
+ }
+ }
+ }
+
+ if currentPeer != nil {
+ cfg.Peers = append(cfg.Peers, *currentPeer)
+ }
+
+ if err := scanner.Err(); err != nil {
+ return nil, fmt.Errorf("error reading config file: %w", err)
+ }
+
+ return cfg, nil
+}
diff --git a/pkg/wgconf/wgconf_test.go b/pkg/wgconf/wgconf_test.go
index d0bcb0b..805aeaa 100644
--- a/pkg/wgconf/wgconf_test.go
+++ b/pkg/wgconf/wgconf_test.go
@@ -1,15 +1,71 @@
package wgconf
import (
+ "os"
+ "path/filepath"
"testing"
)
func TestParseConfig(t *testing.T) {
- // Test that valid .conf files are parsed correctly and invalid ones return errors.
- t.Skip("not implemented")
+ content := `[Interface]
+PrivateKey = ABC123XYZ
+Address = 10.0.0.1/24
+DNS = 1.1.1.1
+
+[Peer]
+PublicKey = DEF456UVW
+Endpoint = 1.2.3.4:51820
+AllowedIPs = 0.0.0.0/0
+
+[Peer]
+PublicKey = GHI789TSR
+Endpoint = 5.6.7.8:51820
+AllowedIPs = 192.168.1.0/24, 192.168.2.0/24`
+
+ tmpFile := filepath.Join(t.TempDir(), "test.conf")
+ if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ cfg, err := Parse(tmpFile)
+ if err != nil {
+ t.Fatalf("Parse failed: %v", err)
+ }
+
+ if cfg.PrivateKey != "ABC123XYZ" {
+ t.Errorf("expected PrivateKey ABC123XYZ, got %s", cfg.PrivateKey)
+ }
+ if cfg.Address != "10.0.0.1/24" {
+ t.Errorf("expected Address 10.0.0.1/24, got %s", cfg.Address)
+ }
+ if cfg.DNS != "1.1.1.1" {
+ t.Errorf("expected DNS 1.1.1.1, got %s", cfg.DNS)
+ }
+ if len(cfg.Peers) != 2 {
+ t.Fatalf("expected 2 peers, got %d", len(cfg.Peers))
+ }
+
+ p1 := cfg.Peers[0]
+ if p1.PublicKey != "DEF456UVW" || p1.Endpoint != "1.2.3.4:51820" || len(p1.AllowedIPs) != 1 || p1.AllowedIPs[0] != "0.0.0.0/0" {
+ t.Errorf("Peer 1 mismatch: %+v", p1)
+ }
+
+ p2 := cfg.Peers[1]
+ if p2.PublicKey != "GHI789TSR" || p2.Endpoint != "5.6.7.8:51820" || len(p2.AllowedIPs) != 2 || p2.AllowedIPs[0] != "192.168.1.0/24" {
+ t.Errorf("Peer 2 mismatch: %+v", p2)
+ }
}
-func TestValidateProfile(t *testing.T) {
- // Test that profile names are resolved correctly to ~/.config/wg-wrap/profiles/*.conf.
- t.Skip("not implemented")
+func TestParseInvalidConfig(t *testing.T) {
+ content := `[Interface]
+InvalidLineWithoutEquals`
+ tmpFile := filepath.Join(t.TempDir(), "invalid.conf")
+ if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
+ t.Fatal(err)
+ }
+
+ _, err := Parse(tmpFile)
+ if err == nil {
+ t.Error("expected error for invalid line format, got nil")
+ }
}
diff --git a/tests/e2e/config_test.go b/tests/e2e/config_test.go
index 83cfc15..e613b13 100644
--- a/tests/e2e/config_test.go
+++ b/tests/e2e/config_test.go
@@ -38,12 +38,12 @@ func TestConfigPropagation(t *testing.T) {
// Test 2: Configuration after bootstrap (Isolated)
// We use 'test-ns' as a way to run a command that we know is isolated.
- // Actually, we can just run 'show-config' but the current 'Route'
- // handles 'show-config' BEFORE the bootstrap.
- // To test isolated config, we can't use 'show-config' because it's a diagnostic
+ // Actually, we can just run 'show-config' but the current 'Route'
+ // handles 'show-config' BEFORE the bootstrap.
+ // To test isolated config, we can't use 'show-config' because it's a diagnostic
// command designed to run outside isolation.
-
- // To verify what an isolated process sees, we can use a target command
+
+ // To verify what an isolated process sees, we can use a target command
// that prints the environment.
cmdIsolated := exec.Command(binaryPath, "--profile", profile, "--", "sh", "-c", "echo $XDG_RUNTIME_DIR")
cmdIsolated.Env = append(os.Environ(), fmt.Sprintf("XDG_RUNTIME_DIR=%s", tmpRuntimeDir))
diff --git a/tests/e2e/lifecycle_test.go b/tests/e2e/lifecycle_test.go
index 649dbc0..08887e1 100644
--- a/tests/e2e/lifecycle_test.go
+++ b/tests/e2e/lifecycle_test.go
@@ -48,9 +48,11 @@ func TestNamespaceLifecycleAutomation(t *testing.T) {
tmpRuntimeDir := t.TempDir()
profile := "e2e-lifecycle-test"
pidsDir := filepath.Join(tmpRuntimeDir, "profiles", profile, "pids")
-
+
// Clean up before starting
- os.RemoveAll(filepath.Join(tmpRuntimeDir, "profiles", profile))
+ if err := os.RemoveAll(filepath.Join(tmpRuntimeDir, "profiles", profile)); err != nil {
+ t.Fatalf("failed to remove profile directory: %v", err)
+ }
t.Run("ReferenceCounting", func(t *testing.T) {
// Start a process that exits quickly
@@ -75,7 +77,7 @@ func TestNamespaceLifecycleAutomation(t *testing.T) {
if err := cmd1.Wait(); err != nil {
t.Fatalf("cmd1 failed: %v", err)
}
-
+
// Poll for the count to drop back to 1
timeout := time.After(2 * time.Second)
found := false
@@ -97,7 +99,7 @@ func TestNamespaceLifecycleAutomation(t *testing.T) {
if err := cmd2.Wait(); err != nil {
t.Fatalf("cmd2 failed: %v", err)
}
-
+
// Verify a clean state (expect 0 files)
timeout = time.After(2 * time.Second)
found = false