//go:build linux package wireguard import ( "encoding/base64" "encoding/hex" "fmt" "net" "net/netip" "os" "path/filepath" "strconv" "strings" "git.theodohertyfamily.com/wg-wrap/internal/namespace" "git.theodohertyfamily.com/wg-wrap/internal/paths" "git.theodohertyfamily.com/wg-wrap/pkg/wgconf" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" ) // Tunnel represents an active Userspace WireGuard tunnel inside a network namespace. type Tunnel struct { Device *device.Device Tun tun.Device dnsFile string } // StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { var cleanups []func() defer func() { if err != nil { for i := len(cleanups) - 1; i >= 0; i-- { cleanups[i]() } } }() // 1. Create the TUN device inside the current (isolated) namespace tunName := "tun0" mtu := 1420 if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make mount namespace private: %v\n", err) } if err := BlockHostServices(pm, profile); err != nil { fmt.Printf("warning: failed to block host services: %v\n", err) } tunDev, err := tun.CreateTUN(tunName, mtu) if err != nil { return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err) } cleanups = append(cleanups, func() { if err := tunDev.Close(); err != nil { fmt.Printf("warning: failed to close TUN device: %v\n", err) } }) // 2. Instantiate the userspace WireGuard device logger := device.NewLogger(device.LogLevelSilent, "[wg-wrap] ") var bind conn.Bind if hostSocketFdStr := os.Getenv("WG_WRAP_HOST_SOCKET_FD"); hostSocketFdStr != "" { if fd, err := strconv.Atoi(hostSocketFdStr); err == nil && fd > 0 { if fdBind, err := NewFDBind(fd); err == nil { bind = fdBind } } } if bind == nil { return nil, fmt.Errorf("failed to acquire host socket FD: no valid WG_WRAP_HOST_SOCKET_FD provided") } wgDev := device.NewDevice(tunDev, bind, logger) cleanups = append(cleanups, func() { wgDev.Close() }) // 3. Formulate the UAPI configuration string to configure peers/keys uapiConf, err := buildUAPIConfig(cfg) if err != nil { return nil, fmt.Errorf("failed to build UAPI config: %w", err) } if err := wgDev.IpcSet(uapiConf); err != nil { return nil, fmt.Errorf("failed to configure WireGuard device: %w", err) } if err := wgDev.Up(); err != nil { return nil, fmt.Errorf("failed to bring up WireGuard device: %w", err) } // 4. Configure network interface using netlink if err := configureInterface(tunName, cfg.Address, mtu); err != nil { return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) } var dnsFile string profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile) if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil { fmt.Printf("warning: failed to configure DNS resolver: %v\n", err) } else { dnsFile = path cleanups = append(cleanups, func() { UnmountResolvConf(dnsFile) }) } return &Tunnel{ Device: wgDev, Tun: tunDev, dnsFile: dnsFile, }, nil } // Close shuts down the userspace WireGuard device and closes the TUN interface. func (t *Tunnel) Close() { if t.Device != nil { t.Device.Close() } if t.dnsFile != "" { if err := UnmountResolvConf(t.dnsFile); err != nil { fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err) } } } // keyToHex ensures a WireGuard key is in hexadecimal format, converting from base64 if needed. func keyToHex(key string) (string, error) { decoded, err := base64.StdEncoding.DecodeString(key) if err == nil && len(decoded) == 32 { return hex.EncodeToString(decoded), nil } if len(key) == 64 { if _, err := hex.DecodeString(key); err == nil { return strings.ToLower(key), nil } } return "", fmt.Errorf("key is neither valid base64 nor hex 32-byte key: %s", key) } // buildUAPIConfig translates our wgconf.Config into the standard WireGuard UAPI format func buildUAPIConfig(cfg *wgconf.Config) (string, error) { var sb strings.Builder if cfg.PrivateKey != "" { hexKey, err := keyToHex(cfg.PrivateKey) if err != nil { return "", fmt.Errorf("invalid PrivateKey: %w", err) } _, _ = fmt.Fprintf(&sb, "private_key=%s\n", hexKey) } sb.WriteString("replace_peers=true\n") for _, peer := range cfg.Peers { if peer.PublicKey == "" { continue } hexKey, err := keyToHex(peer.PublicKey) if err != nil { return "", fmt.Errorf("invalid Peer PublicKey: %w", err) } _, _ = fmt.Fprintf(&sb, "public_key=%s\n", hexKey) if peer.Endpoint != "" { _, _ = fmt.Fprintf(&sb, "endpoint=%s\n", peer.Endpoint) } for _, allowedIP := range peer.AllowedIPs { trimmed := strings.TrimSpace(allowedIP) if trimmed != "" { _, _ = fmt.Fprintf(&sb, "allowed_ip=%s\n", trimmed) } } } return sb.String(), nil } // configureInterface uses netlink to set address, MTU, and default routing table. func configureInterface(name, address string, mtu int) error { link, err := netlink.LinkByName(name) if err != nil { return fmt.Errorf("failed to find link %s: %w", name, err) } if err := netlink.LinkSetMTU(link, mtu); err != nil { return fmt.Errorf("failed to set MTU %d on link %s: %w", mtu, name, err) } if err := netlink.LinkSetUp(link); err != nil { return fmt.Errorf("failed to bring up link %s: %w", name, err) } addr, err := netlink.ParseAddr(address) if err != nil { return fmt.Errorf("invalid IP address %s: %w", address, err) } if err := netlink.AddrAdd(link, addr); err != nil { if !strings.Contains(err.Error(), "file exists") { return fmt.Errorf("failed to add address %s to link %s: %w", address, name, err) } } var dst *net.IPNet if addr.IP.To4() != nil { _, dst, _ = net.ParseCIDR("0.0.0.0/0") } else { _, dst, _ = net.ParseCIDR("::/0") } route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, LinkIndex: link.Attrs().Index, Dst: dst, } if err := netlink.RouteAdd(route); err != nil { if err := netlink.RouteReplace(route); err != nil { return fmt.Errorf("failed to configure default route via %s: %w", name, err) } } return nil } // GetTunnelLocalIP extracts the local IP address (without CIDR) from the config. func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { if cfg.Address == "" { return "", fmt.Errorf("profile has no Address configured") } parts := strings.Split(cfg.Address, "/") ipStr := parts[0] ip, err := netip.ParseAddr(ipStr) if err != nil { return "", fmt.Errorf("invalid IP address in config '%s': %w", ipStr, err) } return ip.String(), nil } func ConfigureResolvConf(dns string, profileDir string) (string, error) { if dns == "" { return "", nil } // Create the temporary resolv.conf file within the profile directory // so it can be cleaned up during namespace unpinning. tmpFile, err := os.CreateTemp(profileDir, "resolvconf") if err != nil { return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err) } launcherPath := tmpFile.Name() content := fmt.Sprintf("nameserver %s\n", dns) if _, err := tmpFile.WriteString(content); err != nil { _ = tmpFile.Close() return "", fmt.Errorf("failed to write to temp resolv.conf: %w", err) } _ = tmpFile.Close() if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err) } if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err) } return launcherPath, nil } func UnmountResolvConf(path string) error { if path == "" { return nil } fmt.Printf("DEBUG: Unmounting resolv.conf file: %s\n", path) // Attempt to unmount. If it fails, it might already be unmounted // or the namespace might be gone. _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH) if err := os.Remove(path); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err) } return nil } func BlockHostServices(pm *paths.PathManager, profile string) error { blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block") if err := os.MkdirAll(blockDirBase, 0755); err != nil { return fmt.Errorf("failed to create block base directory: %w", err) } tmpDir, err := os.MkdirTemp(blockDirBase, "dir-") if err != nil { return fmt.Errorf("failed to create temp block dir: %w", err) } tmpFile, err := os.CreateTemp(blockDirBase, "file-") if err != nil { return fmt.Errorf("failed to create temp block file: %w", err) } tmpFileName := tmpFile.Name() _ = tmpFile.Close() for _, p := range namespace.GetBlockPaths() { stat, err := os.Stat(p) if err == nil { source := tmpFileName if stat.IsDir() { source = tmpDir } if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil { fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err) } else { _ = unix.Mount("", p, "", unix.MS_PRIVATE, "") } } } return nil } type HostBind struct{} func NewHostBind(inner conn.Bind, hostNetNSFd int) *HostBind { return &HostBind{} } func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, fmt.Errorf("HostBind.Open is disabled for security reasons") } func (h *HostBind) Close() error { return nil } func (h *HostBind) SetMark(mark uint32) error { return nil } func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return nil } func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } func (h *HostBind) BatchSize() int { return 0} type FDBind struct { originalFd int conn *net.UDPConn } type FDEndpoint struct { addr netip.AddrPort } func (e *FDEndpoint) DstIP() netip.Addr { return e.addr.Addr() } func (e *FDEndpoint) DstToString() string { return e.addr.String() } func (e *FDEndpoint) DstToBytes() []byte { return e.addr.Addr().AsSlice() } func (e *FDEndpoint) ClearSrc() {} func (e *FDEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (e *FDEndpoint) SrcToString() string { return "" } func (e *FDEndpoint) SrcIfidx() int32 { return 0 } func NewFDBind(fd int) (*FDBind, error) { return &FDBind{originalFd: fd}, nil } func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { dupFd, err := unix.Dup(b.originalFd) if err != nil { return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err) } file := os.NewFile(uintptr(dupFd), "host-udp-socket") pconn, err := net.FilePacketConn(file) if err != nil { _ = file.Close() return nil, 0, fmt.Errorf("failed to wrap fd %d as packet conn: %w", dupFd, err) } udpConn, ok := pconn.(*net.UDPConn) if !ok { _ = pconn.Close() return nil, 0, fmt.Errorf("fd %d is not a UDP socket", dupFd) } b.conn = udpConn laddr, ok := b.conn.LocalAddr().(*net.UDPAddr) if !ok { return nil, 0, fmt.Errorf("local address is not a UDP address") } actualPort = uint16(laddr.Port) receive := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { if len(packets) == 0 { return 0, nil } if b.conn == nil { return 0, net.ErrClosed } for i := 0; i < len(packets); i++ { nBytes, addr, err := b.conn.ReadFromUDP(packets[i]) if err != nil { return i, err } sizes[i] = nBytes addrPort := addr.AddrPort() eps[i] = &FDEndpoint{addr: addrPort} n++ } return n, nil } return []conn.ReceiveFunc{receive}, actualPort, nil } func (b *FDBind) Close() error { if b.conn != nil { err := b.conn.Close() b.conn = nil return err } return nil } func (b *FDBind) SetMark(mark uint32) error { return nil } func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { if b.conn == nil { return net.ErrClosed } addrPort, err := netip.ParseAddrPort(endpoint.DstToString()) if err != nil { return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err) } addr := net.UDPAddrFromAddrPort(addrPort) for _, buf := range bufs { _, err := b.conn.WriteToUDP(buf, addr) if err != nil { return fmt.Errorf("failed to write to UDP socket: %w", err) } } return nil } func (b *FDBind) ParseEndpoint(s string) (conn.Endpoint, error) { addrPort, err := netip.ParseAddrPort(s) if err != nil { return nil, fmt.Errorf("failed to parse endpoint address %s: %w", s, err) } return &FDEndpoint{addr: addrPort}, nil } func (b *FDBind) BatchSize() int { return 1 }