summaryrefslogtreecommitdiff
path: root/internal/manager/manager.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/manager/manager.go')
-rw-r--r--internal/manager/manager.go238
1 files changed, 238 insertions, 0 deletions
diff --git a/internal/manager/manager.go b/internal/manager/manager.go
new file mode 100644
index 0000000..99a1a32
--- /dev/null
+++ b/internal/manager/manager.go
@@ -0,0 +1,238 @@
+// Package manager orchestrates the high-level lifecycle of WireGuard tunnels
+// and their associated network namespaces.
+//
+// Architecture:
+// wg-wrap provides a transparent data path:
+// Linux Application -> Linux Kernel Routing -> TUN Device -> Userspace WireGuard -> UDP Socket -> Internet.
+//
+// Persistent Namespaces & Shared Sessions:
+// To support multiple concurrent commands on the same WireGuard tunnel without re-establishing
+// connections, wg-wrap employs session-based persistent, unprivileged namespaces.
+//
+// 1. Tracking: Process usage is tracked using active PID files inside the runtime base directory.
+// 2. Ref-Counting & Cleanup: Active PIDs are regularly pruned. When the last active process exits,
+// the namespace is unpinned and resources are reclaimed.
+// 3. Setns Join: When a new process is executed on an active profile, it discovers an active PID
+// and attaches itself to the existing User, Mount, and Network namespaces of the active tunnel.
+package manager
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "git.theodohertyfamily.com/wg-wrap/internal/config"
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
+ "git.theodohertyfamily.com/wg-wrap/internal/wireguard"
+ "git.theodohertyfamily.com/wg-wrap/pkg/wgconf"
+)
+
+// Manager orchestrates the high-level lifecycle of WireGuard tunnels
+// and their associated network namespaces.
+type Manager struct {
+ // PM is the path manager used to resolve configuration and runtime directories.
+ PM *paths.PathManager
+}
+
+// New creates a new Manager with the given path manager.
+func New(pm *paths.PathManager) *Manager {
+ return &Manager{PM: pm}
+}
+
+// Bootstrap ensures the process is running in an isolated user and network namespace.
+// If an active session already exists for the profile, it joins it.
+func (m *Manager) Bootstrap(cfg *config.Config) error {
+ if namespace.IsIsolated() {
+ return nil
+ }
+
+ // Preserve the host runtime base dir in the environment before bootstrapping.
+ _ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", m.PM.RuntimeBaseDir())
+
+ // Acquire startup lock to prevent concurrent bootstrap/joining races.
+ lockFile, lockErr := namespace.AcquireProfileLock(m.PM, cfg.Profile)
+ if lockErr == nil {
+ defer namespace.ReleaseProfileLock(lockFile)
+ }
+
+ // Before bootstrapping, see if an active namespace/process for the profile exists.
+ activePid, err := namespace.FindActiveProfilePid(m.PM, cfg.Profile)
+ if err == nil && activePid > 0 {
+ // Release the lock before executing the command to allow others to join.
+ namespace.ReleaseProfileLock(lockFile)
+
+ // Register this PID before joining to prevent the race where the joining process
+ // hasn't registered itself yet, causing the existing process to think it's the last one.
+ _ = namespace.RegisterProcess(m.PM, cfg.Profile)
+
+ if err := namespace.BootstrapJoin(activePid); err != nil {
+ return fmt.Errorf("failed to join existing namespace: %w", err)
+ }
+ return nil
+ }
+
+ if err := namespace.Bootstrap(); err != nil {
+ return fmt.Errorf("bootstrap failed: %w", err)
+ }
+
+ return nil
+}
+
+// Execute manages the full execution lifecycle inside an isolated namespace:
+// lock acquisition, PID registration, tunnel initialization, command execution, and cleanup.
+func (m *Manager) Execute(cfg *config.Config, verbose bool) error {
+ if !namespace.IsIsolated() {
+ return fmt.Errorf("Execute called without namespace isolation")
+ }
+
+ // Acquire execution lock during configuration and startup inside the namespace.
+ lockFile, lockErr := namespace.AcquireProfileLock(m.PM, cfg.Profile)
+ var lockFileReleased bool
+ if lockErr == nil {
+ defer func() {
+ if !lockFileReleased {
+ namespace.ReleaseProfileLock(lockFile)
+ }
+ }()
+ }
+
+ if err := namespace.PruneStalePids(m.PM, cfg.Profile); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to prune stale pids: %v\n", err)
+ }
+ if err := namespace.RegisterProcess(m.PM, cfg.Profile); err != nil {
+ return fmt.Errorf("failed to register process: %w", err)
+ }
+
+ defer func() {
+ var cleanupLock *os.File
+ var cleanupErr error
+
+ if lockErr == nil && !lockFileReleased {
+ cleanupLock = lockFile
+ } else {
+ cleanupLock, cleanupErr = namespace.AcquireProfileLock(m.PM, cfg.Profile)
+ }
+
+ if cleanupErr == nil {
+ if err := namespace.UnregisterProcess(m.PM, cfg.Profile); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to unregister process: %v\n", err)
+ }
+
+ if err := namespace.PruneStalePids(m.PM, cfg.Profile); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to prune stale pids during cleanup: %v\n", err)
+ }
+
+ last, lastErr := namespace.IsLastProcess(m.PM, cfg.Profile)
+
+ if lastErr == nil && last {
+ if err := namespace.UnpinNamespace(m.PM, cfg.Profile); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to unpin namespace: %v\n", err)
+ }
+ }
+ if lockErr == nil && !lockFileReleased {
+ lockFileReleased = true
+ }
+ namespace.ReleaseProfileLock(cleanupLock)
+ } else {
+ if err := namespace.UnregisterProcess(m.PM, cfg.Profile); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to unregister process: %v\n", err)
+ }
+ }
+ }()
+
+ if os.Getenv("WG_WRAP_JOINED") == "1" {
+ if verbose {
+ fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile)
+ }
+ } else {
+ if verbose {
+ fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile)
+ }
+
+ profilesDir := m.PM.ConfigDir()
+ profilePath := filepath.Join(profilesDir, cfg.Profile+".conf")
+
+ if _, err := os.Stat(profilePath); err == nil {
+ wgCfg, err := wgconf.Parse(profilePath)
+ if err != nil {
+ return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
+ }
+
+ dnsServer := cfg.DNSServer
+ if dnsServer == "" {
+ dnsServer = wgCfg.DNS
+ }
+ if dnsServer == "" {
+ dnsServer = "1.1.1.1"
+ hasDefaultRoute := false
+ for _, peer := range wgCfg.Peers {
+ for _, ip := range peer.AllowedIPs {
+ trimmed := strings.TrimSpace(ip)
+ if trimmed == "0.0.0.0/0" || trimmed == "::/0" {
+ hasDefaultRoute = true
+ break
+ }
+ }
+ if hasDefaultRoute {
+ break
+ }
+ }
+ if !hasDefaultRoute {
+ fmt.Fprintf(os.Stderr, "warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n")
+ }
+ }
+
+ tunnel, err := wireguard.StartTunnel(m.PM, cfg.Profile, wgCfg, dnsServer)
+ if err != nil {
+ return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
+ }
+ defer tunnel.Close()
+
+ if err := namespace.PinNamespace(m.PM, cfg.Profile); err != nil {
+ fmt.Fprintf(os.Stderr, "warning: failed to pin namespace: %v\n", err)
+ }
+ } else {
+ if cfg.Profile != "default" {
+ return fmt.Errorf("profile %s not found", cfg.Profile)
+ }
+ fmt.Fprintf(os.Stderr, "warning: default profile configuration not found. Executing command in bare isolation.\n")
+ }
+ }
+
+ lockFileReleased = true
+ namespace.ReleaseProfileLock(lockFile)
+
+ cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ cmd.Env = os.Environ()
+
+ if err := cmd.Run(); err != nil {
+ return fmt.Errorf("command execution failed: %w", err)
+ }
+
+ return nil
+}
+
+// StopProfile stops a profile session by unpinning its namespace.
+func (m *Manager) StopProfile(profile string) error {
+ if err := namespace.UnpinNamespace(m.PM, profile); err != nil {
+ return fmt.Errorf("failed to stop profile: %w", err)
+ }
+ return nil
+}
+
+// VerifyLifecycle checks for an active session for the given profile.
+func (m *Manager) VerifyLifecycle(profile string) error {
+ activePid, err := namespace.FindActiveProfilePid(m.PM, profile)
+ if err != nil || activePid <= 0 {
+ return fmt.Errorf("no active session found for profile %s", profile)
+ }
+
+ fmt.Printf("Active session found for profile %s (PID: %d)\n", profile, activePid)
+ return nil
+}