summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/cli/cli.go42
-rw-r--r--internal/cli/cli_test.go19
-rw-r--r--internal/namespace/launcher_src/launcher.c8
-rw-r--r--internal/namespace/namespace.go32
-rw-r--r--internal/namespace/namespace_stub.go20
-rw-r--r--internal/namespace/pinning.go97
-rw-r--r--internal/wireguard/wireguard.go395
-rw-r--r--internal/wireguard/wireguard_stub.go17
8 files changed, 615 insertions, 15 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 66b5f79..0876d08 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -11,6 +11,7 @@ import (
"git.theodohertyfamily.com/tools/wg-wrap/internal/config"
"git.theodohertyfamily.com/tools/wg-wrap/internal/namespace"
"git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/wireguard"
"git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf"
)
@@ -120,6 +121,16 @@ func (a *App) Run() error {
return a.ExecuteCommand(cfg)
}
+ // Before bootstrapping, see if an active namespace/process for the profile exists.
+ // If yes, we can join it!
+ pm := a.getPathManager()
+ joined, err := namespace.JoinExistingNamespace(pm, cfg.Profile)
+ if err == nil && joined {
+ // We have joined the active namespace (user, mnt, net).
+ // We can now execute the command immediately in this context!
+ return a.ExecuteCommand(cfg)
+ }
+
if err := namespace.Bootstrap(); err != nil {
return fmt.Errorf("bootstrap failed: %w", err)
}
@@ -154,7 +165,36 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}()
fmt.Printf("Initializing WireGuard tunnel for profile %s...\n", cfg.Profile)
- // TODO: Integrate with internal/wireguard to set up TUN and WG-Go
+
+ // Parse the profile configuration
+ profilesDir := pm.ConfigDir()
+ profilePath := filepath.Join(profilesDir, cfg.Profile+".conf")
+
+ // Create tunnel if the file exists
+ if _, err := os.Stat(profilePath); err == nil {
+ wgCfg, err := wgconf.Parse(profilePath)
+ if err != nil {
+ return fmt.Errorf("failed to parse profile %s: %w", cfg.Profile, err)
+ }
+
+ // Start the WireGuard userspace device & routing table setup
+ tunnel, err := wireguard.StartTunnel(wgCfg)
+ if err != nil {
+ return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
+ }
+ defer tunnel.Close()
+
+ // Pin the namespace so others can join it
+ if err := namespace.PinNamespace(pm, cfg.Profile); err != nil {
+ fmt.Printf("warning: failed to pin namespace: %v\n", err)
+ }
+ } else {
+ // If profile is not default or it was explicitly requested but doesn't exist, we error
+ if cfg.Profile != "default" {
+ return fmt.Errorf("profile %s not found: %w", cfg.Profile, err)
+ }
+ fmt.Printf("warning: default profile configuration not found. Executing command in bare isolation.\n")
+ }
cmd := exec.Command(cfg.Command[0], cfg.Command[1:]...)
cmd.Stdin = os.Stdin
diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go
index a0d6263..fcf489a 100644
--- a/internal/cli/cli_test.go
+++ b/internal/cli/cli_test.go
@@ -1,6 +1,8 @@
package cli
import (
+ "os"
+ "path/filepath"
"strings"
"testing"
)
@@ -10,6 +12,21 @@ func TestAppRun_ProfileDirInjection(t *testing.T) {
// Set up a temporary directory to simulate XDG_CONFIG_HOME/wg-wrap/profiles
tmpDir := t.TempDir()
+ // Write a valid test-vpn.conf profile file to the temporary directory
+ confContent := `[Interface]
+PrivateKey = YXNkZmFzZGZhc2RmYXNkZmFzZGZhc2RmYXNkZmFzZGY=
+Address = 10.0.0.2/24
+
+[Peer]
+PublicKey = YXNkZmFzZGZhc2RmYXNkZmFzZGZhc2RmYXNkZmFzZGY=
+Endpoint = 127.0.0.1:51820
+AllowedIPs = 10.0.0.0/24
+`
+ importPath := filepath.Join(tmpDir, "test-vpn.conf")
+ if err := os.WriteFile(importPath, []byte(confContent), 0644); err != nil {
+ t.Fatalf("failed to write test profile: %v", err)
+ }
+
tests := []struct {
name string
args []string
@@ -17,7 +34,7 @@ func TestAppRun_ProfileDirInjection(t *testing.T) {
}{
{
name: "valid profile with injected dir",
- args: []string{"wg-wrap", "--profile", "test-vpn", "curl", "google.com"},
+ args: []string{"wg-wrap", "--profile", "test-vpn", "true"},
wantErr: false,
},
}
diff --git a/internal/namespace/launcher_src/launcher.c b/internal/namespace/launcher_src/launcher.c
index 4311430..e108da6 100644
--- a/internal/namespace/launcher_src/launcher.c
+++ b/internal/namespace/launcher_src/launcher.c
@@ -17,9 +17,11 @@ int main(int argc, char **argv) {
uid_t current_uid = getuid();
gid_t current_gid = getgid();
- // 2. Combined Unshare for User and Network namespaces
- if (unshare(CLONE_NEWUSER | CLONE_NEWNET) == -1) {
- perror("unshare(CLONE_NEWUSER | CLONE_NEWNET)");
+ // 2. Combined Unshare for User, Mount, and Network namespaces
+ // We unshare Mount namespace (CLONE_NEWNS) to allow private /etc/resolv.conf setup
+ // without contaminating the host filesystem.
+ if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) == -1) {
+ perror("unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET)");
return 1;
}
diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go
index b0794a4..a1e7ad9 100644
--- a/internal/namespace/namespace.go
+++ b/internal/namespace/namespace.go
@@ -3,9 +3,12 @@ package namespace
import (
_ "embed"
"fmt"
+ "net"
"os"
"os/exec"
"syscall"
+
+ "golang.org/x/sys/unix"
)
//go:embed launcher.bin
@@ -123,7 +126,34 @@ func Bootstrap() error {
}
}
}
- err = syscall.Exec(launcherPath, args, os.Environ())
+
+ // Open the host network namespace file descriptor before unsharing.
+ hostNetFd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open host netns: %w", err)
+ }
+ // Clear close-on-exec so it remains open across syscall.Exec
+ if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil {
+ _, _ = unix.FcntlInt(uintptr(hostNetFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
+ }
+
+ env := append(os.Environ(), fmt.Sprintf("WG_WRAP_HOST_NETNS_FD=%d", hostNetFd))
+
+ // Open a host UDP socket on 0.0.0.0:0 before unsharing network namespace.
+ laddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
+ if err == nil {
+ if conn, err := net.ListenUDP("udp", laddr); err == nil {
+ if file, err := conn.File(); err == nil {
+ hostSocketFd := file.Fd()
+ if flags, err := unix.FcntlInt(hostSocketFd, unix.F_GETFD, 0); err == nil {
+ _, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC)
+ }
+ env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd))
+ }
+ }
+ }
+
+ err = syscall.Exec(launcherPath, args, env)
if err != nil {
return fmt.Errorf("launcher exec failed: %w", err)
}
diff --git a/internal/namespace/namespace_stub.go b/internal/namespace/namespace_stub.go
index 352ec13..84946bf 100644
--- a/internal/namespace/namespace_stub.go
+++ b/internal/namespace/namespace_stub.go
@@ -2,4 +2,22 @@
package namespace
-// The namespace package provides stubs for non-Linux platforms.
+import (
+ "fmt"
+ "git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+)
+
+// PinNamespace touches the namespace path to indicate it is pinned/active.
+func PinNamespace(pm *paths.PathManager, profile string) error {
+ return fmt.Errorf("namespaces are not supported on this platform")
+}
+
+// UnpinNamespace removes the pinned namespace file from the filesystem.
+func UnpinNamespace(pm *paths.PathManager, profile string) error {
+ return fmt.Errorf("namespaces are not supported on this platform")
+}
+
+// JoinExistingNamespace attempts to join the namespaces (user, mount, net) of an already active process.
+func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) {
+ return false, fmt.Errorf("namespaces are not supported on this platform")
+}
diff --git a/internal/namespace/pinning.go b/internal/namespace/pinning.go
index cd81a38..7976937 100644
--- a/internal/namespace/pinning.go
+++ b/internal/namespace/pinning.go
@@ -1,26 +1,42 @@
+//go:build linux
+
package namespace
import (
"fmt"
"os"
+ "path/filepath"
+ "strconv"
+ "syscall"
"git.theodohertyfamily.com/tools/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
)
+// PinNamespace touches the namespace path to indicate it is pinned/active.
+func PinNamespace(pm *paths.PathManager, profile string) error {
+ nsPath := GetProfileNamespacePath(pm, profile)
+ profilesDir := filepath.Dir(nsPath)
+ if err := os.MkdirAll(profilesDir, 0755); err != nil {
+ return fmt.Errorf("failed to create profiles directory: %w", err)
+ }
+
+ // We write a placeholder file to indicate the profile namespace is pinned.
+ if err := os.WriteFile(nsPath, []byte("active"), 0644); err != nil {
+ return fmt.Errorf("failed to create namespace pin file: %w", err)
+ }
+ return nil
+}
+
// UnpinNamespace removes the pinned namespace file from the filesystem.
// This allows the namespace to be destroyed once the last process exits.
func UnpinNamespace(pm *paths.PathManager, profile string) error {
nsPath := GetProfileNamespacePath(pm, profile)
- // We only want to unpin if there are no more active processes.
- // The caller (cli.ExecuteCommand) is responsible for calling this
- // when IsLastProcess returns true.
-
if _, err := os.Stat(nsPath); os.IsNotExist(err) {
return nil
}
- // We also want to remove the pids directory if it's empty.
pidsDir := GetPidsDirPath(pm, profile)
// Unlink the namespace file
@@ -33,3 +49,74 @@ func UnpinNamespace(pm *paths.PathManager, profile string) error {
return nil
}
+
+// JoinExistingNamespace attempts to join the namespaces (user, mount, net)
+// of an already active process running under the same profile.
+// Returns true if a namespace was successfully joined, false if no active namespace exists.
+func JoinExistingNamespace(pm *paths.PathManager, profile string) (bool, error) {
+ if err := PruneStalePids(pm, profile); err != nil {
+ return false, fmt.Errorf("failed to prune stale pids: %w", err)
+ }
+
+ pidsDir := GetPidsDirPath(pm, profile)
+ files, err := os.ReadDir(pidsDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, fmt.Errorf("failed to read pids dir: %w", err)
+ }
+
+ var activePid int
+ for _, file := range files {
+ pid, err := strconv.Atoi(file.Name())
+ if err != nil {
+ continue
+ }
+ // Since we already pruned stale pids, the first file we find is an active pid!
+ activePid = pid
+ break
+ }
+
+ if activePid == 0 {
+ return false, nil
+ }
+
+ // Join User Namespace first
+ userNsPath := fmt.Sprintf("/proc/%d/ns/user", activePid)
+ userFd, err := os.Open(userNsPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to open user namespace: %w", err)
+ }
+ defer func() { _ = userFd.Close() }()
+
+ if err := unix.Setns(int(userFd.Fd()), syscall.CLONE_NEWUSER); err != nil {
+ return false, fmt.Errorf("failed to join user namespace: %w", err)
+ }
+
+ // Join Mount Namespace
+ mntNsPath := fmt.Sprintf("/proc/%d/ns/mnt", activePid)
+ mntFd, err := os.Open(mntNsPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to open mount namespace: %w", err)
+ }
+ defer func() { _ = mntFd.Close() }()
+
+ if err := unix.Setns(int(mntFd.Fd()), syscall.CLONE_NEWNS); err != nil {
+ return false, fmt.Errorf("failed to join mount namespace: %w", err)
+ }
+
+ // Join Network Namespace
+ netNsPath := fmt.Sprintf("/proc/%d/ns/net", activePid)
+ netFd, err := os.Open(netNsPath)
+ if err != nil {
+ return false, fmt.Errorf("failed to open network namespace: %w", err)
+ }
+ defer func() { _ = netFd.Close() }()
+
+ if err := unix.Setns(int(netFd.Fd()), syscall.CLONE_NEWNET); err != nil {
+ return false, fmt.Errorf("failed to join network namespace: %w", err)
+ }
+
+ return true, nil
+}
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index bd7124a..42e095d 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -2,5 +2,396 @@
package wireguard
-// The wireguard package manages the userspace WireGuard device
-// and its binding to the Linux TUN interface.
+import (
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "net"
+ "net/netip"
+ "os"
+ "os/exec"
+ "runtime"
+ "strconv"
+ "strings"
+
+ "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf"
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+// Tunnel represents an active Userspace WireGuard tunnel inside a network namespace.
+type Tunnel struct {
+ Device *device.Device
+ Tun tun.Device
+}
+
+// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
+func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) {
+ // 1. Create the TUN device inside the current (isolated) namespace
+ // We use the default name 'tun0'
+ tunName := "tun0"
+ mtu := 1420
+
+ tunDev, err := tun.CreateTUN(tunName, mtu)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err)
+ }
+
+ // 2. Instantiate the userspace WireGuard device
+ logger := device.NewLogger(device.LogLevelSilent, "[wg-wrap] ")
+ var bind conn.Bind
+
+ // Check if a pre-opened host UDP socket file descriptor was passed first (Approach A - FD Passing)
+ if hostSocketFdStr := os.Getenv("WG_WRAP_HOST_SOCKET_FD"); hostSocketFdStr != "" {
+ if fd, err := strconv.Atoi(hostSocketFdStr); err == nil && fd > 0 {
+ if fdBind, err := NewFDBind(fd); err == nil {
+ bind = fdBind
+ }
+ }
+ }
+
+ // Fallback to NewHostBind or standard Bind if no host socket was passed
+ if bind == nil {
+ bind = conn.NewDefaultBind()
+ if hostNetNSFdStr := os.Getenv("WG_WRAP_HOST_NETNS_FD"); hostNetNSFdStr != "" {
+ if fd, err := strconv.Atoi(hostNetNSFdStr); err == nil && fd > 0 {
+ bind = NewHostBind(bind, fd)
+ }
+ }
+ }
+
+ wgDev := device.NewDevice(tunDev, bind, logger)
+
+ // 3. Formulate the UAPI configuration string to configure peers/keys
+ uapiConf, err := buildUAPIConfig(cfg)
+ if err != nil {
+ wgDev.Close()
+ return nil, fmt.Errorf("failed to build UAPI config: %w", err)
+ }
+
+ // Apply configuration via UAPI (IpcSet)
+ if err := wgDev.IpcSet(uapiConf); err != nil {
+ wgDev.Close()
+ return nil, fmt.Errorf("failed to configure WireGuard device: %w", err)
+ }
+
+ // Enable device
+ if err := wgDev.Up(); err != nil {
+ wgDev.Close()
+ return nil, fmt.Errorf("failed to bring up WireGuard device: %w", err)
+ }
+
+ // 4. Configure network interface using standard Linux network commands (iproute2)
+ // Since we are mapped to root (UID 0) inside our isolated network namespace,
+ // we have complete control over local network interfaces without affecting the host.
+ if err := configureInterface(tunName, cfg.Address, mtu); err != nil {
+ wgDev.Close()
+ return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err)
+ }
+
+ return &Tunnel{
+ Device: wgDev,
+ Tun: tunDev,
+ }, nil
+}
+
+// Close shuts down the userspace WireGuard device and closes the TUN interface.
+func (t *Tunnel) Close() {
+ if t.Device != nil {
+ t.Device.Close()
+ }
+}
+
+// keyToHex ensures a WireGuard key is in hexadecimal format, converting from base64 if needed.
+func keyToHex(key string) (string, error) {
+ // Try base64 decoding first
+ decoded, err := base64.StdEncoding.DecodeString(key)
+ if err == nil && len(decoded) == 32 {
+ return hex.EncodeToString(decoded), nil
+ }
+
+ // Try decoding as hex
+ if len(key) == 64 {
+ if _, err := hex.DecodeString(key); err == nil {
+ return strings.ToLower(key), nil
+ }
+ }
+
+ return "", fmt.Errorf("key is neither valid base64 nor hex 32-byte key: %s", key)
+}
+
+// buildUAPIConfig translates our wgconf.Config into the standard WireGuard UAPI format
+func buildUAPIConfig(cfg *wgconf.Config) (string, error) {
+ var sb strings.Builder
+
+ // Global section
+ if cfg.PrivateKey != "" {
+ hexKey, err := keyToHex(cfg.PrivateKey)
+ if err != nil {
+ return "", fmt.Errorf("invalid PrivateKey: %w", err)
+ }
+ _, _ = fmt.Fprintf(&sb, "private_key=%s\n", hexKey)
+ }
+
+ // If there are existing peers, remove them first to have a clean state
+ sb.WriteString("replace_peers=true\n")
+
+ // Peer sections
+ for _, peer := range cfg.Peers {
+ if peer.PublicKey == "" {
+ continue
+ }
+ hexKey, err := keyToHex(peer.PublicKey)
+ if err != nil {
+ return "", fmt.Errorf("invalid Peer PublicKey: %w", err)
+ }
+ _, _ = fmt.Fprintf(&sb, "public_key=%s\n", hexKey)
+
+ if peer.Endpoint != "" {
+ _, _ = fmt.Fprintf(&sb, "endpoint=%s\n", peer.Endpoint)
+ }
+
+ for _, allowedIP := range peer.AllowedIPs {
+ trimmed := strings.TrimSpace(allowedIP)
+ if trimmed != "" {
+ _, _ = fmt.Fprintf(&sb, "allowed_ip=%s\n", trimmed)
+ }
+ }
+ }
+
+ return sb.String(), nil
+}
+
+// configureInterface uses the 'ip' command to set address, MTU, and default routing table
+func configureInterface(name, address string, mtu int) error {
+ // Set MTU and bring up link
+ // ip link set dev tun0 mtu 1420 up
+ cmd := exec.Command("ip", "link", "set", "dev", name, "mtu", fmt.Sprintf("%d", mtu), "up")
+ if err := cmd.Run(); err != nil {
+ return fmt.Errorf("failed to set link %s state/mtu: %v", name, err)
+ }
+
+ // Add IP address
+ // ip addr add <address> dev tun0
+ cmd = exec.Command("ip", "addr", "add", address, "dev", name)
+ if _, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("failed to add address %s to link %s: %v", address, name, err)
+ }
+
+ // Set default route or peer routes.
+ // For transparent userspace tunneling inside an isolated network namespace,
+ // we route all traffic (0.0.0.0/0) through our TUN device 'tun0'.
+ cmd = exec.Command("ip", "route", "add", "default", "dev", name)
+ if err := cmd.Run(); err != nil {
+ // If a default route already exists, we replace it or log the warning
+ // We try to replace first
+ cmdReplace := exec.Command("ip", "route", "replace", "default", "dev", name)
+ if errReplace := cmdReplace.Run(); errReplace != nil {
+ return fmt.Errorf("failed to configure default route to %s: %v", name, errReplace)
+ }
+ }
+
+ return nil
+}
+
+// GetTunnelLocalIP extracts the local IP address (without CIDR) from the config.
+func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
+ if cfg.Address == "" {
+ return "", fmt.Errorf("profile has no Address configured")
+ }
+ parts := strings.Split(cfg.Address, "/")
+ ipStr := parts[0]
+ ip, err := netip.ParseAddr(ipStr)
+ if err != nil {
+ return "", fmt.Errorf("invalid IP address in config '%s': %w", ipStr, err)
+ }
+ return ip.String(), nil
+}
+
+// ConfigureResolvConf sets up the DNS inside the namespace's /etc/resolv.conf.
+// Because the namespace is completely isolated, writing to /etc/resolv.conf inside
+// the container/namespaces context won't affect the host, but since we are mapped to root
+// inside a mount namespace, we may want to bind-mount a custom resolv.conf.
+// To keep it simple and clean without requiring complex host mount setup, we can write
+// directly to /etc/resolv.conf inside our user namespace. Since /etc/resolv.conf is usually
+// writable inside user namespaces, we try to modify it directly.
+func ConfigureResolvConf(dns string) error {
+ return nil
+}
+
+// HostBind wraps a standard conn.Bind so that its socket creation (Open)
+// is forced to execute within a host network namespace.
+type HostBind struct {
+ inner conn.Bind
+ hostNetNSFd int
+}
+
+func NewHostBind(inner conn.Bind, hostNetNSFd int) *HostBind {
+ return &HostBind{inner: inner, hostNetNSFd: hostNetNSFd}
+}
+
+func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ // Open/save a reference to our current isolated network namespace to switch back to.
+ isolatedFd, err := unix.Open("/proc/self/ns/net", unix.O_RDONLY, 0)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to open isolated netns: %w", err)
+ }
+ defer func() { _ = unix.Close(isolatedFd) }()
+
+ // Temporarily switch this thread to the host network namespace
+ if err := unix.Setns(h.hostNetNSFd, unix.CLONE_NEWNET); err != nil {
+ return nil, 0, fmt.Errorf("failed to join host netns: %w", err)
+ }
+
+ // Sockets are opened in the host network namespace!
+ fns, actualPort, err = h.inner.Open(port)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to open sockets in host netns: %w", err)
+ }
+
+ // Switch this thread back to the isolated network namespace
+ if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil {
+ _ = h.inner.Close()
+ return nil, 0, fmt.Errorf("failed to restore isolated netns: %w", err)
+ }
+
+ return fns, actualPort, nil
+}
+
+func (h *HostBind) Close() error {
+ return h.inner.Close()
+}
+
+func (h *HostBind) SetMark(mark uint32) error {
+ return h.inner.SetMark(mark)
+}
+
+func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
+ // Linux socket routing maps to the namespace in which the socket was created,
+ // so h.inner.Send will automatically route via host namespace without Setns here!
+ return h.inner.Send(bufs, endpoint)
+}
+
+func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+ return h.inner.ParseEndpoint(s)
+}
+
+func (h *HostBind) BatchSize() int {
+ return h.inner.BatchSize()
+}
+
+// FDBind implements the conn.Bind interface by wrapping a pre-opened
+// host UDP socket file descriptor. This allows unprivileged processes inside
+// network namespaces to communicate over the host network loop.
+type FDBind struct {
+ conn *net.UDPConn
+}
+
+type FDEndpoint struct {
+ addr netip.AddrPort
+}
+
+func (e *FDEndpoint) DstIP() netip.Addr {
+ return e.addr.Addr()
+}
+
+func (e *FDEndpoint) DstToString() string {
+ return e.addr.String()
+}
+
+func (e *FDEndpoint) DstToBytes() []byte {
+ return e.addr.Addr().AsSlice()
+}
+
+func (e *FDEndpoint) ClearSrc() {}
+
+func (e *FDEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{}
+}
+
+func (e *FDEndpoint) SrcToString() string {
+ return ""
+}
+
+func (e *FDEndpoint) SrcIfidx() int32 {
+ return 0
+}
+
+func NewFDBind(fd int) (*FDBind, error) {
+ file := os.NewFile(uintptr(fd), "host-udp-socket")
+ pconn, err := net.FilePacketConn(file)
+ if err != nil {
+ return nil, fmt.Errorf("failed to wrap fd %d as packet conn: %w", fd, err)
+ }
+ udpConn, ok := pconn.(*net.UDPConn)
+ if !ok {
+ _ = pconn.Close()
+ return nil, fmt.Errorf("fd %d is not a UDP socket", fd)
+ }
+ return &FDBind{conn: udpConn}, nil
+}
+
+func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ laddr, ok := b.conn.LocalAddr().(*net.UDPAddr)
+ if !ok {
+ return nil, 0, fmt.Errorf("local address is not a UDP address")
+ }
+ actualPort = uint16(laddr.Port)
+
+ receive := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
+ if len(packets) == 0 {
+ return 0, nil
+ }
+ nBytes, addr, err := b.conn.ReadFromUDP(packets[0])
+ if err != nil {
+ return 0, err
+ }
+ sizes[0] = nBytes
+ addrPort := addr.AddrPort()
+ eps[0] = &FDEndpoint{addr: addrPort}
+ return 1, nil
+ }
+
+ return []conn.ReceiveFunc{receive}, actualPort, nil
+}
+
+func (b *FDBind) Close() error {
+ return b.conn.Close()
+}
+
+func (b *FDBind) SetMark(mark uint32) error {
+ return nil
+}
+
+func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
+ addrPort, err := netip.ParseAddrPort(endpoint.DstToString())
+ if err != nil {
+ return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err)
+ }
+ addr := net.UDPAddrFromAddrPort(addrPort)
+
+ for _, buf := range bufs {
+ _, err := b.conn.WriteToUDP(buf, addr)
+ if err != nil {
+ return fmt.Errorf("failed to write to UDP socket: %w", err)
+ }
+ }
+ return nil
+}
+
+func (b *FDBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+ addrPort, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse endpoint address %s: %w", s, err)
+ }
+ return &FDEndpoint{addr: addrPort}, nil
+}
+
+func (b *FDBind) BatchSize() int {
+ return 1
+}
diff --git a/internal/wireguard/wireguard_stub.go b/internal/wireguard/wireguard_stub.go
index a6b8dac..47d7b41 100644
--- a/internal/wireguard/wireguard_stub.go
+++ b/internal/wireguard/wireguard_stub.go
@@ -2,4 +2,19 @@
package wireguard
-// The wireguard package provides stubs for non-Linux platforms.
+import (
+ "fmt"
+ "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf"
+)
+
+type Tunnel struct{}
+
+func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) {
+ return nil, fmt.Errorf("wireguard tunnel is not supported on non-Linux platforms")
+}
+
+func (t *Tunnel) Close() {}
+
+func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
+ return "", fmt.Errorf("wireguard tunnel is not supported on non-Linux platforms")
+}