Skip to content

Commit b7bf416

Browse files
committed
Add network policy filtering for user-v2 networking
Implements egress traffic filtering with: - Protocol, port, IP/CIDR, and domain-based rules - DNS packet snooping for domain-to-IP tracking - ICMP support (ICMPv4/ICMPv6) - partial - awaiting gvisor fix - Policy validation with strict error checking - DNS tracker with 10k domain limit and TTL expiration Usage: limactl network create NAME --policy policy.yaml Signed-off-by: Simon Kaegi <simon.kaegi@gmail.com>
1 parent e21b634 commit b7bf416

File tree

16 files changed

+2831
-1
lines changed

16 files changed

+2831
-1
lines changed

cmd/limactl/network.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"maps"
1111
"net"
1212
"os"
13+
"path/filepath"
1314
"slices"
1415
"strings"
1516
"text/tabwriter"
@@ -18,12 +19,17 @@ import (
1819
"github.com/spf13/cobra"
1920

2021
"github.com/lima-vm/lima/v2/pkg/networks"
22+
"github.com/lima-vm/lima/v2/pkg/networks/usernet"
23+
"github.com/lima-vm/lima/v2/pkg/networks/usernet/filter"
2124
"github.com/lima-vm/lima/v2/pkg/yqutil"
2225
)
2326

