//go:build linux package wireguard import ( "encoding/base64" "encoding/hex" "fmt" "net" "net/netip" "os" "os/exec" "runtime" "strconv" "strings" "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" ) // Tunnel represents an active Userspace WireGuard tunnel inside a network namespace. type Tunnel struct { Device *device.Device Tun tun.Device } // StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. func StartTunnel(cfg *wgconf.Config, dnsServer string) (*Tunnel, error) { // 1. Create the TUN device inside the current (isolated) namespace // We use the default name 'tun0' tunName := "tun0" mtu := 1420 // Ensure the mount namespace is private to prevent mount propagation to the host. // This is critical for the bind-mount of /etc/resolv.conf to work in rootless environments. if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { // We log this as a warning because some environments might not allow this, // but we can still try to proceed. fmt.Printf("warning: failed to make mount namespace private: %v\n", err) } // Block host services (D-Bus, nscd) to prevent name resolution leak bypasses if err := BlockHostServices(); err != nil { fmt.Printf("warning: failed to block host services: %v\n", err) } tunDev, err := tun.CreateTUN(tunName, mtu) if err != nil { return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err) } // 2. Instantiate the userspace WireGuard device logger := device.NewLogger(device.LogLevelSilent, "[wg-wrap] ") var bind conn.Bind // Check if a pre-opened host UDP socket file descriptor was passed first (Approach A - FD Passing) if hostSocketFdStr := os.Getenv("WG_WRAP_HOST_SOCKET_FD"); hostSocketFdStr != "" { if fd, err := strconv.Atoi(hostSocketFdStr); err == nil && fd > 0 { if fdBind, err := NewFDBind(fd); err == nil { bind = fdBind } } } // Fallback to NewHostBind or standard Bind if no host socket was passed if bind == nil { bind = conn.NewDefaultBind() if hostNetNSFdStr := os.Getenv("WG_WRAP_HOST_NETNS_FD"); hostNetNSFdStr != "" { if fd, err := strconv.Atoi(hostNetNSFdStr); err == nil && fd > 0 { bind = NewHostBind(bind, fd) } } } wgDev := device.NewDevice(tunDev, bind, logger) // 3. Formulate the UAPI configuration string to configure peers/keys uapiConf, err := buildUAPIConfig(cfg) if err != nil { wgDev.Close() return nil, fmt.Errorf("failed to build UAPI config: %w", err) } // Apply configuration via UAPI (IpcSet) if err := wgDev.IpcSet(uapiConf); err != nil { wgDev.Close() return nil, fmt.Errorf("failed to configure WireGuard device: %w", err) } // Enable device if err := wgDev.Up(); err != nil { wgDev.Close() return nil, fmt.Errorf("failed to bring up WireGuard device: %w", err) } // 4. Configure network interface using standard Linux network commands (iproute2) // Since we are mapped to root (UID 0) inside our isolated network namespace, // we have complete control over local network interfaces without affecting the host. if err := configureInterface(tunName, cfg.Address, mtu); err != nil { wgDev.Close() return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) } // Configure DNS resolver inside the namespace if err := ConfigureResolvConf(dnsServer); err != nil { // We treat DNS failure as a warning rather than a fatal error to allow // the tunnel to function even if /etc/resolv.conf is read-only. fmt.Printf("warning: failed to configure DNS resolver: %v\n", err) } return &Tunnel{ Device: wgDev, Tun: tunDev, }, nil } // Close shuts down the userspace WireGuard device and closes the TUN interface. func (t *Tunnel) Close() { if t.Device != nil { t.Device.Close() } } // keyToHex ensures a WireGuard key is in hexadecimal format, converting from base64 if needed. func keyToHex(key string) (string, error) { // Try base64 decoding first decoded, err := base64.StdEncoding.DecodeString(key) if err == nil && len(decoded) == 32 { return hex.EncodeToString(decoded), nil } // Try decoding as hex if len(key) == 64 { if _, err := hex.DecodeString(key); err == nil { return strings.ToLower(key), nil } } return "", fmt.Errorf("key is neither valid base64 nor hex 32-byte key: %s", key) } // buildUAPIConfig translates our wgconf.Config into the standard WireGuard UAPI format func buildUAPIConfig(cfg *wgconf.Config) (string, error) { var sb strings.Builder // Global section if cfg.PrivateKey != "" { hexKey, err := keyToHex(cfg.PrivateKey) if err != nil { return "", fmt.Errorf("invalid PrivateKey: %w", err) } _, _ = fmt.Fprintf(&sb, "private_key=%s\n", hexKey) } // If there are existing peers, remove them first to have a clean state sb.WriteString("replace_peers=true\n") // Peer sections for _, peer := range cfg.Peers { if peer.PublicKey == "" { continue } hexKey, err := keyToHex(peer.PublicKey) if err != nil { return "", fmt.Errorf("invalid Peer PublicKey: %w", err) } _, _ = fmt.Fprintf(&sb, "public_key=%s\n", hexKey) if peer.Endpoint != "" { _, _ = fmt.Fprintf(&sb, "endpoint=%s\n", peer.Endpoint) } for _, allowedIP := range peer.AllowedIPs { trimmed := strings.TrimSpace(allowedIP) if trimmed != "" { _, _ = fmt.Fprintf(&sb, "allowed_ip=%s\n", trimmed) } } } return sb.String(), nil } // configureInterface uses the 'ip' command to set address, MTU, and default routing table func configureInterface(name, address string, mtu int) error { // Set MTU and bring up link // ip link set dev tun0 mtu 1420 up cmd := exec.Command("ip", "link", "set", "dev", name, "mtu", fmt.Sprintf("%d", mtu), "up") if err := cmd.Run(); err != nil { return fmt.Errorf("failed to set link %s state/mtu: %v", name, err) } // Add IP address // ip addr add
dev tun0 cmd = exec.Command("ip", "addr", "add", address, "dev", name) if _, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("failed to add address %s to link %s: %v", address, name, err) } // Set default route or peer routes. // For transparent userspace tunneling inside an isolated network namespace, // we route all traffic (0.0.0.0/0) through our TUN device 'tun0'. cmd = exec.Command("ip", "route", "add", "default", "dev", name) if err := cmd.Run(); err != nil { // If a default route already exists, we replace it or log the warning // We try to replace first cmdReplace := exec.Command("ip", "route", "replace", "default", "dev", name) if errReplace := cmdReplace.Run(); errReplace != nil { return fmt.Errorf("failed to configure default route to %s: %v", name, errReplace) } } return nil } // GetTunnelLocalIP extracts the local IP address (without CIDR) from the config. func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { if cfg.Address == "" { return "", fmt.Errorf("profile has no Address configured") } parts := strings.Split(cfg.Address, "/") ipStr := parts[0] ip, err := netip.ParseAddr(ipStr) if err != nil { return "", fmt.Errorf("invalid IP address in config '%s': %w", ipStr, err) } return ip.String(), nil } func ConfigureResolvConf(dns string) error { if dns == "" { return nil } // To avoid modifying the host's /etc/resolv.conf, we use the private mount namespace. tmpFile, err := os.CreateTemp("", "resolvconf") if err != nil { return fmt.Errorf("failed to create temp resolv.conf: %w", err) } defer func() { _ = tmpFile.Close() }() content := fmt.Sprintf("nameserver %s\n", dns) if _, err := tmpFile.WriteString(content); err != nil { return fmt.Errorf("failed to write to temp resolv.conf: %w", err) } // 1. Bind-mount the temp file over /etc/resolv.conf if err := unix.Mount(tmpFile.Name(), "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { _ = os.Remove(tmpFile.Name()) return fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", tmpFile.Name(), err) } // Unlink the temporary source file. Since /etc/resolv.conf is a bind mount, // the kernel will keep the inode alive, but the file is removed from /tmp. _ = os.Remove(tmpFile.Name()) // 2. Make the mount private to ensure it doesn't propagate back to the host // and to satisfy kernel requirements for mount transitions in some environments. // We do this by applying MS_PRIVATE in a separate mount call. if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { // If MS_PRIVATE fails, we can log a warning but proceed since / is already private fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err) } return nil } // BlockHostServices blocks local D-Bus and name service cache daemon (nscd) sockets // inside the mount namespace. This prevents glibc from bypassing the network namespace // isolation via host services (e.g. systemd-resolved via D-Bus). func BlockHostServices() error { tmpDir, err := os.MkdirTemp("", "wg-wrap-block-") if err != nil { return fmt.Errorf("failed to create temp dir: %w", err) } defer func() { _ = os.Remove(tmpDir) }() tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-") if err != nil { return fmt.Errorf("failed to create temp file: %w", err) } tmpFileName := tmpFile.Name() _ = tmpFile.Close() defer func() { _ = os.Remove(tmpFileName) }() // Specific socket files and directories to block pathsToBlock := []string{ "/run/dbus/system_bus_socket", "/run/systemd/resolve/io.systemd.Resolve", "/run/systemd/resolve/io.systemd.Resolve.Monitor", "/run/nscd/socket", "/var/run/dbus/system_bus_socket", "/var/run/systemd/resolve/io.systemd.Resolve", "/var/run/systemd/resolve/io.systemd.Resolve.Monitor", "/var/run/nscd/socket", } for _, p := range pathsToBlock { stat, err := os.Stat(p) if err == nil { source := tmpFileName if stat.IsDir() { source = tmpDir } if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil { fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err) } else { _ = unix.Mount("", p, "", unix.MS_PRIVATE, "") } } } return nil } // HostBind wraps a standard conn.Bind so that its socket creation (Open) // is forced to execute within a host network namespace. type HostBind struct { inner conn.Bind hostNetNSFd int } func NewHostBind(inner conn.Bind, hostNetNSFd int) *HostBind { return &HostBind{inner: inner, hostNetNSFd: hostNetNSFd} } func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { runtime.LockOSThread() defer runtime.UnlockOSThread() // Open/save a reference to our current isolated network namespace to switch back to. isolatedFd, err := unix.Open("/proc/self/ns/net", unix.O_RDONLY, 0) if err != nil { return nil, 0, fmt.Errorf("failed to open isolated netns: %w", err) } defer func() { _ = unix.Close(isolatedFd) }() // Temporarily switch this thread to the host network namespace if err := unix.Setns(h.hostNetNSFd, unix.CLONE_NEWNET); err != nil { return nil, 0, fmt.Errorf("failed to join host netns: %w", err) } // Sockets are opened in the host network namespace! fns, actualPort, err = h.inner.Open(port) if err != nil { return nil, 0, fmt.Errorf("failed to open sockets in host netns: %w", err) } // Switch this thread back to the isolated network namespace if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil { _ = h.inner.Close() // CRITICAL: The thread is stuck in the host network namespace. Returning it to the Go runtime pool // will cause other goroutines to run in the host namespace, breaching isolation. We must panic // immediately to abort the process and prevent a namespace escape. panic(fmt.Sprintf("CRITICAL: failed to restore isolated netns: %v", err)) } return fns, actualPort, nil } func (h *HostBind) Close() error { return h.inner.Close() } func (h *HostBind) SetMark(mark uint32) error { return h.inner.SetMark(mark) } func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { // Linux socket routing maps to the namespace in which the socket was created, // so h.inner.Send will automatically route via host namespace without Setns here! return h.inner.Send(bufs, endpoint) } func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return h.inner.ParseEndpoint(s) } func (h *HostBind) BatchSize() int { return h.inner.BatchSize() } // FDBind implements the conn.Bind interface by wrapping a pre-opened // host UDP socket file descriptor. This allows unprivileged processes inside // network namespaces to communicate over the host network loop. type FDBind struct { originalFd int conn *net.UDPConn } type FDEndpoint struct { addr netip.AddrPort } func (e *FDEndpoint) DstIP() netip.Addr { return e.addr.Addr() } func (e *FDEndpoint) DstToString() string { return e.addr.String() } func (e *FDEndpoint) DstToBytes() []byte { return e.addr.Addr().AsSlice() } func (e *FDEndpoint) ClearSrc() {} func (e *FDEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (e *FDEndpoint) SrcToString() string { return "" } func (e *FDEndpoint) SrcIfidx() int32 { return 0 } func NewFDBind(fd int) (*FDBind, error) { return &FDBind{originalFd: fd}, nil } func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { // Duplicate the original fd so we can close the duplicated socket during // transitions or shutdown, while preserving the ability to re-open/re-bind it later. dupFd, err := unix.Dup(b.originalFd) if err != nil { return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err) } file := os.NewFile(uintptr(dupFd), "host-udp-socket") pconn, err := net.FilePacketConn(file) if err != nil { _ = file.Close() return nil, 0, fmt.Errorf("failed to wrap fd %d as packet conn: %w", dupFd, err) } udpConn, ok := pconn.(*net.UDPConn) if !ok { _ = pconn.Close() return nil, 0, fmt.Errorf("fd %d is not a UDP socket", dupFd) } b.conn = udpConn laddr, ok := b.conn.LocalAddr().(*net.UDPAddr) if !ok { return nil, 0, fmt.Errorf("local address is not a UDP address") } actualPort = uint16(laddr.Port) receive := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { if len(packets) == 0 { return 0, nil } if b.conn == nil { return 0, net.ErrClosed } nBytes, addr, err := b.conn.ReadFromUDP(packets[0]) if err != nil { return 0, err } sizes[0] = nBytes addrPort := addr.AddrPort() eps[0] = &FDEndpoint{addr: addrPort} return 1, nil } return []conn.ReceiveFunc{receive}, actualPort, nil } func (b *FDBind) Close() error { if b.conn != nil { err := b.conn.Close() b.conn = nil return err } return nil } func (b *FDBind) SetMark(mark uint32) error { return nil } func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { if b.conn == nil { return net.ErrClosed } addrPort, err := netip.ParseAddrPort(endpoint.DstToString()) if err != nil { return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err) } addr := net.UDPAddrFromAddrPort(addrPort) for _, buf := range bufs { _, err := b.conn.WriteToUDP(buf, addr) if err != nil { return fmt.Errorf("failed to write to UDP socket: %w", err) } } return nil } func (b *FDBind) ParseEndpoint(s string) (conn.Endpoint, error) { addrPort, err := netip.ParseAddrPort(s) if err != nil { return nil, fmt.Errorf("failed to parse endpoint address %s: %w", s, err) } return &FDEndpoint{addr: addrPort}, nil } func (b *FDBind) BatchSize() int { return 1 }