summaryrefslogtreecommitdiff
path: root/internal/wireguard/wireguard.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/wireguard/wireguard.go')
-rw-r--r--internal/wireguard/wireguard.go275
1 files changed, 108 insertions, 167 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 3a2bfa3..3c293b4 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -9,12 +9,11 @@ import (
"net"
"net/netip"
"os"
- "os/exec"
- "runtime"
"strconv"
"strings"
"git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf"
+ "github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
@@ -23,8 +22,9 @@ import (
// Tunnel represents an active Userspace WireGuard tunnel inside a network namespace.
type Tunnel struct {
- Device *device.Device
- Tun tun.Device
+ Device *device.Device
+ Tun tun.Device
+ dnsFile string
}
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
@@ -54,7 +54,11 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
if err != nil {
return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err)
}
- cleanups = append(cleanups, func() { tunDev.Close() })
+ cleanups = append(cleanups, func() {
+ if err := tunDev.Close(); err != nil {
+ fmt.Printf("warning: failed to close TUN device: %v\n", err)
+ }
+ })
// 2. Instantiate the userspace WireGuard device
logger := device.NewLogger(device.LogLevelSilent, "[wg-wrap] ")
@@ -69,12 +73,7 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
}
if bind == nil {
- bind = conn.NewDefaultBind()
- if hostNetNSFdStr := os.Getenv("WG_WRAP_HOST_NETNS_FD"); hostNetNSFdStr != "" {
- if fd, err := strconv.Atoi(hostNetNSFdStr); err == nil && fd > 0 {
- bind = NewHostBind(bind, fd)
- }
- }
+ return nil, fmt.Errorf("failed to acquire host socket FD: no valid WG_WRAP_HOST_SOCKET_FD provided")
}
wgDev := device.NewDevice(tunDev, bind, logger)
@@ -94,13 +93,15 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
return nil, fmt.Errorf("failed to bring up WireGuard device: %w", err)
}
- // 4. Configure network interface using standard Linux network commands (iproute2)
+ // 4. Configure network interface using netlink
if err := configureInterface(tunName, cfg.Address, mtu); err != nil {
return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err)
}
- if err := ConfigureResolvConf(dnsServer); err != nil {
+ if path, err := ConfigureResolvConf(dnsServer); err != nil {
fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
+ } else {
+ t.dnsFile = path
}
return &Tunnel{
@@ -114,31 +115,30 @@ func (t *Tunnel) Close() {
if t.Device != nil {
t.Device.Close()
}
+ if t.dnsFile != "" {
+ if err := UnmountResolvConf(t.dnsFile); err != nil {
+ fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err)
+ }
+ }
}
// keyToHex ensures a WireGuard key is in hexadecimal format, converting from base64 if needed.
func keyToHex(key string) (string, error) {
- // Try base64 decoding first
decoded, err := base64.StdEncoding.DecodeString(key)
if err == nil && len(decoded) == 32 {
return hex.EncodeToString(decoded), nil
}
-
- // Try decoding as hex
if len(key) == 64 {
if _, err := hex.DecodeString(key); err == nil {
return strings.ToLower(key), nil
}
}
-
return "", fmt.Errorf("key is neither valid base64 nor hex 32-byte key: %s", key)
}
// buildUAPIConfig translates our wgconf.Config into the standard WireGuard UAPI format
func buildUAPIConfig(cfg *wgconf.Config) (string, error) {
var sb strings.Builder
-
- // Global section
if cfg.PrivateKey != "" {
hexKey, err := keyToHex(cfg.PrivateKey)
if err != nil {
@@ -146,11 +146,7 @@ func buildUAPIConfig(cfg *wgconf.Config) (string, error) {
}
_, _ = fmt.Fprintf(&sb, "private_key=%s\n", hexKey)
}
-
- // If there are existing peers, remove them first to have a clean state
sb.WriteString("replace_peers=true\n")
-
- // Peer sections
for _, peer := range cfg.Peers {
if peer.PublicKey == "" {
continue
@@ -160,11 +156,9 @@ func buildUAPIConfig(cfg *wgconf.Config) (string, error) {
return "", fmt.Errorf("invalid Peer PublicKey: %w", err)
}
_, _ = fmt.Fprintf(&sb, "public_key=%s\n", hexKey)
-
if peer.Endpoint != "" {
_, _ = fmt.Fprintf(&sb, "endpoint=%s\n", peer.Endpoint)
}
-
for _, allowedIP := range peer.AllowedIPs {
trimmed := strings.TrimSpace(allowedIP)
if trimmed != "" {
@@ -172,36 +166,43 @@ func buildUAPIConfig(cfg *wgconf.Config) (string, error) {
}
}
}
-
return sb.String(), nil
}
-// configureInterface uses the 'ip' command to set address, MTU, and default routing table
+// configureInterface uses netlink to set address, MTU, and default routing table.
func configureInterface(name, address string, mtu int) error {
- // Set MTU and bring up link
- // ip link set dev tun0 mtu 1420 up
- cmd := exec.Command("ip", "link", "set", "dev", name, "mtu", fmt.Sprintf("%d", mtu), "up")
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to set link %s state/mtu: %v", name, err)
- }
-
- // Add IP address
- // ip addr add <address> dev tun0
- cmd = exec.Command("ip", "addr", "add", address, "dev", name)
- if _, err := cmd.CombinedOutput(); err != nil {
- return fmt.Errorf("failed to add address %s to link %s: %v", address, name, err)
- }
-
- // Set default route or peer routes.
- // For transparent userspace tunneling inside an isolated network namespace,
- // we route all traffic (0.0.0.0/0) through our TUN device 'tun0'.
- cmd = exec.Command("ip", "route", "add", "default", "dev", name)
- if err := cmd.Run(); err != nil {
- // If a default route already exists, we replace it or log the warning
- // We try to replace first
- cmdReplace := exec.Command("ip", "route", "replace", "default", "dev", name)
- if errReplace := cmdReplace.Run(); errReplace != nil {
- return fmt.Errorf("failed to configure default route to %s: %v", name, errReplace)
+ link, err := netlink.LinkByName(name)
+ if err != nil {
+ return fmt.Errorf("failed to find link %s: %w", name, err)
+ }
+
+ if err := netlink.LinkSetMTU(link, mtu); err != nil {
+ return fmt.Errorf("failed to set MTU %d on link %s: %w", mtu, name, err)
+ }
+
+ if err := netlink.LinkSetUp(link); err != nil {
+ return fmt.Errorf("failed to bring up link %s: %w", name, err)
+ }
+
+ addr, err := netlink.ParseAddr(address)
+ if err != nil {
+ return fmt.Errorf("invalid IP address %s: %w", address, err)
+ }
+ if err := netlink.AddrAdd(link, addr); err != nil {
+ if !strings.Contains(err.Error(), "file exists") {
+ return fmt.Errorf("failed to add address %s to link %s: %w", address, name, err)
+ }
+ }
+
+ route := &netlink.Route{
+ Scope: netlink.SCOPE_UNIVERSE,
+ LinkIndex: link.Attrs().Index,
+ Dst: nil,
+ }
+
+ if err := netlink.RouteAdd(route); err != nil {
+ if err := netlink.RouteReplace(route); err != nil {
+ return fmt.Errorf("failed to configure default route via %s: %w", name, err)
}
}
@@ -222,46 +223,56 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
return ip.String(), nil
}
-func ConfigureResolvConf(dns string) error {
+func ConfigureResolvConf(dns string) (string, error) {
if dns == "" {
- return nil
+ return "", nil
}
-
tmpFile, err := os.CreateTemp("", "resolvconf")
if err != nil {
- return fmt.Errorf("failed to create temp resolv.conf: %w", err)
+ return "", fmt.Errorf("failed to create temp resolv.conf: %w", err)
}
launcherPath := tmpFile.Name()
- defer func() {
- _ = tmpFile.Close()
- _ = os.Remove(launcherPath)
- }()
-
content := fmt.Sprintf("nameserver %s\n", dns)
if _, err := tmpFile.WriteString(content); err != nil {
- return fmt.Errorf("failed to write to temp resolv.conf: %w", err)
+ _ = tmpFile.Close()
+ return "", fmt.Errorf("failed to write to temp resolv.conf: %w", err)
}
+ _ = tmpFile.Close()
if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
- return fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err)
+ return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err)
}
if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil {
fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err)
}
+ return launcherPath, nil
+}
+
+func UnmountResolvConf(path string) error {
+ if path == "" {
+ return nil
+ }
+ if err := unix.Unmount("/etc/resolv.conf", 0); err != nil {
+ return fmt.Errorf("failed to unmount /etc/resolv.conf: %w", err)
+ }
+ if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err)
+ }
return nil
}
-// BlockHostServices blocks local D-Bus and name service cache daemon (nscd) sockets
-// inside the mount namespace. This prevents glibc from bypassing the network namespace
-// isolation via host services (e.g. systemd-resolved via D-Bus).
func BlockHostServices() error {
tmpDir, err := os.MkdirTemp("", "wg-wrap-block-")
if err != nil {
return fmt.Errorf("failed to create temp dir: %w", err)
}
- defer os.RemoveAll(tmpDir)
+ defer func() {
+ if err := os.RemoveAll(tmpDir); err != nil {
+ fmt.Printf("warning: failed to remove temp dir %s: %v\n", tmpDir, err)
+ }
+ }()
tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-")
if err != nil {
@@ -269,7 +280,11 @@ func BlockHostServices() error {
}
tmpFileName := tmpFile.Name()
_ = tmpFile.Close()
- defer os.Remove(tmpFileName)
+ defer func() {
+ if err := os.Remove(tmpFileName); err != nil && !os.IsNotExist(err) {
+ fmt.Printf("warning: failed to remove temp file %s: %v\n", tmpFileName, err)
+ }
+ }()
pathsToBlock := []string{
"/run/dbus/system_bus_socket",
@@ -299,76 +314,22 @@ func BlockHostServices() error {
return nil
}
-// HostBind wraps a standard conn.Bind so that its socket creation (Open)
-// is forced to execute within a host network namespace.
-type HostBind struct {
- inner conn.Bind
- hostNetNSFd int
-}
+type HostBind struct{}
func NewHostBind(inner conn.Bind, hostNetNSFd int) *HostBind {
- return &HostBind{inner: inner, hostNetNSFd: hostNetNSFd}
+ return &HostBind{}
}
func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- // Open/save a reference to our current isolated network namespace to switch back to.
- isolatedFd, err := unix.Open("/proc/self/ns/net", unix.O_RDONLY, 0)
- if err != nil {
- return nil, 0, fmt.Errorf("failed to open isolated netns: %w", err)
- }
- defer func() { _ = unix.Close(isolatedFd) }()
-
- // Temporarily switch this thread to the host network namespace
- if err := unix.Setns(h.hostNetNSFd, unix.CLONE_NEWNET); err != nil {
- return nil, 0, fmt.Errorf("failed to join host netns: %w", err)
- }
-
- // Sockets are opened in the host network namespace!
- fns, actualPort, err = h.inner.Open(port)
- if err != nil {
- return nil, 0, fmt.Errorf("failed to open sockets in host netns: %w", err)
- }
-
- // Switch this thread back to the isolated network namespace
- if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil {
- _ = h.inner.Close()
- // CRITICAL: The thread is stuck in the host network namespace. Returning it to the Go runtime pool
- // will cause other goroutines to run in the host namespace, breaching isolation. We must panic
- // immediately to abort the process and prevent a namespace escape.
- panic(fmt.Sprintf("CRITICAL: failed to restore isolated netns: %v", err))
- }
-
- return fns, actualPort, nil
-}
-
-func (h *HostBind) Close() error {
- return h.inner.Close()
-}
-
-func (h *HostBind) SetMark(mark uint32) error {
- return h.inner.SetMark(mark)
-}
-
-func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
- // Linux socket routing maps to the namespace in which the socket was created,
- // so h.inner.Send will automatically route via host namespace without Setns here!
- return h.inner.Send(bufs, endpoint)
-}
-
-func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- return h.inner.ParseEndpoint(s)
+ return nil, 0, fmt.Errorf("HostBind.Open is disabled for security reasons")
}
-func (h *HostBind) BatchSize() int {
- return h.inner.BatchSize()
-}
+func (h *HostBind) Close() error { return nil }
+func (h *HostBind) SetMark(mark uint32) error { return nil }
+func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return nil }
+func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (h *HostBind) BatchSize() int { return 0 }
-// FDBind implements the conn.Bind interface by wrapping a pre-opened
-// host UDP socket file descriptor. This allows unprivileged processes inside
-// network namespaces to communicate over the host network loop.
type FDBind struct {
originalFd int
conn *net.UDPConn
@@ -378,39 +339,19 @@ type FDEndpoint struct {
addr netip.AddrPort
}
-func (e *FDEndpoint) DstIP() netip.Addr {
- return e.addr.Addr()
-}
-
-func (e *FDEndpoint) DstToString() string {
- return e.addr.String()
-}
-
-func (e *FDEndpoint) DstToBytes() []byte {
- return e.addr.Addr().AsSlice()
-}
-
-func (e *FDEndpoint) ClearSrc() {}
-
-func (e *FDEndpoint) SrcIP() netip.Addr {
- return netip.Addr{}
-}
-
-func (e *FDEndpoint) SrcToString() string {
- return ""
-}
-
-func (e *FDEndpoint) SrcIfidx() int32 {
- return 0
-}
+func (e *FDEndpoint) DstIP() netip.Addr { return e.addr.Addr() }
+func (e *FDEndpoint) DstToString() string { return e.addr.String() }
+func (e *FDEndpoint) DstToBytes() []byte { return e.addr.Addr().AsSlice() }
+func (e *FDEndpoint) ClearSrc() {}
+func (e *FDEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
+func (e *FDEndpoint) SrcToString() string { return "" }
+func (e *FDEndpoint) SrcIfidx() int32 { return 0 }
func NewFDBind(fd int) (*FDBind, error) {
return &FDBind{originalFd: fd}, nil
}
func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
- // Duplicate the original fd so we can close the duplicated socket during
- // transitions or shutdown, while preserving the ability to re-open/re-bind it later.
dupFd, err := unix.Dup(b.originalFd)
if err != nil {
return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err)
@@ -443,14 +384,17 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e
if b.conn == nil {
return 0, net.ErrClosed
}
- nBytes, addr, err := b.conn.ReadFromUDP(packets[0])
- if err != nil {
- return 0, err
+ for i := 0; i < len(packets); i++ {
+ nBytes, addr, err := b.conn.ReadFromUDP(packets[i])
+ if err != nil {
+ return i, err
+ }
+ sizes[i] = nBytes
+ addrPort := addr.AddrPort()
+ eps[i] = &FDEndpoint{addr: addrPort}
+ n++
}
- sizes[0] = nBytes
- addrPort := addr.AddrPort()
- eps[0] = &FDEndpoint{addr: addrPort}
- return 1, nil
+ return n, nil
}
return []conn.ReceiveFunc{receive}, actualPort, nil
@@ -465,9 +409,7 @@ func (b *FDBind) Close() error {
return nil
}
-func (b *FDBind) SetMark(mark uint32) error {
- return nil
-}
+func (b *FDBind) SetMark(mark uint32) error { return nil }
func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
if b.conn == nil {
@@ -478,7 +420,6 @@ func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err)
}
addr := net.UDPAddrFromAddrPort(addrPort)
-
for _, buf := range bufs {
_, err := b.conn.WriteToUDP(buf, addr)
if err != nil {