2427
const networkCreateExample = ` Create a network:
2528
$ limactl network create foo --gateway 192.168.42.1/24
2629
30+
Create a network with policy filtering:
31+
$ limactl network create secure --gateway 192.168.42.1/24 --policy ~/policy.yaml
32+
2733
Connect VM instances to the newly created network:
2834
$ limactl create --network lima:foo --name vm1
2935
$ limactl create --network lima:foo --name vm2
@@ -144,6 +150,7 @@ func newNetworkCreateCommand() *cobra.Command {
144150
flags.String("gateway", "", "gateway, e.g., \"192.168.42.1/24\"")
145151
flags.String("interface", "", "interface for bridged mode")
146152
_ = cmd.RegisterFlagCompletionFunc("interface", bashFlagCompleteNetworkInterfaceNames)
153+
flags.String("policy", "", "path to policy file (YAML or JSON, user-v2 mode only)")
147154
return cmd
148155
}
149156

@@ -174,6 +181,38 @@ func networkCreateAction(cmd *cobra.Command, args []string) error {
174181
return err
175182
}
176183

184+
policyPath, err := flags.GetString("policy")
185+
if err != nil {
186+
return err
187+
}
188+
189+
// Handle policy file if provided
190+
if policyPath != "" {
191+
// Only user-v2 mode supports filtering
192+
if mode != networks.ModeUserV2 {
193+
logrus.Warnf("Policy filtering is only supported for mode 'user-v2', ignoring --policy flag")
194+
} else {
195+
// Load the policy to validate it
196+
pol, err := filter.LoadPolicy(policyPath)
197+
if err != nil {
198+
return fmt.Errorf("failed to load policy: %w", err)
199+
}
200+
201+
// Save as JSON in the network directory (~/.lima/_networks/<name>/policy.json)
202+
policyJSONPath, err := usernet.PolicyFile(name)
203+
if err != nil {
204+
return fmt.Errorf("failed to get policy path: %w", err)
205+
}
206+
// Ensure network directory exists (follows usernet convention)
207+
if err := os.MkdirAll(filepath.Dir(policyJSONPath), 0o755); err != nil {
208+
return fmt.Errorf("failed to create network directory: %w", err)
209+
}
210+
if err := filter.SavePolicyJSON(pol, policyJSONPath); err != nil {
211+
return fmt.Errorf("failed to save policy: %w", err)
212+
}
213+
}
214+
}
215+
177216
switch mode {
178217
case networks.ModeBridged:
179218
if gateway != "" {

cmd/limactl/usernet.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ import (
1111
"strconv"
1212
"syscall"
1313

14+
"github.com/sirupsen/logrus"
1415
"github.com/spf13/cobra"
1516

1617
"github.com/lima-vm/lima/v2/pkg/networks/usernet"
18+
"github.com/lima-vm/lima/v2/pkg/networks/usernet/filter"
1719
)
1820

1921
func newUsernetCommand() *cobra.Command {
@@ -31,6 +33,7 @@ func newUsernetCommand() *cobra.Command {
3133
hostagentCommand.Flags().String("subnet", "192.168.5.0/24", "Sets subnet value for the usernet network")
3234
hostagentCommand.Flags().Int("mtu", 1500, "mtu")
3335
hostagentCommand.Flags().StringToString("leases", nil, "Pass default static leases for startup. Eg: '192.168.104.1=52:55:55:b3:bc:d9,192.168.104.2=5a:94:ef:e4:0c:df' ")
36+
hostagentCommand.Flags().String("policy", "", "Path to policy JSON file")
3437
return hostagentCommand
3538
}
3639

@@ -75,6 +78,22 @@ func usernetAction(cmd *cobra.Command, _ []string) error {
7578
return err
7679
}
7780

81+
policyPath, err := cmd.Flags().GetString("policy")
82+
if err != nil {
83+
return err
84+
}
85+
86+
// Parse the policy at the CLI boundary (fail fast on invalid policy)
87+
var policy *filter.Policy
88+
if policyPath != "" {
89+
logrus.Debugf("Loading policy from: %s", policyPath)
90+
policy, err = filter.LoadPolicy(policyPath)
91+
if err != nil {
92+
return fmt.Errorf("failed to load policy: %w", err)
93+
}
94+
logrus.Debugf("Loaded policy with %d rules", len(policy.Rules))
95+
}
96+
7897
os.RemoveAll(endpoint)
7998
os.RemoveAll(qemuSocket)
8099
os.RemoveAll(fdSocket)
@@ -92,5 +111,6 @@ func usernetAction(cmd *cobra.Command, _ []string) error {
92111
FdSocket: fdSocket,
93112
Subnet: subnet,
94113
DefaultLeases: leases,
114+
Policy: policy,
95115
})
96116
}

pkg/networks/usernet/config.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ func Leases(name string) (string, error) {
112112
return sockPath, nil
113113
}
114114

115+
// PolicyFile returns the path to the policy JSON file for the given network name.
116+
// For usernet, this is stored in ~/.lima/_networks/<name>/policy.json (not VarRun).
117+
func PolicyFile(name string) (string, error) {
118+
dir, err := dirnames.LimaNetworksDir()
119+
if err != nil {
120+
return "", err
121+
}
122+
return filepath.Join(dir, name, "policy.json"), nil
123+
}
124+
115125
func netmaskToCidr(baseIP, netMask net.IP) (net.IP, *net.IPNet, error) {
116126
size, _ := net.IPMask(netMask.To4()).Size()
117127
return net.ParseCIDR(fmt.Sprintf("%s/%d", baseIP.String(), size))

pkg/networks/usernet/filter/dns.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package filter
5+
6+
import (
7+
"net"
8+
"strings"
9+
"sync"
10+
"time"
11+
)
12+
13+
const (
14+
// MaxDNSRecords is the maximum number of DNS records to track.
15+
// This prevents unbounded memory growth in long-running processes.
16+
MaxDNSRecords = 10000
17+
)
18+
19+
// DNSRecord represents a DNS query result with TTL.
20+
type DNSRecord struct {
21+
Domain string
22+
IPs []net.IP
23+
ExpireAt time.Time
24+
}
25+
26+
// Tracker tracks domain to IP mappings from DNS queries.
27+
type Tracker struct {
28+
mu sync.RWMutex
29+
records map[string]*DNSRecord // domain -> record
30+
}
31+
32+
// NewTracker creates a new DNS tracker.
33+
func NewTracker() *Tracker {
34+
return &Tracker{
35+
records: make(map[string]*DNSRecord),
36+
}
37+
}
38+
39+
// SeedLimaInternalDomains pre-populates the tracker with Lima internal domains
40+
// These are special domains that map to Lima network infrastructure:
41+
// - subnet.lima.internal -> the entire Lima subnet (e.g., 192.168.100.0/24)
42+
// - host.lima.internal -> the Lima gateway (e.g., 192.168.100.2)
43+
func (t *Tracker) SeedLimaInternalDomains(subnet, gatewayIP string) error {
44+
if subnet == "" {
45+
return nil
46+
}
47+
48+
_, subnetNet, err := net.ParseCIDR(subnet)
49+
if err != nil {
50+
return err
51+
}
52+
53+
// Get all IPs in the subnet for subnet.lima.internal
54+
var subnetIPs []net.IP
55+
// For now, just add the network address
56+
// We could enumerate all IPs but that's expensive for large subnets
57+
subnetIPs = append(subnetIPs, subnetNet.IP)
58+
59+
// Add subnet.lima.internal -> subnet IPs
60+
// Use a very long TTL (24 hours) since these are static mappings
61+
t.AddRecord("subnet.lima.internal", subnetIPs, 24*time.Hour)
62+
63+
// Add host.lima.internal -> Lima gateway
64+
// This must be seeded because gvisor's internal DNS server resolves *.lima.internal
65+
// domains internally, so the DNS snooper never sees the responses
66+
if gatewayIP != "" {
67+
gateway := net.ParseIP(gatewayIP)
68+
if gateway != nil {
69+
t.AddRecord("host.lima.internal", []net.IP{gateway}, 24*time.Hour)
70+
}
71+
}
72+
73+
return nil
74+
}
75+
76+
// AddRecord adds or updates a DNS record.
77+
func (t *Tracker) AddRecord(domain string, ips []net.IP, ttl time.Duration) {
78+
t.mu.Lock()
79+
defer t.mu.Unlock()
80+
81+
domain = strings.ToLower(domain)
82+
83+
// If at capacity and this is a new domain, clean up expired entries first
84+
if _, exists := t.records[domain]; !exists && len(t.records) >= MaxDNSRecords {
85+
t.cleanExpiredLocked()
86+
87+
// If still at capacity after cleanup, remove oldest entry
88+
if len(t.records) >= MaxDNSRecords {
89+
t.removeOldestLocked()
90+
}
91+
}
92+
93+
t.records[domain] = &DNSRecord{
94+
Domain: domain,
95+
IPs: ips,
96+
ExpireAt: time.Now().Add(ttl),
97+
}
98+
}
99+
100+
// GetIPs returns all IPs for a domain, or nil if not found/expired.
101+
func (t *Tracker) GetIPs(domain string) []net.IP {
102+
t.mu.RLock()
103+
defer t.mu.RUnlock()
104+
105+
domain = strings.ToLower(domain)
106+
record, ok := t.records[domain]
107+
if !ok || time.Now().After(record.ExpireAt) {
108+
return nil
109+
}
110+
return record.IPs
111+
}
112+
113+
// GetIPsForPattern returns all IPs matching a domain pattern (supports wildcards).
114+
// Example: "*.example.com" matches "api.example.com", "cdn.example.com".
115+
func (t *Tracker) GetIPsForPattern(pattern string) []net.IP {
116+
t.mu.RLock()
117+
defer t.mu.RUnlock()
118+
119+
pattern = strings.ToLower(pattern)
120+
var allIPs []net.IP
121+
seenIPs := make(map[string]bool)
122+
123+
for domain, record := range t.records {
124+
// Skip expired records
125+
if time.Now().After(record.ExpireAt) {
126+
continue
127+
}
128+
129+
// Check if domain matches pattern
130+
if matchesPattern(domain, pattern) {
131+
for _, ip := range record.IPs {
132+
ipStr := ip.String()
133+
if !seenIPs[ipStr] {
134+
seenIPs[ipStr] = true
135+
allIPs = append(allIPs, ip)
136+
}
137+
}
138+
}
139+
}
140+
141+
return allIPs
142+
}
143+
144+
// GetDomainsForIP returns all domains that resolve to the given IP (reverse lookup).
145+
func (t *Tracker) GetDomainsForIP(ip net.IP) []string {
146+
t.mu.RLock()
147+
defer t.mu.RUnlock()
148+
149+
var domains []string
150+
now := time.Now()
151+
152+
for domain, record := range t.records {
153+
// Skip expired records
154+
if now.After(record.ExpireAt) {
155+
continue
156+
}
157+
158+
// Check if this domain resolves to the given IP
159+
for _, recordIP := range record.IPs {
160+
if recordIP.Equal(ip) {
161+
domains = append(domains, domain)
162+
break
163+
}
164+
}
165+
}
166+
167+
return domains
168+
}
169+
170+
// CleanExpired removes expired DNS records.
171+
func (t *Tracker) CleanExpired() {
172+
t.mu.Lock()
173+
defer t.mu.Unlock()
174+
t.cleanExpiredLocked()
175+
}
176+
177+
// cleanExpiredLocked removes expired DNS records (must hold lock).
178+
func (t *Tracker) cleanExpiredLocked() {
179+
now := time.Now()
180+
for domain, record := range t.records {
181+
if now.After(record.ExpireAt) {
182+
delete(t.records, domain)
183+
}
184+
}
185+
}
186+
187+
// removeOldestLocked removes the record with the earliest expiration time (must hold lock).
188+
func (t *Tracker) removeOldestLocked() {
189+
if len(t.records) == 0 {
190+
return
191+
}
192+
193+
var oldestDomain string
194+
var oldestExpireAt time.Time
195+
first := true
196+
197+
for domain, record := range t.records {
198+
if first || record.ExpireAt.Before(oldestExpireAt) {
199+
oldestDomain = domain
200+
oldestExpireAt = record.ExpireAt
201+
first = false
202+
}
203+
}
204+
205+
if oldestDomain != "" {
206+
delete(t.records, oldestDomain)
207+
}
208+
}
209+
210+
// matchesPattern checks if a domain matches a pattern with wildcard support.
211+
// Pattern examples:
212+
// - "example.com" matches exactly "example.com"
213+
// - "*.example.com" matches "api.example.com", "cdn.example.com", but NOT "example.com"
214+
// - "*" matches everything
215+
func matchesPattern(domain, pattern string) bool {
216+
// Exact match
217+
if domain == pattern {
218+
return true
219+
}
220+
221+
// Match all
222+
if pattern == "*" {
223+
return true
224+
}
225+
226+
// Wildcard pattern
227+
if strings.HasPrefix(pattern, "*.") {
228+
suffix := pattern[2:] // Remove "*."
229+
// Domain must end with the suffix and have at least one more label
230+
if strings.HasSuffix(domain, "."+suffix) {
231+
return true
232+
}
233+
}
234+
235+
return false
236+
}

0 commit comments

Comments
 (0)