summaryrefslogtreecommitdiff
path: root/internal/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'internal/wireguard')
-rw-r--r--internal/wireguard/wireguard.go78
1 files changed, 38 insertions, 40 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 8ac7e63..45e6292 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -9,9 +9,12 @@ import (
"net"
"net/netip"
"os"
+ "path/filepath"
"strconv"
"strings"
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
"git.theodohertyfamily.com/wg-wrap/pkg/wgconf"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@@ -22,13 +25,13 @@ import (
// Tunnel represents an active Userspace WireGuard tunnel inside a network namespace.
type Tunnel struct {
- Device *device.Device
- Tun tun.Device
- dnsFile string
+ Device *device.Device
+ Tun tun.Device
+ dnsFile string
}
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
+func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
var cleanups []func()
defer func() {
if err != nil {
@@ -46,7 +49,7 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
}
- if err := BlockHostServices(); err != nil {
+ if err := BlockHostServices(pm, profile); err != nil {
fmt.Printf("warning: failed to block host services: %v\n", err)
}
@@ -99,10 +102,14 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
}
var dnsFile string
- if path, err := ConfigureResolvConf(dnsServer); err != nil {
+ profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile)
+ if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil {
fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
} else {
dnsFile = path
+ cleanups = append(cleanups, func() {
+ UnmountResolvConf(dnsFile)
+ })
}
return &Tunnel{
@@ -232,13 +239,16 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
return ip.String(), nil
}
-func ConfigureResolvConf(dns string) (string, error) {
+func ConfigureResolvConf(dns string, profileDir string) (string, error) {
if dns == "" {
return "", nil
}
- tmpFile, err := os.CreateTemp("", "resolvconf")
+
+ // Create the temporary resolv.conf file within the profile directory
+ // so it can be cleaned up during namespace unpinning.
+ tmpFile, err := os.CreateTemp(profileDir, "resolvconf")
if err != nil {
- return "", fmt.Errorf("failed to create temp resolv.conf: %w", err)
+ return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err)
}
launcherPath := tmpFile.Name()
content := fmt.Sprintf("nameserver %s\n", dns)
@@ -263,50 +273,38 @@ func UnmountResolvConf(path string) error {
if path == "" {
return nil
}
- if err := unix.Unmount("/etc/resolv.conf", 0); err != nil {
- return fmt.Errorf("failed to unmount /etc/resolv.conf: %w", err)
- }
+
+ fmt.Printf("DEBUG: Unmounting resolv.conf file: %s\n", path)
+
+ // Attempt to unmount. If it fails, it might already be unmounted
+ // or the namespace might be gone.
+ _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
+
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err)
}
return nil
}
-func BlockHostServices() error {
- tmpDir, err := os.MkdirTemp("", "wg-wrap-block-")
+func BlockHostServices(pm *paths.PathManager, profile string) error {
+ blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
+ if err := os.MkdirAll(blockDirBase, 0755); err != nil {
+ return fmt.Errorf("failed to create block base directory: %w", err)
+ }
+
+ tmpDir, err := os.MkdirTemp(blockDirBase, "dir-")
if err != nil {
- return fmt.Errorf("failed to create temp dir: %w", err)
+ return fmt.Errorf("failed to create temp block dir: %w", err)
}
- defer func() {
- if err := os.RemoveAll(tmpDir); err != nil {
- fmt.Printf("warning: failed to remove temp dir %s: %v\n", tmpDir, err)
- }
- }()
- tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-")
+ tmpFile, err := os.CreateTemp(blockDirBase, "file-")
if err != nil {
- return fmt.Errorf("failed to create temp file: %w", err)
+ return fmt.Errorf("failed to create temp block file: %w", err)
}
tmpFileName := tmpFile.Name()
_ = tmpFile.Close()
- defer func() {
- if err := os.Remove(tmpFileName); err != nil && !os.IsNotExist(err) {
- fmt.Printf("warning: failed to remove temp file %s: %v\n", tmpFileName, err)
- }
- }()
-
- pathsToBlock := []string{
- "/run/dbus/system_bus_socket",
- "/run/systemd/resolve/io.systemd.Resolve",
- "/run/systemd/resolve/io.systemd.Resolve.Monitor",
- "/run/nscd/socket",
- "/var/run/dbus/system_bus_socket",
- "/var/run/systemd/resolve/io.systemd.Resolve",
- "/var/run/systemd/resolve/io.systemd.Resolve.Monitor",
- "/var/run/nscd/socket",
- }
- for _, p := range pathsToBlock {
+ for _, p := range namespace.GetBlockPaths() {
stat, err := os.Stat(p)
if err == nil {
source := tmpFileName
@@ -337,7 +335,7 @@ func (h *HostBind) Close() error { return ni
func (h *HostBind) SetMark(mark uint32) error { return nil }
func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return nil }
func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
-func (h *HostBind) BatchSize() int { return 0 }
+func (h *HostBind) BatchSize() int { return 0}
type FDBind struct {
originalFd int