summaryrefslogtreecommitdiff
path: root/internal/manager/manager.go
blob: 99a1a32a2dfdd9be0e07d2cbb0f66d891548f2f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
// Package manager orchestrates the high-level lifecycle of WireGuard tunnels
// and their associated network namespaces.
//
// Architecture:
// wg-wrap provides a transparent data path:
// Linux Application -> Linux Kernel Routing -> TUN Device -> Userspace WireGuard -> UDP Socket -> Internet.
//
// Persistent Namespaces & Shared Sessions:
// To support multiple concurrent commands on the same WireGuard tunnel without re-establishing
// connections, wg-wrap employs session-based persistent, unprivileged namespaces.
//
// 1. Tracking: Process usage is tracked using active PID files inside the runtime base directory.
// 2. Ref-Counting & Cleanup: Active PIDs are regularly pruned. When the last active process exits,
//    the namespace is unpinned and resources are reclaimed.
// 3. Setns Join: When a new process is executed on an active profile, it discovers an active PID
//    and attaches itself to the existing User, Mount, and Network namespaces of the active tunnel.
package manager

import (
	"fmt"
	"os"
	"os/exec"
	"path/filepath"
	"strings"

	"git.theodohertyfamily.com/wg-wrap/internal/config"
	"git.theodohertyfamily.com/wg-wrap/internal/namespace"
	"git.theodohertyfamily.com/wg-wrap/internal/paths"
	"git.theodohertyfamily.com/wg-wrap/internal/wireguard"
	"git.theodohertyfamily.com/wg-wrap/pkg/wgconf"
)

// Manager orchestrates the high-level lifecycle of WireGuard tunnels
// and their associated network namespaces.
type Manager struct {
	// PM is the path manager used to resolve configuration and runtime directories.
	PM *paths.PathManager
}

// New creates a new Manager with the given path manager.
func New(pm *paths.PathManager) *Manager {
	return &Manager{PM: pm}
}

// Bootstrap ensures the process is running in an isolated user and network namespace.
// If an active session already exists for the profile, it joins it.
func (m *Manager) Bootstrap(cfg *config.Config) error {
	if namespace.IsIsolated() {
		return nil
	}

	// Preserve the host runtime base dir in the environment before bootstrapping.
	_ = os.Setenv("WG_WRAP_HOST_RUNTIME_BASE_DIR", m.PM.RuntimeBaseDir())

	// Acquire startup lock to prevent concurrent bootstrap/joining races.
	lockFile, lockErr := namespace.AcquireProfileLock(m.PM, cfg.Profile)
	if lockErr == nil {
		defer namespace.ReleaseProfileLock(lockFile)
	}

	// Before bootstrapping, see if an active namespace/process for the profile exists.
	activePid, err := namespace.FindActiveProfilePid(m.PM, cfg.Profile)
	if err == nil && activePid > 0 {
		// Release the lock before executing the command to allow others to join.
		namespace.ReleaseProfileLock(lockFile)

		// Register this PID before joining to prevent the race where the joining process
		// hasn't registered itself yet, causing the existing process to think it's the last one.
		_ = namespace.RegisterProcess(m.PM, cfg.Profile)

		if err := namespace.BootstrapJoin(activePid); err != nil {
			return fmt.Errorf("failed to join existing namespace: %w", err)
		}
		return nil
	}

	if err := namespace.Bootstrap(); err != nil {
		return fmt.Errorf("bootstrap failed: %w", err)
	}

	return nil
}

