diff options
Diffstat (limited to 'internal/wireguard')
| -rw-r--r-- | internal/wireguard/wireguard.go | 395 | ||||
| -rw-r--r-- | internal/wireguard/wireguard_stub.go | 17 |
2 files changed, 409 insertions, 3 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index bd7124a..42e095d 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -2,5 +2,396 @@ package wireguard -// The wireguard package manages the userspace WireGuard device -// and its binding to the Linux TUN interface. +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "net" + "net/netip" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + + "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +// Tunnel represents an active Userspace WireGuard tunnel inside a network namespace. +type Tunnel struct { + Device *device.Device + Tun tun.Device +} + +// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. +func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) { + // 1. Create the TUN device inside the current (isolated) namespace + // We use the default name 'tun0' + tunName := "tun0" + mtu := 1420 + + tunDev, err := tun.CreateTUN(tunName, mtu) + if err != nil { + return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err) + } + + // 2. Instantiate the userspace WireGuard device + logger := device.NewLogger(device.LogLevelSilent, "[wg-wrap] ") + var bind conn.Bind + + // Check if a pre-opened host UDP socket file descriptor was passed first (Approach A - FD Passing) + if hostSocketFdStr := os.Getenv("WG_WRAP_HOST_SOCKET_FD"); hostSocketFdStr != "" { + if fd, err := strconv.Atoi(hostSocketFdStr); err == nil && fd > 0 { + if fdBind, err := NewFDBind(fd); err == nil { + bind = fdBind + } + } + } + + // Fallback to NewHostBind or standard Bind if no host socket was passed + if bind == nil { + bind = conn.NewDefaultBind() + if hostNetNSFdStr := os.Getenv("WG_WRAP_HOST_NETNS_FD"); hostNetNSFdStr != "" { + if fd, err := strconv.Atoi(hostNetNSFdStr); err == nil && fd > 0 { + bind = NewHostBind(bind, fd) + } + } + } + + wgDev := device.NewDevice(tunDev, bind, logger) + + // 3. Formulate the UAPI configuration string to configure peers/keys + uapiConf, err := buildUAPIConfig(cfg) + if err != nil { + wgDev.Close() + return nil, fmt.Errorf("failed to build UAPI config: %w", err) + } + + // Apply configuration via UAPI (IpcSet) + if err := wgDev.IpcSet(uapiConf); err != nil { + wgDev.Close() + return nil, fmt.Errorf("failed to configure WireGuard device: %w", err) + } + + // Enable device + if err := wgDev.Up(); err != nil { + wgDev.Close() + return nil, fmt.Errorf("failed to bring up WireGuard device: %w", err) + } + + // 4. Configure network interface using standard Linux network commands (iproute2) + // Since we are mapped to root (UID 0) inside our isolated network namespace, + // we have complete control over local network interfaces without affecting the host. + if err := configureInterface(tunName, cfg.Address, mtu); err != nil { + wgDev.Close() + return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) + } + + return &Tunnel{ + Device: wgDev, + Tun: tunDev, + }, nil +} + +// Close shuts down the userspace WireGuard device and closes the TUN interface. +func (t *Tunnel) Close() { + if t.Device != nil { + t.Device.Close() + } +} + +// keyToHex ensures a WireGuard key is in hexadecimal format, converting from base64 if needed. +func keyToHex(key string) (string, error) { + // Try base64 decoding first + decoded, err := base64.StdEncoding.DecodeString(key) + if err == nil && len(decoded) == 32 { + return hex.EncodeToString(decoded), nil + } + + // Try decoding as hex + if len(key) == 64 { + if _, err := hex.DecodeString(key); err == nil { + return strings.ToLower(key), nil + } + } + + return "", fmt.Errorf("key is neither valid base64 nor hex 32-byte key: %s", key) +} + +// buildUAPIConfig translates our wgconf.Config into the standard WireGuard UAPI format +func buildUAPIConfig(cfg *wgconf.Config) (string, error) { + var sb strings.Builder + + // Global section + if cfg.PrivateKey != "" { + hexKey, err := keyToHex(cfg.PrivateKey) + if err != nil { + return "", fmt.Errorf("invalid PrivateKey: %w", err) + } + _, _ = fmt.Fprintf(&sb, "private_key=%s\n", hexKey) + } + + // If there are existing peers, remove them first to have a clean state + sb.WriteString("replace_peers=true\n") + + // Peer sections + for _, peer := range cfg.Peers { + if peer.PublicKey == "" { + continue + } + hexKey, err := keyToHex(peer.PublicKey) + if err != nil { + return "", fmt.Errorf("invalid Peer PublicKey: %w", err) + } + _, _ = fmt.Fprintf(&sb, "public_key=%s\n", hexKey) + + if peer.Endpoint != "" { + _, _ = fmt.Fprintf(&sb, "endpoint=%s\n", peer.Endpoint) + } + + for _, allowedIP := range peer.AllowedIPs { + trimmed := strings.TrimSpace(allowedIP) + if trimmed != "" { + _, _ = fmt.Fprintf(&sb, "allowed_ip=%s\n", trimmed) + } + } + } + + return sb.String(), nil +} + +// configureInterface uses the 'ip' command to set address, MTU, and default routing table +func configureInterface(name, address string, mtu int) error { + // Set MTU and bring up link + // ip link set dev tun0 mtu 1420 up + cmd := exec.Command("ip", "link", "set", "dev", name, "mtu", fmt.Sprintf("%d", mtu), "up") + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to set link %s state/mtu: %v", name, err) + } + + // Add IP address + // ip addr add <address> dev tun0 + cmd = exec.Command("ip", "addr", "add", address, "dev", name) + if _, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add address %s to link %s: %v", address, name, err) + } + + // Set default route or peer routes. + // For transparent userspace tunneling inside an isolated network namespace, + // we route all traffic (0.0.0.0/0) through our TUN device 'tun0'. + cmd = exec.Command("ip", "route", "add", "default", "dev", name) + if err := cmd.Run(); err != nil { + // If a default route already exists, we replace it or log the warning + // We try to replace first + cmdReplace := exec.Command("ip", "route", "replace", "default", "dev", name) + if errReplace := cmdReplace.Run(); errReplace != nil { + return fmt.Errorf("failed to configure default route to %s: %v", name, errReplace) + } + } + + return nil +} + +// GetTunnelLocalIP extracts the local IP address (without CIDR) from the config. +func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { + if cfg.Address == "" { + return "", fmt.Errorf("profile has no Address configured") + } + parts := strings.Split(cfg.Address, "/") + ipStr := parts[0] + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return "", fmt.Errorf("invalid IP address in config '%s': %w", ipStr, err) + } + return ip.String(), nil +} + +// ConfigureResolvConf sets up the DNS inside the namespace's /etc/resolv.conf. +// Because the namespace is completely isolated, writing to /etc/resolv.conf inside +// the container/namespaces context won't affect the host, but since we are mapped to root +// inside a mount namespace, we may want to bind-mount a custom resolv.conf. +// To keep it simple and clean without requiring complex host mount setup, we can write +// directly to /etc/resolv.conf inside our user namespace. Since /etc/resolv.conf is usually +// writable inside user namespaces, we try to modify it directly. +func ConfigureResolvConf(dns string) error { + return nil +} + +// HostBind wraps a standard conn.Bind so that its socket creation (Open) +// is forced to execute within a host network namespace. +type HostBind struct { + inner conn.Bind + hostNetNSFd int +} + +func NewHostBind(inner conn.Bind, hostNetNSFd int) *HostBind { + return &HostBind{inner: inner, hostNetNSFd: hostNetNSFd} +} + +func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + // Open/save a reference to our current isolated network namespace to switch back to. + isolatedFd, err := unix.Open("/proc/self/ns/net", unix.O_RDONLY, 0) + if err != nil { + return nil, 0, fmt.Errorf("failed to open isolated netns: %w", err) + } + defer func() { _ = unix.Close(isolatedFd) }() + + // Temporarily switch this thread to the host network namespace + if err := unix.Setns(h.hostNetNSFd, unix.CLONE_NEWNET); err != nil { + return nil, 0, fmt.Errorf("failed to join host netns: %w", err) + } + + // Sockets are opened in the host network namespace! + fns, actualPort, err = h.inner.Open(port) + if err != nil { + return nil, 0, fmt.Errorf("failed to open sockets in host netns: %w", err) + } + + // Switch this thread back to the isolated network namespace + if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil { + _ = h.inner.Close() + return nil, 0, fmt.Errorf("failed to restore isolated netns: %w", err) + } + + return fns, actualPort, nil +} + +func (h *HostBind) Close() error { + return h.inner.Close() +} + +func (h *HostBind) SetMark(mark uint32) error { + return h.inner.SetMark(mark) +} + +func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { + // Linux socket routing maps to the namespace in which the socket was created, + // so h.inner.Send will automatically route via host namespace without Setns here! + return h.inner.Send(bufs, endpoint) +} + +func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return h.inner.ParseEndpoint(s) +} + +func (h *HostBind) BatchSize() int { + return h.inner.BatchSize() +} + +// FDBind implements the conn.Bind interface by wrapping a pre-opened +// host UDP socket file descriptor. This allows unprivileged processes inside +// network namespaces to communicate over the host network loop. +type FDBind struct { + conn *net.UDPConn +} + +type FDEndpoint struct { + addr netip.AddrPort +} + +func (e *FDEndpoint) DstIP() netip.Addr { + return e.addr.Addr() +} + +func (e *FDEndpoint) DstToString() string { + return e.addr.String() +} + +func (e *FDEndpoint) DstToBytes() []byte { + return e.addr.Addr().AsSlice() +} + +func (e *FDEndpoint) ClearSrc() {} + +func (e *FDEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *FDEndpoint) SrcToString() string { + return "" +} + +func (e *FDEndpoint) SrcIfidx() int32 { + return 0 +} + +func NewFDBind(fd int) (*FDBind, error) { + file := os.NewFile(uintptr(fd), "host-udp-socket") + pconn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("failed to wrap fd %d as packet conn: %w", fd, err) + } + udpConn, ok := pconn.(*net.UDPConn) + if !ok { + _ = pconn.Close() + return nil, fmt.Errorf("fd %d is not a UDP socket", fd) + } + return &FDBind{conn: udpConn}, nil +} + +func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + laddr, ok := b.conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, 0, fmt.Errorf("local address is not a UDP address") + } + actualPort = uint16(laddr.Port) + + receive := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + if len(packets) == 0 { + return 0, nil + } + nBytes, addr, err := b.conn.ReadFromUDP(packets[0]) + if err != nil { + return 0, err + } + sizes[0] = nBytes + addrPort := addr.AddrPort() + eps[0] = &FDEndpoint{addr: addrPort} + return 1, nil + } + + return []conn.ReceiveFunc{receive}, actualPort, nil +} + +func (b *FDBind) Close() error { + return b.conn.Close() +} + +func (b *FDBind) SetMark(mark uint32) error { + return nil +} + +func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { + addrPort, err := netip.ParseAddrPort(endpoint.DstToString()) + if err != nil { + return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err) + } + addr := net.UDPAddrFromAddrPort(addrPort) + + for _, buf := range bufs { + _, err := b.conn.WriteToUDP(buf, addr) + if err != nil { + return fmt.Errorf("failed to write to UDP socket: %w", err) + } + } + return nil +} + +func (b *FDBind) ParseEndpoint(s string) (conn.Endpoint, error) { + addrPort, err := netip.ParseAddrPort(s) + if err != nil { + return nil, fmt.Errorf("failed to parse endpoint address %s: %w", s, err) + } + return &FDEndpoint{addr: addrPort}, nil +} + +func (b *FDBind) BatchSize() int { + return 1 +} diff --git a/internal/wireguard/wireguard_stub.go b/internal/wireguard/wireguard_stub.go index a6b8dac..47d7b41 100644 --- a/internal/wireguard/wireguard_stub.go +++ b/internal/wireguard/wireguard_stub.go @@ -2,4 +2,19 @@ package wireguard -// The wireguard package provides stubs for non-Linux platforms. +import ( + "fmt" + "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf" +) + +type Tunnel struct{} + +func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) { + return nil, fmt.Errorf("wireguard tunnel is not supported on non-Linux platforms") +} + +func (t *Tunnel) Close() {} + +func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { + return "", fmt.Errorf("wireguard tunnel is not supported on non-Linux platforms") +} |