// Execute manages the full execution lifecycle inside an isolated namespace:
// lock acquisition, PID registration, tunnel initialization, command execution, and cleanup.
func (m *Manager) Execute(cfg *config.Config, verbose bool) error {
	if !namespace.IsIsolated() {
		return fmt.Errorf("Execute called without namespace isolation")
	}

	// Acquire execution lock during configuration and startup inside the namespace.
	lockFile, lockErr := namespace.AcquireProfileLock(m.PM, cfg.Profile)
	var lockFileReleased bool
	if lockErr == nil {
		defer func() {
			if !lockFileReleased {
				namespace.ReleaseProfileLock(lockFile)
			}
		}()
	}

	if err := namespace.PruneStalePids(m.PM, cfg.Profile); err != nil {
		fmt.Fprintf(os.Stderr, "failed to prune stale pids: %v\n", err)
	}
	if err := namespace.RegisterProcess(m.PM, cfg.Profile); err != nil {
		return fmt.Errorf("failed to register process: %w", err)
	}

	defer func() {
		var cleanupLock *os.File
		var cleanupErr error

		if lockErr == nil && !lockFileReleased {
			cleanupLock = lockFile
		} else {
			cleanupLock, cleanupErr = namespace.AcquireProfileLock(m.PM, cfg.Profile)
		}

		if cleanupErr == nil {
			if err := namespace.UnregisterProcess(m.PM, cfg.Profile); err != nil {
				fmt.Fprintf(os.Stderr, "failed to unregister process: %v\n", err)
			}

			if err := namespace.PruneStalePids(m.PM, cfg.Profile); err != nil {
				fmt.Fprintf(os.Stderr, "failed to prune stale pids during cleanup: %v\n", err)
			}

			last, lastErr := namespace.IsLastProcess(m.PM, cfg.Profile)

			if lastErr == nil && last {
				if err := namespace.UnpinNamespace(m.PM, cfg.Profile); err != nil {
					fmt.Fprintf(os.Stderr, "failed to unpin namespace: %v\n", err)
				}
			}
			if lockErr == nil && !lockFileReleased {
				lockFileReleased = true
			}
			namespace.ReleaseProfileLock(cleanupLock)
		} else {
			if err := namespace.UnregisterProcess(m.PM, cfg.Profile); err != nil {
				fmt.Fprintf(os.Stderr, "failed to unregister process: %v\n", err)
			}
		}
	}()

	if os.Getenv("WG_WRAP_JOINED") == "1" {
		if verbose {
			fmt.Printf("Joining active WireGuard tunnel session for profile %s...\n", cfg.Profile)
		}
	} else {
		if verbose {
			fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile)
		}

		profilesDir := m.PM.ConfigDir()
		profilePath := filepath.Join(profilesDir, cfg.Profile+".conf")

		if _, err := os.Stat(profilePath); err == nil {
			wgCfg, err := wgconf.Parse(profilePath)
			if err != nil {
				return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
			}

			dnsServer := cfg.DNSServer
			if dnsServer == "" {
				dnsServer = wgCfg.DNS
			}
			if dnsServer == "" {
				dnsServer = "1.1.1.1"
				hasDefaultRoute := false
				for _, peer := range wgCfg.Peers {
					for _, ip := range peer.AllowedIPs {
						trimmed := strings.TrimSpace(ip)
						if trimmed == "0.0.0.0/0" || trimmed == "::/0" {
							hasDefaultRoute = true
							break
						}
					}
					if hasDefaultRoute {
						break
					}
				}
				if !hasDefaultRoute {
					fmt.Fprintf(os.Stderr, "warning: Falling back to 1.1.1.1, but your profile does not route all traffic (0.0.0.0/0). DNS resolution may fail.\n")
				}
			}

			tunnel, err := wireguard.StartTunnel(m.PM, cfg.Profile, wgCfg, dnsServer)
			if err != nil {
				return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
			}
			defer tunnel.Close()

			if err := namespace.PinNamespace(m.PM, cfg.Profile); err != nil {
				fmt.Fprintf(os.Stderr, "warning: failed to pin namespace: %v\n", err)
			}
		} else {
			if cfg.Profile != "default" {
				return fmt.Errorf("profile %s not found", cfg.Profile)
			}
			fmt.Fprintf(os.Stderr, "warning: default profile configuration not found. Executing command in bare isolation.\n")
		}
	}

	lockFileReleased = true
	namespace.ReleaseProfileLock(lockFile)

	cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...)
	cmd.Stdin = os.Stdin
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Env = os.Environ()

	if err := cmd.Run(); err != nil {
		return fmt.Errorf("command execution failed: %w", err)
	}

	return nil
}

// StopProfile stops a profile session by unpinning its namespace.
func (m *Manager) StopProfile(profile string) error {
	if err := namespace.UnpinNamespace(m.PM, profile); err != nil {
		return fmt.Errorf("failed to stop profile: %w", err)
	}
	return nil
}

// VerifyLifecycle checks for an active session for the given profile.
func (m *Manager) VerifyLifecycle(profile string) error {
	activePid, err := namespace.FindActiveProfilePid(m.PM, profile)
	if err != nil || activePid <= 0 {
		return fmt.Errorf("no active session found for profile %s", profile)
	}

	fmt.Printf("Active session found for profile %s (PID: %d)\n", profile, activePid)
	return nil
}