From 3b06b38c5b6f63acde8677b8d0acd92c99dfbfaa Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 10 Nov 2025 10:23:04 +0000 Subject: [PATCH 01/12] Merged PR 12588213: Various enforcement fixes [Cherry-picked from ba849d8464ff019fe751872d458168bc9aec2256] While the entire PR contains several (perhaps unrelated) changes, the individual commits in this PR can be reviewed on their own, and there is some background in the commit messages. I've tested this with azcri-containerd as well as with the functional tests in hcsshim and some of the tests fail due to unrelated reasons (can't do VPMem, no GPU, WCOW) - I haven't noticed anything that could be caused by these changes. I also deployed these changes to a VM and can run some confidential workloads manually. Outstanding work that will not be done in this PR: - Deny unmounting of in-use layers (this should already be impossible due to directories being in use, but can result in rego metadata being inconsistent) Missing go tests that will not be part of this PR - I have local stash for these, but they are not fully working. Can do separately: - Functional test for checking that a bad OCISpec.Rootfs is rejected (tested manually) - Functional test for checking that a bad containerID is rejected (tested manually) Signed-off-by: Tingmao Wang --- internal/guest/runtime/hcsv2/uvm.go | 148 +++- internal/guestpath/paths.go | 14 +- internal/layers/lcow.go | 4 +- internal/uvm/start.go | 5 +- pkg/securitypolicy/api.rego | 38 +- pkg/securitypolicy/api_test.rego | 8 +- pkg/securitypolicy/framework.rego | 161 ++++- pkg/securitypolicy/open_door.rego | 2 + pkg/securitypolicy/policy.rego | 2 + .../policy_v0.10.0_api_test.rego | 71 ++ pkg/securitypolicy/rego_utils_test.go | 122 +++- pkg/securitypolicy/regopolicy_linux_test.go | 630 ++++++++++++++++-- pkg/securitypolicy/securitypolicyenforcer.go | 18 + .../securitypolicyenforcer_rego.go | 56 +- test/functional/lcow_policy_test.go | 22 +- 15 files changed, 1144 insertions(+), 157 deletions(-) create mode 100644 pkg/securitypolicy/policy_v0.10.0_api_test.rego diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 4436c73c4f..c57079cfb0 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -14,6 +14,7 @@ import ( "os/exec" "path" "path/filepath" + "regexp" "strings" "sync" "syscall" @@ -40,6 +41,7 @@ import ( "github.com/Microsoft/hcsshim/internal/guest/storage/pmem" "github.com/Microsoft/hcsshim/internal/guest/storage/scsi" "github.com/Microsoft/hcsshim/internal/guest/transport" + "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" "github.com/Microsoft/hcsshim/internal/oci" @@ -54,6 +56,30 @@ import ( // for V2 where the specific message is targeted at the UVM itself. const UVMContainerID = "00000000-0000-0000-0000-000000000000" +// Prevent path traversal via malformed container / sandbox IDs. Container IDs +// can be either UVMContainerID, or a 64 character hex string. This is also used +// to check that sandbox IDs (which is also used in paths) are valid, which has +// the same format. +const validContainerIDRegexRaw = `[0-9a-fA-F]{64}` +var validContainerIDRegex = regexp.MustCompile("^" + validContainerIDRegexRaw + "$") + +// isSandboxId just changes the error message +func checkValidContainerID(id string, isSandboxId bool) error { + if id == UVMContainerID { + return nil + } + + if !validContainerIDRegex.MatchString(id) { + idtype := "container" + if isSandboxId { + idtype = "sandbox" + } + return errors.Errorf("invalid %s id: %s (must match %s)", idtype, id, validContainerIDRegex.String()) + } + + return nil +} + // VirtualPod represents a virtual pod that shares a UVM/Sandbox with other pods type VirtualPod struct { VirtualSandboxID string @@ -245,7 +271,54 @@ func setupSandboxLogDir(sandboxID, virtualSandboxID string) error { // TODO: unify workload and standalone logic for non-sandbox features (e.g., block devices, huge pages, uVM mounts) // TODO(go1.24): use [os.Root] instead of `!strings.HasPrefix(, )` +// Returns whether this host has a security policy set, i.e. if it's running +// confidential containers. +func (h *Host) HasSecurityPolicy() bool { + return len(h.securityOptions.PolicyEnforcer.EncodedSecurityPolicy()) > 0 +} + +// For confidential containers, make sure that the host can't use unexpected +// bundle paths / scratch dir / rootfs +func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHostedContainerSettingsV2) error { + if settings.OCISpecification == nil { + return errors.Errorf("OCISpecification is nil") + } + if settings.OCISpecification.Root == nil { + return errors.Errorf("OCISpecification.Root is nil") + } + + // matches with CreateContainer / createLinuxContainerDocument in internal/hcsoci + containerRootInUVM := path.Join(guestpath.LCOWRootPrefixInUVM, containerID) + if settings.OCIBundlePath != containerRootInUVM { + return errors.Errorf("OCIBundlePath %q must equal expected %q", + settings.OCIBundlePath, containerRootInUVM) + } + expectedContainerRootfs := path.Join(containerRootInUVM, guestpath.RootfsPath) + if settings.OCISpecification.Root.Path != expectedContainerRootfs { + return errors.Errorf("OCISpecification.Root.Path %q must equal expected %q", + settings.OCISpecification.Root.Path, expectedContainerRootfs) + } + + // matches with MountLCOWLayers + scratchDirPath := settings.ScratchDirPath + expectedScratchDirPathNonShared := path.Join(containerRootInUVM, guestpath.ScratchDir, containerID) + expectedScratchDirPathShared := path.Join(guestpath.LCOWRootPrefixInUVM, sandboxID, guestpath.ScratchDir, containerID) + if scratchDirPath != expectedScratchDirPathNonShared && + scratchDirPath != expectedScratchDirPathShared { + return errors.Errorf("ScratchDirPath %q must be either %q or %q", + scratchDirPath, expectedScratchDirPathNonShared, expectedScratchDirPathShared) + } + + return nil +} + func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { + if h.HasSecurityPolicy() { + if err = checkValidContainerID(id, false); err != nil { + return nil, err + } + } + criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // Check for virtual pod annotation @@ -393,6 +466,11 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM case "container": sid, ok := settings.OCISpecification.Annotations[annotations.KubernetesSandboxID] sandboxID = sid + if h.HasSecurityPolicy() { + if err = checkValidContainerID(sid, true); err != nil { + return nil, err + } + } if !ok || sid == "" { return nil, errors.Errorf("unsupported 'io.kubernetes.cri.sandbox-id': '%s'", sid) } @@ -402,7 +480,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM // Add SEV device when security policy is not empty, except when privileged annotation is // set to "true", in which case all UVMs devices are added. - if len(h.securityOptions.PolicyEnforcer.EncodedSecurityPolicy()) > 0 && !oci.ParseAnnotationsBool(ctx, + if h.HasSecurityPolicy() && !oci.ParseAnnotationsBool(ctx, settings.OCISpecification.Annotations, annotations.LCOWPrivileged, false) { if err := specGuest.AddDevSev(ctx, settings.OCISpecification); err != nil { log.G(ctx).WithError(err).Debug("failed to add SEV device") @@ -448,6 +526,12 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM }) } + if h.HasSecurityPolicy() { + if err = checkContainerSettings(sandboxID, id, settings); err != nil { + return nil, err + } + } + user, groups, umask, err := h.securityOptions.PolicyEnforcer.GetUserInfo(settings.OCISpecification.Process, settings.OCISpecification.Root.Path) if err != nil { return nil, err @@ -605,6 +689,12 @@ func writeSpecToFile(ctx context.Context, configFile string, spec *specs.Spec) e } func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) (retErr error) { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, false); err != nil { + return err + } + } + switch req.ResourceType { case guestresource.ResourceTypeSCSIDevice: return modifySCSIDevice(ctx, req.RequestType, req.Settings.(*guestresource.SCSIDevice)) @@ -689,6 +779,12 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * } func (h *Host) modifyContainerSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) error { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, false); err != nil { + return err + } + } + c, err := h.GetCreatedContainer(containerID) if err != nil { return err @@ -1060,6 +1156,9 @@ func modifyMappedVirtualDisk( if err != nil { return err } + if mvd.Filesystem != "" && mvd.Filesystem != "ext4" { + return errors.Errorf("filesystem must be ext4 for read-only scsi mounts") + } } } switch rt { @@ -1076,6 +1175,11 @@ func modifyMappedVirtualDisk( if err != nil { return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) } + } else { + err = securityPolicy.EnforceRWDeviceMountPolicy(ctx, mvd.MountPath, mvd.Encrypted, mvd.EnsureFilesystem, mvd.Filesystem) + if err != nil { + return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1094,6 +1198,10 @@ func modifyMappedVirtualDisk( if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } + } else { + if err := securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1192,8 +1300,42 @@ func modifyCombinedLayers( scratchEncrypted bool, securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + isConfidential := len(securityPolicy.EncodedSecurityPolicy()) > 0 + containerID := cl.ContainerID + switch rt { case guestrequest.RequestTypeAdd: + if isConfidential { + if err := checkValidContainerID(containerID, false); err != nil { + return err + } + + // We check this regardless of what the policy says, as long as we're in + // confidential mode. This matches with checkContainerSettings called for + // container creation request. + expectedContainerRootfs := path.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) + if cl.ContainerRootPath != expectedContainerRootfs { + return fmt.Errorf("combined layers target %q does not match expected path %q", + cl.ContainerRootPath, expectedContainerRootfs) + } + + if cl.ScratchPath != "" { + // At this point, we do not know what the sandbox ID would be yet, so we + // have to allow anything reasonable. + scratchDirRegexStr := fmt.Sprintf( + "^%s/%s/%s/%s$", + guestpath.LCOWRootPrefixInUVM, + validContainerIDRegexRaw, + guestpath.ScratchDir, + containerID, + ) + scratchDirRegex := regexp.MustCompile(scratchDirRegexStr) + if !scratchDirRegex.MatchString(cl.ScratchPath) { + return fmt.Errorf("scratch path %q must match regex %q", + cl.ScratchPath, scratchDirRegexStr) + } + } + } layerPaths := make([]string, len(cl.Layers)) for i, layer := range cl.Layers { layerPaths[i] = layer.Path @@ -1214,12 +1356,14 @@ func modifyCombinedLayers( } } - if err := securityPolicy.EnforceOverlayMountPolicy(ctx, cl.ContainerID, layerPaths, cl.ContainerRootPath); err != nil { + if err := securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { return fmt.Errorf("overlay creation denied by policy: %w", err) } return overlay.MountLayer(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) case guestrequest.RequestTypeRemove: + // cl.ContainerID is not set on remove requests, but rego checks that we can + // only umount previously mounted targets anyway if err := securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { return errors.Wrap(err, "overlay removal denied by policy") } diff --git a/internal/guestpath/paths.go b/internal/guestpath/paths.go index aab9ed1053..1852bc2454 100644 --- a/internal/guestpath/paths.go +++ b/internal/guestpath/paths.go @@ -27,15 +27,17 @@ const ( // LCOWMountPathPrefixFmt is the path format in the LCOW UVM where // non-global mounts, such as Plan9 mounts are added LCOWMountPathPrefixFmt = "/mounts/m%d" - // LCOWGlobalMountPrefixFmt is the path format in the LCOW UVM where global - // mounts are added - LCOWGlobalMountPrefixFmt = "/run/mounts/m%d" + // LCOWGlobalScsiMountPrefixFmt is the path format in the LCOW UVM where + // global desk mounts are added + LCOWGlobalScsiMountPrefixFmt = "/run/mounts/scsi/m%d" // LCOWGlobalDriverPrefixFmt is the path format in the LCOW UVM where drivers // are mounted as read/write LCOWGlobalDriverPrefixFmt = "/run/drivers/%s" - // WCOWGlobalMountPrefixFmt is the path prefix format in the WCOW UVM where - // mounts are added - WCOWGlobalMountPrefixFmt = "C:\\mounts\\m%d" + // WCOWGlobalScsiMountPrefixFmt is the path prefix format in the WCOW UVM + // where global desk mounts are added + WCOWGlobalScsiMountPrefixFmt = `c:\mounts\scsi\m%d` // RootfsPath is part of the container's rootfs path RootfsPath = "rootfs" + // ScratchDir is the name of the directory used for overlay upper and work + ScratchDir = "scratch" ) diff --git a/internal/layers/lcow.go b/internal/layers/lcow.go index dccd994e87..b1385fa4ac 100644 --- a/internal/layers/lcow.go +++ b/internal/layers/lcow.go @@ -159,7 +159,9 @@ func MountLCOWLayers( // handles the case where we want to share a scratch disk for multiple containers instead // of mounting a new one. Pass a unique value for `ScratchPath` to avoid container upper and // work directories colliding in the UVM. - containerScratchPathInUVM := ospath.Join("linux", scsiMount.GuestPath(), "scratch", containerID) + // Note that in the shared scratch case, AddVirtualDisk above is a no-op and + // will return the existing mount. + containerScratchPathInUVM := ospath.Join("linux", scsiMount.GuestPath(), guestpath.ScratchDir, containerID) defer func() { if err != nil { diff --git a/internal/uvm/start.go b/internal/uvm/start.go index 781bc3c417..c921ace550 100644 --- a/internal/uvm/start.go +++ b/internal/uvm/start.go @@ -21,6 +21,7 @@ import ( "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/gcs/prot" + "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/hcs" "github.com/Microsoft/hcsshim/internal/hcs/schema1" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" @@ -357,9 +358,9 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { } else { gb = scsi.NewHCSGuestBackend(uvm.hcsSystem, uvm.OS()) } - guestMountFmt := `c:\mounts\scsi\m%d` + guestMountFmt := guestpath.WCOWGlobalScsiMountPrefixFmt if uvm.OS() == "linux" { - guestMountFmt = "/run/mounts/scsi/m%d" + guestMountFmt = guestpath.LCOWGlobalScsiMountPrefixFmt } mgr, err := scsi.NewManager( scsi.NewHCSHostBackend(uvm.hcsSystem), diff --git a/pkg/securitypolicy/api.rego b/pkg/securitypolicy/api.rego index 36a197ebc2..e7bc653ac4 100644 --- a/pkg/securitypolicy/api.rego +++ b/pkg/securitypolicy/api.rego @@ -3,22 +3,24 @@ package api version := "@@API_VERSION@@" enforcement_points := { - "mount_device": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}}, - "mount_overlay": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}}, - "mount_cims": {"introducedVersion": "0.11.0", "default_results": {"allowed": false}}, - "create_container": {"introducedVersion": "0.1.0", "default_results": {"allowed": false, "env_list": null, "allow_stdio_access": false}}, - "unmount_device": {"introducedVersion": "0.2.0", "default_results": {"allowed": true}}, - "unmount_overlay": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, - "exec_in_container": {"introducedVersion": "0.2.0", "default_results": {"allowed": true, "env_list": null}}, - "exec_external": {"introducedVersion": "0.3.0", "default_results": {"allowed": true, "env_list": null, "allow_stdio_access": false}}, - "shutdown_container": {"introducedVersion": "0.4.0", "default_results": {"allowed": true}}, - "signal_container_process": {"introducedVersion": "0.5.0", "default_results": {"allowed": true}}, - "plan9_mount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, - "plan9_unmount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}}, - "get_properties": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}}, - "dump_stacks": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}}, - "runtime_logging": {"introducedVersion": "0.8.0", "default_results": {"allowed": true}}, - "load_fragment": {"introducedVersion": "0.9.0", "default_results": {"allowed": false, "add_module": false}}, - "scratch_mount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}}, - "scratch_unmount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}}, + "mount_device": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}, "use_framework": false}, + "rw_mount_device": {"introducedVersion": "0.11.0", "default_results": {}, "use_framework": true}, + "mount_overlay": {"introducedVersion": "0.1.0", "default_results": {"allowed": false}, "use_framework": false}, + "mount_cims": {"introducedVersion": "0.11.0", "default_results": {"allowed": false}, "use_framework": false}, + "create_container": {"introducedVersion": "0.1.0", "default_results": {"allowed": false, "env_list": null, "allow_stdio_access": false}, "use_framework": false}, + "unmount_device": {"introducedVersion": "0.2.0", "default_results": {"allowed": true}, "use_framework": false}, + "rw_unmount_device": {"introducedVersion": "0.11.0", "default_results": {}, "use_framework": true}, + "unmount_overlay": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}, "use_framework": false}, + "exec_in_container": {"introducedVersion": "0.2.0", "default_results": {"allowed": true, "env_list": null}, "use_framework": false}, + "exec_external": {"introducedVersion": "0.3.0", "default_results": {"allowed": true, "env_list": null, "allow_stdio_access": false}, "use_framework": false}, + "shutdown_container": {"introducedVersion": "0.4.0", "default_results": {"allowed": true}, "use_framework": false}, + "signal_container_process": {"introducedVersion": "0.5.0", "default_results": {"allowed": true}, "use_framework": false}, + "plan9_mount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}, "use_framework": false}, + "plan9_unmount": {"introducedVersion": "0.6.0", "default_results": {"allowed": true}, "use_framework": false}, + "get_properties": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}, "use_framework": false}, + "dump_stacks": {"introducedVersion": "0.7.0", "default_results": {"allowed": true}, "use_framework": false}, + "runtime_logging": {"introducedVersion": "0.8.0", "default_results": {"allowed": true}, "use_framework": false}, + "load_fragment": {"introducedVersion": "0.9.0", "default_results": {"allowed": false, "add_module": false}, "use_framework": false}, + "scratch_mount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}, "use_framework": false}, + "scratch_unmount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}, "use_framework": false}, } diff --git a/pkg/securitypolicy/api_test.rego b/pkg/securitypolicy/api_test.rego index 2d2de733c6..767c506e58 100644 --- a/pkg/securitypolicy/api_test.rego +++ b/pkg/securitypolicy/api_test.rego @@ -3,8 +3,8 @@ package api version := "0.0.2" enforcement_points := { - "__fixture_for_future_test__": {"introducedVersion": "100.0.0", "default_results": {"allowed": true}}, - "__fixture_for_allowed_test_true__": {"introducedVersion": "0.0.2", "default_results": {"allowed": true}}, - "__fixture_for_allowed_test_false__": {"introducedVersion": "0.0.2", "default_results": {"allowed": false}}, - "__fixture_for_allowed_extra__": {"introducedVersion": "0.0.1", "default_results": {"allowed": false, "__test__": "test"}} + "__fixture_for_future_test__": {"introducedVersion": "100.0.0", "default_results": {"allowed": true}, "use_framework": false}, + "__fixture_for_allowed_test_true__": {"introducedVersion": "0.0.2", "default_results": {"allowed": true}, "use_framework": false}, + "__fixture_for_allowed_test_false__": {"introducedVersion": "0.0.2", "default_results": {"allowed": false}, "use_framework": false}, + "__fixture_for_allowed_extra__": {"introducedVersion": "0.0.1", "default_results": {"allowed": false, "__test__": "test"}, "use_framework": false} } diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index 8a28f3e312..20e0e8b067 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -5,10 +5,28 @@ import future.keywords.in version := "@@FRAMEWORK_VERSION@@" +# Add ^ and $ to regex patterns that doesn't have them. +# This forces the regex to match the entire string, which is safer. +# Policies should include .* explicitly at the beginning or end if partial +# matches are to be allowed. + +anchor_pattern(p) := p { + startswith(p, "^") + endswith(p, "$") +} else := concat("", ["^", p]) { + endswith(p, "$") +} else := concat("", [p, "$"]) { + startswith(p, "^") +} else := concat("", ["^", p, "$"]) + device_mounted(target) { data.metadata.devices[target] } +device_mounted(target) { + data.metadata.rw_devices[target] +} + default deviceHash_ok := false # test if a device hash exists as a layer in a policy container @@ -27,9 +45,14 @@ deviceHash_ok { default mount_device := {"allowed": false} +mount_target_ok { + regex.match(anchor_pattern(input.mountPathRegex), input.target) +} + mount_device := {"metadata": [addDevice], "allowed": true} { not device_mounted(input.target) deviceHash_ok + mount_target_ok addDevice := { "name": "devices", "action": "add", @@ -38,10 +61,38 @@ mount_device := {"metadata": [addDevice], "allowed": true} { } } +allowed_scratch_fs("ext4") +allowed_scratch_fs("xfs") + +rwmount_device_encrypt_ok { + input.encrypted +} + +rwmount_device_encrypt_ok { + allow_unencrypted_scratch +} + +default rw_mount_device := {"allowed": false} + +rw_mount_device := {"metadata": [addDevice], "allowed": true} { + not device_mounted(input.target) + rwmount_device_encrypt_ok + input.ensureFilesystem + allowed_scratch_fs(input.filesystem) + mount_target_ok + addDevice := { + "name": "rw_devices", + "action": "add", + "key": input.target, + "value": true, + } +} + default unmount_device := {"allowed": false} unmount_device := {"metadata": [removeDevice], "allowed": true} { - device_mounted(input.unmountTarget) + data.metadata.devices[input.unmountTarget] + removeDevice := { "name": "devices", "action": "remove", @@ -49,6 +100,18 @@ unmount_device := {"metadata": [removeDevice], "allowed": true} { } } +default rw_unmount_device := {"allowed": false} + +rw_unmount_device := {"metadata": [removeRWDevice], "allowed": true} { + data.metadata.rw_devices[input.unmountTarget] + + removeRWDevice := { + "name": "rw_devices", + "action": "remove", + "key": input.unmountTarget, + } +} + layerPaths_ok(layers) { length := count(layers) count(input.layerPaths) == length @@ -127,6 +190,10 @@ default mount_overlay := {"allowed": false} mount_overlay := {"metadata": [addMatches, addOverlayTarget], "allowed": true} { not overlay_exists + # sanity check, but due to checks in the Go code, this should always pass if + # `not overlay_exists` passes. + not overlay_mounted(input.target) + containers := [container | container := candidate_containers[_] layerPaths_ok(container.layers) @@ -171,30 +238,7 @@ env_ok(pattern, "string", value) { } env_ok(pattern, "re2", value) { - anchored := anchor_pattern(pattern) - regex.match(anchored, value) -} - -anchor_pattern(p) := anchored { - startswith_leading := startswith(p, "^") - endswith_trailing := endswith(p, "$") - - anchored = sprintf("%s%s%s", [ - add_leading_trailing_chars(startswith_leading, "", "^"), # Add ^ only if missing - p, - add_leading_trailing_chars(endswith_trailing, "", "$") # Add $ only if missing - ]) -} - -# Function to return one of two values depending on a boolean condition -add_leading_trailing_chars(cond, ifTrue, ifFalse) := result { - cond - result = ifTrue -} - -add_leading_trailing_chars(cond, ifTrue, ifFalse) := result { - not cond - result = ifFalse + regex.match(anchor_pattern(pattern), value) } rule_ok(rule, env) { @@ -316,7 +360,7 @@ idName_ok(pattern, "name", value) { } idName_ok(pattern, "re2", value) { - regex.match(pattern, value.name) + regex.match(anchor_pattern(pattern), value.name) } user_ok(user) { @@ -682,13 +726,13 @@ security_ok(current_container) { mountSource_ok(constraint, source) { startswith(constraint, data.sandboxPrefix) newConstraint := replace(constraint, data.sandboxPrefix, input.sandboxDir) - regex.match(newConstraint, source) + regex.match(anchor_pattern(newConstraint), source) } mountSource_ok(constraint, source) { startswith(constraint, data.hugePagesPrefix) newConstraint := replace(constraint, data.hugePagesPrefix, input.hugePagesDir) - regex.match(newConstraint, source) + regex.match(anchor_pattern(newConstraint), source) } mountSource_ok(constraint, source) { @@ -918,7 +962,7 @@ default plan9_mount := {"allowed": false} plan9_mount := {"metadata": [addPlan9Target], "allowed": true} { not plan9_mounted(input.target) some containerID, _ in data.metadata.matches - pattern := concat("", [input.rootPrefix, "/", containerID, input.mountPathPrefix]) + pattern := concat("", ["^", input.rootPrefix, "/", containerID, input.mountPathPrefix, "$"]) regex.match(pattern, input.target) addPlan9Target := { "name": "p9mounts", @@ -940,20 +984,28 @@ plan9_unmount := {"metadata": [removePlan9Target], "allowed": true} { } -default enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": true, "invalid": false, "version_missing": false} +default enforcement_point_info := { + "available": false, + "default_results": {"allow": false}, + "unknown": true, + "invalid": false, + "version_missing": false, + "use_framework": false +} -enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": false, "version_missing": true} { +enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": false, "version_missing": true, "use_framework": false} { policy_api_version == null } -enforcement_point_info := {"available": available, "default_results": default_results, "unknown": false, "invalid": false, "version_missing": false} { +enforcement_point_info := {"available": available, "default_results": default_results, "unknown": false, "invalid": false, "version_missing": false, "use_framework": use_framework} { enforcement_point := data.api.enforcement_points[input.name] semver.compare(data.api.version, enforcement_point.introducedVersion) >= 0 available := semver.compare(policy_api_version, enforcement_point.introducedVersion) >= 0 default_results := enforcement_point.default_results + use_framework := enforcement_point.use_framework } -enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": true, "version_missing": false} { +enforcement_point_info := {"available": false, "default_results": {"allow": false}, "unknown": false, "invalid": true, "version_missing": false, "use_framework": false} { enforcement_point := data.api.enforcement_points[input.name] semver.compare(data.api.version, enforcement_point.introducedVersion) < 0 } @@ -1250,9 +1302,50 @@ errors["device already mounted at path"] { device_mounted(input.target) } +errors["mountpoint invalid"] { + input.rule in ["mount_device", "rw_mount_device"] + not mount_target_ok +} + errors["no device at path to unmount"] { input.rule == "unmount_device" - not device_mounted(input.unmountTarget) + not data.metadata.devices[input.unmountTarget] + not data.metadata.rw_devices[input.unmountTarget] +} + +errors["received read-only unmount request, but device provided is read-write"] { + input.rule == "unmount_device" + not data.metadata.devices[input.unmountTarget] + data.metadata.rw_devices[input.unmountTarget] +} + +errors["no device at path to unmount"] { + input.rule == "rw_unmount_device" + not data.metadata.devices[input.unmountTarget] + not data.metadata.rw_devices[input.unmountTarget] +} + +errors["received read-write unmount request, but device provided is read-only"] { + input.rule == "rw_unmount_device" + not data.metadata.rw_devices[input.unmountTarget] + data.metadata.devices[input.unmountTarget] +} + +# Error string tested in azcri-containerd Test_RunPodSandboxNotAllowed_WithPolicy_EncryptedScratchPolicy +errors["unencrypted scratch not allowed, non-readonly mount request for SCSI disk must request encryption"] { + input.rule == "rw_mount_device" + not allow_unencrypted_scratch + not input.encrypted +} + +errors["ensureFilesystem must be set on rw device mounts"] { + input.rule == "rw_mount_device" + not input.ensureFilesystem +} + +errors["rw device mounts uses a filesystem that is not allowed"] { + input.rule == "rw_mount_device" + not allowed_scratch_fs(input.filesystem) } errors["container already started"] { diff --git a/pkg/securitypolicy/open_door.rego b/pkg/securitypolicy/open_door.rego index a8e283092d..23c35f9b04 100644 --- a/pkg/securitypolicy/open_door.rego +++ b/pkg/securitypolicy/open_door.rego @@ -3,10 +3,12 @@ package policy api_version := "@@API_VERSION@@" mount_device := {"allowed": true} +rw_mount_device := {"allowed": true} mount_overlay := {"allowed": true} create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} mount_cims := {"allowed": true} unmount_device := {"allowed": true} +rw_unmount_device := {"allowed": true} unmount_overlay := {"allowed": true} exec_in_container := {"allowed": true, "env_list": null} exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} diff --git a/pkg/securitypolicy/policy.rego b/pkg/securitypolicy/policy.rego index 9414116c19..03a71094bd 100644 --- a/pkg/securitypolicy/policy.rego +++ b/pkg/securitypolicy/policy.rego @@ -6,7 +6,9 @@ framework_version := "@@FRAMEWORK_VERSION@@" @@OBJECTS@@ mount_device := data.framework.mount_device +rw_mount_device := data.framework.rw_mount_device unmount_device := data.framework.unmount_device +rw_unmount_device := data.framework.rw_unmount_device mount_overlay := data.framework.mount_overlay unmount_overlay := data.framework.unmount_overlay mount_cims:= data.framework.mount_cims diff --git a/pkg/securitypolicy/policy_v0.10.0_api_test.rego b/pkg/securitypolicy/policy_v0.10.0_api_test.rego new file mode 100644 index 0000000000..a5ea9c56d4 --- /dev/null +++ b/pkg/securitypolicy/policy_v0.10.0_api_test.rego @@ -0,0 +1,71 @@ +package policy + +api_version := "0.10.0" +framework_version := "0.3.0" + +containers := [ + { + "allow_elevated": false, + "allow_stdio_access": true, + "capabilities": { + "ambient": [], + "bounding": [], + "effective": [], + "inheritable": [], + "permitted": [] + }, + "command": [ "bash" ], + "env_rules": [], + "exec_processes": [], + "layers": [ + "@@CONTAINER_LAYER_HASH@@", + ], + "mounts": [], + "no_new_privileges": false, + "seccomp_profile_sha256": "", + "signals": [], + "user": { + "group_idnames": [ + { + "pattern": "", + "strategy": "any" + } + ], + "umask": "0022", + "user_idname": { + "pattern": "", + "strategy": "any" + } + }, + "working_dir": "/" + } +] + +allow_properties_access := true +allow_dump_stacks := false +allow_runtime_logging := false +allow_environment_variable_dropping := true +allow_unencrypted_scratch := false +allow_capability_dropping := true + +mount_device := data.framework.mount_device +unmount_device := data.framework.unmount_device +mount_overlay := data.framework.mount_overlay +unmount_overlay := data.framework.unmount_overlay +create_container := data.framework.create_container +exec_in_container := data.framework.exec_in_container +exec_external := {"allowed": true, + "allow_stdio_access": true, + "env_list": input.envList} +shutdown_container := data.framework.shutdown_container +signal_container_process := data.framework.signal_container_process +plan9_mount := data.framework.plan9_mount +plan9_unmount := data.framework.plan9_unmount +get_properties := data.framework.get_properties +dump_stacks := data.framework.dump_stacks +runtime_logging := data.framework.runtime_logging +load_fragment := data.framework.load_fragment +scratch_mount := data.framework.scratch_mount +scratch_unmount := data.framework.scratch_unmount + +reason := {"errors": data.framework.errors} diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index dbe016098d..2e357f071a 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -6,6 +6,7 @@ package securitypolicy import ( "context" _ "embed" + "encoding/hex" "encoding/json" "fmt" "math/rand" @@ -15,6 +16,7 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "syscall" "testing" "time" @@ -34,7 +36,6 @@ const ( maxExternalProcessesInGeneratedConstraints = 16 maxFragmentsInGeneratedConstraints = 4 maxGeneratedExternalProcesses = 12 - maxGeneratedSandboxIDLength = 32 maxGeneratedEnforcementPointLength = 64 maxGeneratedPlan9Mounts = 8 maxGeneratedFragmentFeedLength = 256 @@ -46,7 +47,6 @@ const ( minStringLength = 10 maxContainersInGeneratedConstraints = 32 maxLayersInGeneratedContainer = 32 - maxGeneratedContainerID = 1000000 maxGeneratedCommandLength = 128 maxGeneratedCommandArgs = 12 maxGeneratedEnvironmentVariables = 16 @@ -355,8 +355,17 @@ func mountImageForContainer(policy *regoEnforcer, container *securityPolicyConta return "", fmt.Errorf("error creating valid overlay: %w", err) } + scratchDisk := getScratchDiskMountTarget(containerID) + err = policy.EnforceRWDeviceMountPolicy(ctx, scratchDisk, true, true, "xfs") + if err != nil { + return "", fmt.Errorf("error mounting scratch disk: %w", err) + } + + overlayTarget := getOverlayMountTarget(containerID) + // see NOTE_TESTCOPY - err = policy.EnforceOverlayMountPolicy(ctx, containerID, copyStrings(layerPaths), testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + ctx, containerID, copyStrings(layerPaths), overlayTarget) if err != nil { return "", fmt.Errorf("error mounting filesystem: %w", err) } @@ -1333,7 +1342,8 @@ func selectFragmentsFromConstraints(gc *generatedConstraints, numFragments int, } func generateSandboxID(r *rand.Rand) string { - return randVariableString(r, maxGeneratedSandboxIDLength) + // Sandbox IDs has the same format as container IDs + return generateContainerID(r) } func generateEnforcementPoint(r *rand.Rand) string { @@ -1615,6 +1625,13 @@ func copyStrings(values []string) []string { //go:embed api_test.rego var apiTestCode string +//go:embed policy_v0.10.0_api_test.rego +var policyWith_0_10_0_apiTestCode string + +func getPolicyCode_0_10_0(layerHash string) string { + return strings.Replace(policyWith_0_10_0_apiTestCode, "@@CONTAINER_LAYER_HASH@@", layerHash, 1) +} + func (p *regoEnforcer) injectTestAPI() error { p.rego.RemoveModule("api.rego") p.rego.AddModule("api.rego", &rpi.RegoModule{Namespace: "api", Code: apiTestCode}) @@ -2030,7 +2047,7 @@ func assertDecisionJSONContains(t *testing.T, err error, expectedValues ...strin for _, expected := range expectedValues { if !strings.Contains(policyDecision, expected) { - t.Errorf("expected error to contain %q", expected) + t.Errorf("expected error to contain %q, but got %q", expected, policyDecision) return false } } @@ -2492,7 +2509,6 @@ func buildEnvironmentVariablesFromEnvRules(rules []EnvRuleConfig, r *rand.Rand) // Build in all required rules, this isn't a setup method of "missing item" // tests for _, rule := range rules { - if rule.Required { if rule.Strategy != EnvVarRuleRegex { vars = append(vars, rule.Rule) @@ -2529,12 +2545,14 @@ func buildEnvironmentVariablesFromEnvRules(rules []EnvRuleConfig, r *rand.Rand) usedIndexes[anIndex] = struct{}{} } numberOfMatches-- - } return vars } +// Only used for random mount targets or for the standard enforcer. Rego policy +// enforces proper targets that are e.g. created from +// guestpath.LCOWGlobalScsiMountPrefixFmt func generateMountTarget(r *rand.Rand) string { return randVariableString(r, maxGeneratedMountTargetLength) } @@ -2563,8 +2581,12 @@ func selectRootHashFromConstraints(constraints *generatedConstraints, r *rand.Ra } func generateContainerID(r *rand.Rand) string { - id := atLeastOneAtMost(r, maxGeneratedContainerID) - return strconv.FormatInt(int64(id), 10) + idbytes := make([]byte, 32) + _, err := r.Read(idbytes) + if err != nil { + panic(fmt.Errorf("failed to generate random container ID: %w", err)) + } + return hex.EncodeToString(idbytes) } func generateMounts(r *rand.Rand) []mountInternal { @@ -2654,26 +2676,28 @@ func selectContainerFromContainerList(containers []*securityPolicyContainer, r * } type dataGenerator struct { - rng *rand.Rand - mountTargets stringSet - containerIDs stringSet - sandboxIDs stringSet - enforcementPoints stringSet - fragmentIssuers stringSet - fragmentFeeds stringSet - fragmentNamespaces stringSet + rng *rand.Rand + layerMountTarget stringSet + nextLayerMountTarget atomic.Uint64 + containerIDs stringSet + sandboxIDs stringSet + enforcementPoints stringSet + fragmentIssuers stringSet + fragmentFeeds stringSet + fragmentNamespaces stringSet } func newDataGenerator(rng *rand.Rand) *dataGenerator { return &dataGenerator{ - rng: rng, - mountTargets: make(stringSet), - containerIDs: make(stringSet), - sandboxIDs: make(stringSet), - enforcementPoints: make(stringSet), - fragmentIssuers: make(stringSet), - fragmentFeeds: make(stringSet), - fragmentNamespaces: make(stringSet), + rng: rng, + layerMountTarget: make(stringSet), + nextLayerMountTarget: atomic.Uint64{}, + containerIDs: make(stringSet), + sandboxIDs: make(stringSet), + enforcementPoints: make(stringSet), + fragmentIssuers: make(stringSet), + fragmentFeeds: make(stringSet), + fragmentNamespaces: make(stringSet), } } @@ -2687,21 +2711,36 @@ func (s *stringSet) randUnique(r *rand.Rand, generator func(*rand.Rand) string) } } -func (gen *dataGenerator) uniqueMountTarget() string { - return gen.mountTargets.randUnique(gen.rng, generateMountTarget) +// Generate a purely random mount target. This will be rejected by rego. +func (gen *dataGenerator) uniqueRandomMountTarget() string { + return gen.layerMountTarget.randUnique(gen.rng, generateMountTarget) } func (gen *dataGenerator) uniqueContainerID() string { return gen.containerIDs.randUnique(gen.rng, generateContainerID) } +func (gen *dataGenerator) uniqueLayerMountTarget() string { + idx := gen.nextLayerMountTarget.Add(1) + return fmt.Sprintf(guestpath.LCOWGlobalScsiMountPrefixFmt, idx) +} + +func getScratchDiskMountTarget(containerID string) string { + return path.Join(guestpath.LCOWRootPrefixInUVM, containerID) +} + +// Returns the roofs of a container. +func getOverlayMountTarget(containerID string) string { + return path.Join(guestpath.LCOWRootPrefixInUVM, containerID, guestpath.RootfsPath) +} + func (gen *dataGenerator) createValidOverlayForContainer(enforcer SecurityPolicyEnforcer, container *securityPolicyContainer) ([]string, error) { ctx := context.Background() // storage for our mount paths overlay := make([]string, len(container.Layers)) for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err @@ -2714,14 +2753,16 @@ func (gen *dataGenerator) createValidOverlayForContainer(enforcer SecurityPolicy } func (gen *dataGenerator) createInvalidOverlayForContainer(enforcer SecurityPolicyEnforcer, container *securityPolicyContainer) ([]string, error) { - method := gen.rng.Intn(3) + method := gen.rng.Intn(4) switch method { case 0: return gen.invalidOverlaySameSizeWrongMounts(enforcer, container) case 1: return gen.invalidOverlayCorrectDevicesWrongOrderSomeMissing(enforcer, container) - default: + case 2: return gen.invalidOverlayRandomJunk(enforcer, container) + default: + return gen.invalidOverlayRandomNoMount(enforcer, container) } } @@ -2731,14 +2772,14 @@ func (gen *dataGenerator) invalidOverlaySameSizeWrongMounts(enforcer SecurityPol overlay := make([]string, len(container.Layers)) for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err } // generate a random new mount point to cause an error - overlay[len(overlay)-i-1] = gen.uniqueMountTarget() + overlay[len(overlay)-i-1] = gen.uniqueLayerMountTarget() } return overlay, nil @@ -2754,7 +2795,7 @@ func (gen *dataGenerator) invalidOverlayCorrectDevicesWrongOrderSomeMissing(enfo var overlay []string for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err @@ -2775,12 +2816,12 @@ func (gen *dataGenerator) invalidOverlayRandomJunk(enforcer SecurityPolicyEnforc overlay := make([]string, layersToCreate) for i := 0; i < int(layersToCreate); i++ { - overlay[i] = gen.uniqueMountTarget() + overlay[i] = generateMountTarget(gen.rng) } // setup entirely different and "correct" expected mounting for i := 0; i < len(container.Layers); i++ { - mount := gen.uniqueMountTarget() + mount := gen.uniqueLayerMountTarget() err := enforcer.EnforceDeviceMountPolicy(ctx, mount, container.Layers[i]) if err != nil { return overlay, err @@ -2790,6 +2831,17 @@ func (gen *dataGenerator) invalidOverlayRandomJunk(enforcer SecurityPolicyEnforc return overlay, nil } +func (gen *dataGenerator) invalidOverlayRandomNoMount(enforcer SecurityPolicyEnforcer, container *securityPolicyContainer) ([]string, error) { + layersToCreate := gen.rng.Int31n(maxLayersInGeneratedContainer) + overlay := make([]string, layersToCreate) + + for i := 0; i < int(layersToCreate); i++ { + overlay[i] = gen.uniqueLayerMountTarget() + } + + return overlay, nil +} + func randVariableString(r *rand.Rand, maxLen int32) string { return randString(r, atLeastOneAtMost(r, maxLen)) } diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index 94cfd8d031..bec4f27f07 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -9,6 +9,7 @@ import ( "fmt" "math/rand" "os" + "path" "path/filepath" "slices" "strconv" @@ -106,7 +107,7 @@ func Test_MarshalRego_Policy(t *testing.T) { _, err = newRegoPolicy(expected, defaultMounts, privilegedMounts, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } @@ -193,11 +194,11 @@ func Test_Rego_EnforceDeviceMountPolicy_No_Matches(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := generateInvalidRootHash(testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) @@ -219,11 +220,11 @@ func Test_Rego_EnforceDeviceMountPolicy_Matches(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := selectRootHashFromConstraints(p, testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) @@ -237,7 +238,7 @@ func Test_Rego_EnforceDeviceMountPolicy_Matches(t *testing.T) { } } -func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { +func Test_Rego_EnforceDeviceUnmountPolicy_Removes_Device_Entries(t *testing.T) { f := func(p *generatedConstraints) bool { securityPolicy := p.toPolicy() policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) @@ -247,7 +248,7 @@ func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := selectRootHashFromConstraints(p, testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) @@ -272,7 +273,36 @@ func Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries(t *testing.T) { } if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { - t.Errorf("Test_Rego_EnforceDeviceUmountPolicy_Removes_Device_Entries failed: %v", err) + t.Errorf("Test_Rego_EnforceDeviceUnmountPolicy_Removes_Device_Entries failed: %v", err) + } +} + +func Test_Rego_EnforceDeviceUnmountPolicy_No_Matches(t *testing.T) { + f := func(p *generatedConstraints) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Error(err) + return false + } + + target := testDataGenerator.uniqueLayerMountTarget() + err = policy.EnforceDeviceUnmountPolicy(p.ctx, target) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + + target = getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, target) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceDeviceUnmountPolicy_No_Matches failed: %v", err) } } @@ -282,11 +312,11 @@ func Test_Rego_EnforceDeviceMountPolicy_Duplicate_Device_Target(t *testing.T) { policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) if err != nil { - t.Errorf("unable to convert policy to rego: %v", err) + t.Errorf("cannot make rego policy from constraints: %v", err) return false } - target := testDataGenerator.uniqueMountTarget() + target := testDataGenerator.uniqueLayerMountTarget() rootHash := selectRootHashFromConstraints(p, testRand) err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) if err != nil { @@ -309,6 +339,342 @@ func Test_Rego_EnforceDeviceMountPolicy_Duplicate_Device_Target(t *testing.T) { } } +func Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget(t *testing.T) { + f := func(p *generatedConstraints) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + target := testDataGenerator.uniqueRandomMountTarget() + rootHash := selectRootHashFromConstraints(p, testRand) + + err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) + + return assertDecisionJSONContains(t, err, "mountpoint invalid") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget failed: %v", err) + } +} + +func Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget_PathTraversal(t *testing.T) { + p := generateConstraints(testRand, 1) + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return + } + + target := testDataGenerator.uniqueLayerMountTarget() + "/../../../../.." + rootHash := selectRootHashFromConstraints(p, testRand) + + err = policy.EnforceDeviceMountPolicy(p.ctx, target, rootHash) + + assertDecisionJSONContains(t, err, "mountpoint invalid") +} + +func deviceMountUnmountTest(t *testing.T, p *generatedConstraints, policy *regoEnforcer, mountScratchFirst, unmountScratchFirst, testInvalidUnmount bool) bool { + container := selectContainerFromContainerList(p.containers, testRand) + containerID := testDataGenerator.uniqueContainerID() + rotarget := testDataGenerator.uniqueLayerMountTarget() + rwtarget := getScratchDiskMountTarget(containerID) + + var err error + + mountScratch := func() bool { + err = policy.EnforceRWDeviceMountPolicy(p.ctx, rwtarget, true, true, "xfs") + if err != nil { + t.Errorf("unable to mount rw device: %v", err) + return false + } + return true + } + + mountLayer := func() bool { + err = policy.EnforceDeviceMountPolicy(p.ctx, rotarget, container.Layers[0]) + if err != nil { + t.Errorf("unable to mount ro device: %v", err) + return false + } + return true + } + + if mountScratchFirst { + if !mountScratch() || !mountLayer() { + return false + } + } else { + if !mountLayer() || !mountScratch() { + return false + } + } + + unmountScratch := func() bool { + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, rwtarget) + if err != nil { + t.Errorf("unable to unmount rw device: %v", err) + return false + } + return true + } + + unmountLayer := func() bool { + err = policy.EnforceDeviceUnmountPolicy(p.ctx, rotarget) + if err != nil { + t.Errorf("unable to unmount ro device: %v", err) + return false + } + return true + } + + if unmountScratchFirst { + if !unmountScratch() || !unmountLayer() { + return false + } + } else { + if !unmountLayer() || !unmountScratch() { + return false + } + } + + if testInvalidUnmount { + err = policy.EnforceDeviceUnmountPolicy(p.ctx, rotarget) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, rwtarget) + if !assertDecisionJSONContains(t, err, "no device at path to unmount") { + return false + } + } + + return true +} + +func Test_Rego_EnforceRWDeviceMountPolicy_MountAndUnmount(t *testing.T) { + f := func(p *generatedConstraints, mountScratchFirst, unmountScratchFirst bool) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + return deviceMountUnmountTest(t, p, policy, mountScratchFirst, unmountScratchFirst, true) + } + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceRWDeviceMountPolicy_MountAndUnmount failed: %v", err) + } +} + +func Test_Rego_EnforceRWDeviceMountPolicy_InvalidTarget(t *testing.T) { + f := func(p *generatedConstraints, encrypted bool, ensureFileSystem bool) bool { + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + target := testDataGenerator.uniqueRandomMountTarget() + filesystem := "xfs" + + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + + return assertDecisionJSONContains(t, err, "mountpoint invalid") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceRWDeviceMountPolicy_Matches failed: %v", err) + } +} + +func Test_Rego_EnforceRWDeviceMountPolicy_MissingEnsureFilesystem(t *testing.T) { + f := func(p *generatedConstraints, encrypted bool) bool { + p.allowUnencryptedScratch = !encrypted + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + target := getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + filesystem := "xfs" + + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, false, filesystem) + + return assertDecisionJSONContains(t, err, "ensureFilesystem must be set on rw device mounts") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 10, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceRWDeviceMountPolicy_Matches failed: %v", err) + } +} + +func Test_Rego_EnforceRWDeviceMountPolicy_DontAllowUnencrypted(t *testing.T) { + p := generateConstraints(testRand, 1) + p.allowUnencryptedScratch = false + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return + } + + target := getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + filesystem := "xfs" + + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, false, true, filesystem) + + assertDecisionJSONContains(t, err, "unencrypted scratch not allowed, non-readonly mount request for SCSI disk must request encryption") +} + +func Test_Rego_EnforceRWDeviceMountPolicy_InvalidFilesystem(t *testing.T) { + p := generateConstraints(testRand, 1) + securityPolicy := p.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return + } + + target := getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + dangerousFilesystems := []string{ + "9p", + "overlay", + "nfs", + "cifs", + } + + for _, filesystem := range dangerousFilesystems { + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, true, true, filesystem) + assertDecisionJSONContains(t, err, "rw device mounts uses a filesystem that is not allowed") + } +} + +// Test that for an older allow all policy (api version < 0.11.0) that does not +// have rw_mount_device, the use_framework passthrough is done correctly, +// allowing enforcing rw mounts. +func Test_Rego_EnforceRWDeviceMountPolicy_Compat_0_10_0_allow_all(t *testing.T) { + p := generateConstraints(testRand, 1) + regoPolicy := ` +package policy + +api_version := "0.10.0" +framework_version := "0.3.0" + +mount_device := {"allowed": true} +mount_overlay := {"allowed": true} +create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} +unmount_device := {"allowed": true} +unmount_overlay := {"allowed": true} +exec_in_container := {"allowed": true, "env_list": null} +exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} +shutdown_container := {"allowed": true} +signal_container_process := {"allowed": true} +plan9_mount := {"allowed": true} +plan9_unmount := {"allowed": true} +get_properties := {"allowed": true} +dump_stacks := {"allowed": true} +runtime_logging := {"allowed": true} +load_fragment := {"allowed": true} +scratch_mount := {"allowed": true} +scratch_unmount := {"allowed": true} +` + for _, b1 := range []bool{true, false} { + for _, b2 := range []bool{true, false} { + policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile rego policy: %v", err) + return + } + + t.Run(fmt.Sprintf("mountScratchFirst=%t, unmountScratchFirst=%t", b1, b2), func(t *testing.T) { + deviceMountUnmountTest(t, p, policy, b1, b2, false) + }) + } + } +} + +// Test that for an older policy (api version < 0.11.0) that does not have +// rw_mount_device, the use_framework passthrough is done correctly, allowing +// enforcing rw mounts. +func Test_Rego_EnforceRWDeviceMountPolicy_Compat_0_10_0(t *testing.T) { + p := generateConstraints(testRand, 1) + regoPolicy := getPolicyCode_0_10_0(p.containers[0].Layers[0]) + for _, b1 := range []bool{true, false} { + for _, b2 := range []bool{true, false} { + policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile rego policy: %v", err) + return + } + + t.Run(fmt.Sprintf("mountScratchFirst=%t, unmountScratchFirst=%t", b1, b2), func(t *testing.T) { + deviceMountUnmountTest(t, p, policy, b1, b2, true) + }) + } + } + + policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile rego policy: %v", err) + return + } + + // Invalid mount target + target := testDataGenerator.uniqueRandomMountTarget() + filesystem := "xfs" + encrypted := true + ensureFileSystem := true + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + assertDecisionJSONContains(t, err, "mountpoint invalid") + + // Missing ensureFilesystem + ensureFileSystem = false + target = getScratchDiskMountTarget(testDataGenerator.uniqueContainerID()) + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + assertDecisionJSONContains(t, err, "ensureFilesystem must be set on rw device mounts") + + // Unencrypted scratch not allowed + ensureFileSystem = true + encrypted = false + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + assertDecisionJSONContains(t, err, "unencrypted scratch not allowed, non-readonly mount request for SCSI disk must request encryption") +} + +func Test_Rego_EnforceRWDeviceMountPolicy_OpenDoor(t *testing.T) { + p := generateConstraints(testRand, 1) + policy, err := newRegoPolicy(openDoorRego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot compile open door rego policy: %v", err) + return + } + + deviceMountUnmountTest(t, p, policy, true, true, false) + + ensureFileSystem := false + encrypted := false + filesystem := "zfs" + target := "/bin" + err = policy.EnforceRWDeviceMountPolicy(p.ctx, target, encrypted, ensureFileSystem, filesystem) + if err != nil { + t.Errorf("unexpected error mounting rw device: %v", err) + } + + err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, target) + if err != nil { + t.Errorf("unexpected error unmounting rw device: %v", err) + } +} + // Verify that RegoSecurityPolicyEnforcer.EnforceOverlayMountPolicy will // return an error when there's no matching overlay targets. func Test_Rego_EnforceOverlayMountPolicy_No_Matches(t *testing.T) { @@ -319,7 +685,8 @@ func Test_Rego_EnforceOverlayMountPolicy_No_Matches(t *testing.T) { return false } - err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, getOverlayMountTarget(tc.containerID)) if err == nil { return false @@ -348,7 +715,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Matches(t *testing.T) { return false } - err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, getOverlayMountTarget(tc.containerID)) // getting an error means something is broken return err == nil @@ -388,7 +756,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_With_Same_Root_Hash(t *testing.T t.Fatalf("error creating valid overlay: %v", err) } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, layers, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, layers, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unable to create an overlay where root hashes are the same") } @@ -428,7 +797,7 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { sharedMount := "" for i := 0; i < len(containerOne.Layers); i++ { - mount := testDataGenerator.uniqueMountTarget() + mount := testDataGenerator.uniqueLayerMountTarget() err := policy.EnforceDeviceMountPolicy(constraints.ctx, mount, containerOne.Layers[i]) if err != nil { t.Fatalf("Unexpected error mounting overlay device: %v", err) @@ -440,13 +809,14 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { containerOneOverlay[len(containerOneOverlay)-i-1] = mount } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, containerOneOverlay, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, containerOneOverlay, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unexpected error mounting overlay: %v", err) } // - // Mount our second contaniers overlay. This should all work. + // Mount our second container overlay. This should all work. // containerID = testDataGenerator.uniqueContainerID() @@ -456,7 +826,7 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { for i := 0; i < len(containerTwo.Layers); i++ { var mount string if i != sharedLayerIndex { - mount = testDataGenerator.uniqueMountTarget() + mount = testDataGenerator.uniqueLayerMountTarget() err := policy.EnforceDeviceMountPolicy(constraints.ctx, mount, containerTwo.Layers[i]) if err != nil { @@ -469,7 +839,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Layers_Shared_Layers(t *testing.T) { containerTwoOverlay[len(containerTwoOverlay)-i-1] = mount } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, containerTwoOverlay, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, containerTwoOverlay, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unexpected error mounting overlay: %v", err) } @@ -490,12 +861,16 @@ func Test_Rego_EnforceOverlayMountPolicy_Overlay_Single_Container_Twice(t *testi return false } - if err := tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()); err != nil { + overlayTarget := getOverlayMountTarget(tc.containerID) + + if err := tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, overlayTarget); err != nil { t.Errorf("expected nil error got: %v", err) return false } - if err := tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()); err == nil { + if err := tc.policy.EnforceOverlayMountPolicy( + p.ctx, tc.containerID, tc.layers, overlayTarget); err == nil { t.Errorf("able to create overlay for the same container twice") return false } else { @@ -536,7 +911,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Reusing_ID_Across_Overlays(t *testing.T t.Fatalf("Unexpected error creating valid overlay: %v", err) } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, layerPaths, getOverlayMountTarget(containerID)) if err != nil { t.Fatalf("Unexpected error mounting overlay filesystem: %v", err) } @@ -547,7 +923,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Reusing_ID_Across_Overlays(t *testing.T t.Fatalf("Unexpected error creating valid overlay: %v", err) } - err = policy.EnforceOverlayMountPolicy(constraints.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, containerID, layerPaths, getOverlayMountTarget(containerID)) if err == nil { t.Fatalf("Unexpected success mounting overlay filesystem") } @@ -588,7 +965,8 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } id := testDataGenerator.uniqueContainerID() - err = policy.EnforceOverlayMountPolicy(constraints.ctx, id, layerPaths, testDataGenerator.uniqueMountTarget()) + err = policy.EnforceOverlayMountPolicy( + constraints.ctx, id, layerPaths, getOverlayMountTarget(id)) if err != nil { t.Fatalf("failed with %d containers", containersToCreate) } @@ -604,7 +982,7 @@ func Test_Rego_EnforceOverlayUnmountPolicy(t *testing.T) { return false } - target := testDataGenerator.uniqueMountTarget() + target := getOverlayMountTarget(tc.containerID) err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, target) if err != nil { t.Errorf("Failure setting up overlay for testing: %v", err) @@ -633,14 +1011,14 @@ func Test_Rego_EnforceOverlayUnmountPolicy_No_Matches(t *testing.T) { return false } - target := testDataGenerator.uniqueMountTarget() + target := getOverlayMountTarget(tc.containerID) err = tc.policy.EnforceOverlayMountPolicy(p.ctx, tc.containerID, tc.layers, target) if err != nil { t.Errorf("Failure setting up overlay for testing: %v", err) return false } - badTarget := testDataGenerator.uniqueMountTarget() + badTarget := getOverlayMountTarget(generateContainerID(testRand)) err = tc.policy.EnforceOverlayUnmountPolicy(p.ctx, badTarget) if err == nil { t.Errorf("Unexpected policy enforcement success: %v", err) @@ -774,6 +1152,125 @@ func Test_Rego_EnforceEnvironmentVariablePolicy_NotAllMatches(t *testing.T) { } } +func Test_Rego_EnforceEnvironmentVariablePolicy_RegexPatterns(t *testing.T) { + testCases := []struct { + rule string + expectMatches []string + expectNotMatches []string + skipAddAnchors bool + }{ + { + rule: "PREFIX_.+=.+", + expectMatches: []string{"PREFIX_FOO=BAR"}, + expectNotMatches: []string{"PREFIX_FOO=", "SOMETHING=ELSE", "SOMETHING_PREFIX_FOO=BAR"}, + }, + { + rule: "PREFIX_.+=.+BAR", + expectMatches: []string{"PREFIX_FOO=FOO_BAR"}, + expectNotMatches: []string{"PREFIX_FOO=BAR_FOO"}, + }, + { + rule: "SIMPLE_VAR=.+", + expectMatches: []string{"SIMPLE_VAR=FOO"}, + expectNotMatches: []string{"SIMPLE_VAR=", "SOMETHING=ELSE", "SOMETHING=ELSE:SIMPLE_VAR=FOO", "SIMPLE_VAR_FOO=BAR", "SIMPLE_VAR"}, + }, + { + rule: "SIMPLE_VAR=.*", + expectMatches: []string{"SIMPLE_VAR=FOO", "SIMPLE_VAR="}, + expectNotMatches: []string{"SIMPLE_VAR"}, + }, + { + rule: "SIMPLE_VAR=", + expectMatches: []string{"SIMPLE_VAR="}, + expectNotMatches: []string{"SIMPLE_VAR", "SIMPLE_VAR=FOO"}, + }, + { + rule: "", + expectMatches: []string{}, + expectNotMatches: []string{"ANYTHING", "ANYTHING=ELSE"}, + }, + { + rule: "(^PREFIX1|^PREFIX2)=.+$", + expectMatches: []string{"PREFIX1=FOO", "PREFIX2=BAR"}, + expectNotMatches: []string{"PREFIX3_FOO=BAR", "PREFIX1=", "SOMETHING=ELSE", ""}, + skipAddAnchors: true, + }, + } + + testRule := func(rule string, expectMatches, expectNotMatches []string) { + testName := rule + if testName == "" { + testName = "(empty)" + } + t.Run(testName, func(t *testing.T) { + gc := generateConstraints(testRand, 1) + container := selectContainerFromContainerList(gc.containers, testRand) + container.EnvRules = append(container.EnvRules, EnvRuleConfig{ + Strategy: EnvVarRuleRegex, + Rule: rule, + }) + gc.allowEnvironmentVariableDropping = false + + for _, env := range expectMatches { + tc, err := setupRegoCreateContainerTest(gc, container, false) + if err != nil { + t.Error(err) + return + } + + tc.envList = append(tc.envList, env) + envsToKeep, _, _, err := tc.policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + + // getting an error means something is broken + if err != nil { + t.Errorf("Expected container creation to be allowed for env %s. It wasn't: %v", env, err) + return + } + + if !areStringArraysEqual(envsToKeep, tc.envList) { + t.Errorf("Expected env %s to be kept, but it was not in the returned envs: %v", env, envsToKeep) + return + } + } + + for _, env := range expectNotMatches { + tc, err := setupRegoCreateContainerTest(gc, container, false) + if err != nil { + t.Error(err) + return + } + + tc.envList = append(tc.envList, env) + _, _, _, err = tc.policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + + // not getting an error means something is broken + if err == nil { + t.Errorf("Expected container creation not to be allowed for env %s. It was allowed: %v", env, err) + return + } + + envName := strings.Split(env, "=")[0] + assertDecisionJSONContains(t, err, "invalid env list", envName) + } + }) + } + + for _, testCase := range testCases { + if !testCase.skipAddAnchors { + for _, rule := range []string{ + testCase.rule, + "^" + testCase.rule, + testCase.rule + "$", + "^" + testCase.rule + "$", + } { + testRule(rule, testCase.expectMatches, testCase.expectNotMatches) + } + } else { + testRule(testCase.rule, testCase.expectMatches, testCase.expectNotMatches) + } + } +} + func Test_Rego_EnforceEnvironmentVariablePolicy_DropEnvs(t *testing.T) { testFunc := func(gc *generatedConstraints) bool { gc.allowEnvironmentVariableDropping = true @@ -1880,8 +2377,8 @@ func Test_Rego_Enforcement_Point_Allowed(t *testing.T) { t.Fatal(err) } - input := make(map[string]interface{}) - results, err := policy.applyDefaults("__fixture_for_allowed_test_false__", input) + results := make(rpi.RegoQueryResult) + results, err = policy.applyDefaults("__fixture_for_allowed_test_false__", nil, results) if err != nil { t.Fatalf("applied defaults for an enforcement point receieved an error: %v", err) } @@ -1896,8 +2393,8 @@ func Test_Rego_Enforcement_Point_Allowed(t *testing.T) { t.Fatal("result of allowed for an available enforcement point was not the specified default (false)") } - input = make(map[string]interface{}) - results, err = policy.applyDefaults("__fixture_for_allowed_test_true__", input) + results = make(rpi.RegoQueryResult) + results, err = policy.applyDefaults("__fixture_for_allowed_test_true__", nil, results) if err != nil { t.Fatalf("applied defaults for an enforcement point receieved an error: %v", err) } @@ -3326,9 +3823,7 @@ func Test_Rego_Plan9MountPolicy_No_Matches(t *testing.T) { tc.seccomp, ) - if err == nil { - t.Fatal("Policy enforcement unexpectedly was allowed") - } + assertDecisionJSONContains(t, err, "invalid mount list") } func Test_Rego_Plan9MountPolicy_Invalid(t *testing.T) { @@ -3346,6 +3841,21 @@ func Test_Rego_Plan9MountPolicy_Invalid(t *testing.T) { } } +func Test_Rego_Plan9MountPolicy_Invalid_PathTraversal(t *testing.T) { + gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) + + tc, err := setupPlan9MountTest(gc) + if err != nil { + t.Fatalf("unable to setup test: %v", err) + } + + mount := tc.uvmPathForShare + "/../../bin" + err = tc.policy.EnforcePlan9MountPolicy(gc.ctx, mount) + if err == nil { + t.Fatal("Policy enforcement unexpectedly was allowed", err) + } +} + func Test_Rego_Plan9UnmountPolicy(t *testing.T) { gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) @@ -3381,9 +3891,7 @@ func Test_Rego_Plan9UnmountPolicy(t *testing.T) { tc.seccomp, ) - if err == nil { - t.Fatal("Policy enforcement unexpectedly was allowed") - } + assertDecisionJSONContains(t, err, "invalid mount list") } func Test_Rego_Plan9UnmountPolicy_No_Matches(t *testing.T) { @@ -3402,9 +3910,7 @@ func Test_Rego_Plan9UnmountPolicy_No_Matches(t *testing.T) { badMount := randString(testRand, maxPlan9MountTargetLength) err = tc.policy.EnforcePlan9UnmountPolicy(gc.ctx, badMount) - if err == nil { - t.Fatalf("Policy enforcement unexpectedly was allowed") - } + assertDecisionJSONContains(t, err, "no device at path to unmount") } func Test_Rego_GetPropertiesPolicy_On(t *testing.T) { @@ -4226,6 +4732,9 @@ func Test_Rego_Scratch_Mount_Policy(t *testing.T) { failureExpected: false, }, } { + + filesystem := "xfs" + t.Run(fmt.Sprintf("UnencryptedAllowed_%t_And_Encrypted_%t", tc.unencryptedAllowed, tc.encrypted), func(t *testing.T) { gc := generateConstraints(testRand, maxContainersInGeneratedConstraints) smConfig, err := setupRegoScratchMountTest(gc, tc.unencryptedAllowed) @@ -4233,15 +4742,29 @@ func Test_Rego_Scratch_Mount_Policy(t *testing.T) { t.Fatalf("unable to setup test: %s", err) } - scratchPath := generateMountTarget(testRand) + containerId := testDataGenerator.uniqueContainerID() + scratchDiskMount := getScratchDiskMountTarget(containerId) + + err = smConfig.policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchDiskMount, tc.encrypted, true, filesystem) + if tc.failureExpected { + if err == nil { + t.Fatal("mounting should've been denied") + } + } else { + if err != nil { + t.Fatalf("mounting unexpectedly was denied: %s", err) + } + } + + scratchPath := path.Join(scratchDiskMount, guestpath.ScratchDir, containerId) err = smConfig.policy.EnforceScratchMountPolicy(gc.ctx, scratchPath, tc.encrypted) if tc.failureExpected { if err == nil { - t.Fatal("policy enforcement should've been denied") + t.Fatal("scratch mount should've been denied") } } else { if err != nil { - t.Fatalf("policy enforcement unexpectedly was denied: %s", err) + t.Fatalf("scratch mount unexpectedly was denied: %s", err) } } }) @@ -4280,7 +4803,15 @@ func Test_Rego_Scratch_Unmount_Policy(t *testing.T) { t.Fatalf("unable to setup test: %s", err) } - scratchPath := generateMountTarget(testRand) + containerId := testDataGenerator.uniqueContainerID() + scratchDiskMount := getScratchDiskMountTarget(containerId) + + err = smConfig.policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchDiskMount, tc.encrypted, true, "xfs") + if err != nil { + t.Fatalf("mounting unexpectedly was denied: %s", err) + } + + scratchPath := path.Join(scratchDiskMount, guestpath.ScratchDir, containerId) err = smConfig.policy.EnforceScratchMountPolicy(gc.ctx, scratchPath, tc.encrypted) if err != nil { t.Fatalf("scratch_mount policy enforcement unexpectedly was denied: %s", err) @@ -4290,6 +4821,11 @@ func Test_Rego_Scratch_Unmount_Policy(t *testing.T) { if err != nil { t.Fatalf("scratch_unmount policy enforcement unexpectedly was denied: %s", err) } + + err = smConfig.policy.EnforceRWDeviceUnmountPolicy(gc.ctx, scratchDiskMount) + if err != nil { + t.Fatalf("device_unmount policy enforcement unexpectedly was denied: %s", err) + } }) } } @@ -4798,7 +5334,7 @@ func Test_FrameworkVersion_Missing(t *testing.T) { layerPaths, err := testDataGenerator.createValidOverlayForContainer(tc.policy, c) - err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layerPaths, testDataGenerator.uniqueLayerMountTarget()) if err == nil { t.Error("unexpected success. Missing framework_version should trigger an error.") } @@ -4834,7 +5370,8 @@ func Test_FrameworkVersion_In_Future(t *testing.T) { layerPaths, err := testDataGenerator.createValidOverlayForContainer(tc.policy, c) - err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layerPaths, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + gc.ctx, containerID, layerPaths, getOverlayMountTarget(containerID)) if err == nil { t.Error("unexpected success. Future framework_version should trigger an error.") } @@ -5799,7 +6336,8 @@ func Test_Rego_ErrorTruncation_Unable(t *testing.T) { maxErrorMessageLength := 32 tc.policy.maxErrorMessageLength = maxErrorMessageLength - err = tc.policy.EnforceOverlayMountPolicy(gc.ctx, tc.containerID, tc.layers, testDataGenerator.uniqueMountTarget()) + err = tc.policy.EnforceOverlayMountPolicy( + gc.ctx, tc.containerID, tc.layers, getOverlayMountTarget(tc.containerID)) if err == nil { t.Fatal("Policy did not throw the expected error") diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index 59c3780638..c9782e0d63 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -59,7 +59,9 @@ func init() { type SecurityPolicyEnforcer interface { EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) (err error) + EnforceRWDeviceMountPolicy(ctx context.Context, target string, encrypted, ensureFilesystem bool, filesystem string) (err error) EnforceDeviceUnmountPolicy(ctx context.Context, unmountTarget string) (err error) + EnforceRWDeviceUnmountPolicy(ctx context.Context, unmountTarget string) (err error) EnforceOverlayMountPolicy(ctx context.Context, containerID string, layerPaths []string, target string) (err error) EnforceOverlayUnmountPolicy(ctx context.Context, target string) (err error) EnforceCreateContainerPolicy( @@ -200,10 +202,18 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceDeviceMountPolicy(context.Context, return nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceRWDeviceMountPolicy(context.Context, string, bool, bool, string) error { + return nil +} + func (OpenDoorSecurityPolicyEnforcer) EnforceDeviceUnmountPolicy(context.Context, string) error { return nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceRWDeviceUnmountPolicy(context.Context, string) error { + return nil +} + func (OpenDoorSecurityPolicyEnforcer) EnforceOverlayMountPolicy(context.Context, string, []string, string) error { return nil } @@ -317,10 +327,18 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceDeviceMountPolicy(context.Context return errors.New("mounting is denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceRWDeviceMountPolicy(context.Context, string, bool, bool, string) error { + return errors.New("Read-write device mounting is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) EnforceDeviceUnmountPolicy(context.Context, string) error { return errors.New("unmounting is denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceRWDeviceUnmountPolicy(context.Context, string) error { + return errors.New("Read-write device unmounting is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) EnforceOverlayMountPolicy(context.Context, string, []string, string) error { return errors.New("creating an overlay fs is denied by policy") } diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index bb2fc27530..276fb7b9bf 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -170,7 +170,7 @@ func newRegoPolicy(code string, defaultMounts []oci.Mount, privilegedMounts []oc return policy, nil } -func (policy *regoEnforcer) applyDefaults(enforcementPoint string, results rpi.RegoQueryResult) (rpi.RegoQueryResult, error) { +func (policy *regoEnforcer) applyDefaults(enforcementPoint string, input inputData, results rpi.RegoQueryResult) (rpi.RegoQueryResult, error) { deny := rpi.RegoQueryResult{"allowed": false} info, err := policy.queryEnforcementPoint(enforcementPoint) if err != nil { @@ -182,12 +182,22 @@ func (policy *regoEnforcer) applyDefaults(enforcementPoint string, results rpi.R return deny, fmt.Errorf("rule for %s is missing from policy", enforcementPoint) } + if results.IsEmpty() && info.useFramework { + rule := "data.framework." + enforcementPoint + result, err := policy.rego.Query(rule, input) + if err != nil { + result = nil + } + return result, err + } + return info.defaultResults.Union(results), nil } type enforcementPointInfo struct { availableByPolicyVersion bool defaultResults rpi.RegoQueryResult + useFramework bool } func (policy *regoEnforcer) queryEnforcementPoint(enforcementPoint string) (*enforcementPointInfo, error) { @@ -230,17 +240,23 @@ func (policy *regoEnforcer) queryEnforcementPoint(enforcementPoint string) (*enf defaultResults, err := result.Object("default_results") if err != nil { - return nil, errors.New("enforcement point result missing defaults") + return nil, fmt.Errorf("enforcement point %s result missing defaults", enforcementPoint) } availableByPolicyVersion, err := result.Bool("available") if err != nil { - return nil, errors.New("enforcement point result missing availability info") + return nil, fmt.Errorf("enforcement point %s result missing availability info", enforcementPoint) + } + + useFramework, err := result.Bool("use_framework") + if err != nil { + return nil, fmt.Errorf("enforcement point %s result missing use_framework info", enforcementPoint) } return &enforcementPointInfo{ availableByPolicyVersion: availableByPolicyVersion, defaultResults: defaultResults, + useFramework: useFramework, }, nil } @@ -251,7 +267,7 @@ func (policy *regoEnforcer) enforce(ctx context.Context, enforcementPoint string return nil, policy.denyWithError(ctx, err, input) } - result, err = policy.applyDefaults(enforcementPoint, result) + result, err = policy.applyDefaults(enforcementPoint, input, result) if err != nil { return result, policy.denyWithError(ctx, err, input) } @@ -486,15 +502,34 @@ func (policy *regoEnforcer) redactSensitiveData(input inputData) inputData { } func (policy *regoEnforcer) EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) error { + mountPathRegex := strings.Replace(guestpath.LCOWGlobalScsiMountPrefixFmt, "%d", "[0-9]+", 1) input := inputData{ - "target": target, - "deviceHash": deviceHash, + "target": target, + "deviceHash": deviceHash, + "mountPathRegex": mountPathRegex, } _, err := policy.enforce(ctx, "mount_device", input) return err } +func (policy *regoEnforcer) EnforceRWDeviceMountPolicy(ctx context.Context, target string, encrypted, ensureFilesystem bool, filesystem string) error { + // At this point we do not know what the container ID would be, so we allow + // any valid IDs. + containerIdRegex := "[0-9a-fA-F]{64}" + mountPathRegex := guestpath.LCOWRootPrefixInUVM + "/" + containerIdRegex + input := inputData{ + "target": target, + "encrypted": encrypted, + "ensureFilesystem": ensureFilesystem, + "filesystem": filesystem, + "mountPathRegex": mountPathRegex, + } + + _, err := policy.enforce(ctx, "rw_mount_device", input) + return err +} + func (policy *regoEnforcer) EnforceOverlayMountPolicy(ctx context.Context, containerID string, layerPaths []string, target string) error { input := inputData{ "containerID": containerID, @@ -768,6 +803,15 @@ func (policy *regoEnforcer) EnforceDeviceUnmountPolicy(ctx context.Context, unmo return err } +func (policy *regoEnforcer) EnforceRWDeviceUnmountPolicy(ctx context.Context, unmountTarget string) error { + input := inputData{ + "unmountTarget": unmountTarget, + } + + _, err := policy.enforce(ctx, "rw_unmount_device", input) + return err +} + func appendMountData(mountData []interface{}, mounts []oci.Mount) []interface{} { for _, mount := range mounts { mountData = append(mountData, inputData{ diff --git a/test/functional/lcow_policy_test.go b/test/functional/lcow_policy_test.go index 43c5fd2090..cf0c857e08 100644 --- a/test/functional/lcow_policy_test.go +++ b/test/functional/lcow_policy_test.go @@ -4,7 +4,9 @@ package functional import ( "context" + "encoding/hex" "fmt" + "math/rand" "testing" ctrdoci "github.com/containerd/containerd/v2/pkg/oci" @@ -22,6 +24,15 @@ import ( testuvm "github.com/Microsoft/hcsshim/test/pkg/uvm" ) +func genValidContainerID(t *testing.T, rng *rand.Rand) string { + t.Helper() + randBytes := make([]byte, 32) + if _, err := rng.Read(randBytes); err != nil { + t.Fatalf("failed to generate random bytes for container ID: %v", err) + } + return hex.EncodeToString(randBytes) +} + func setupScratchTemplate(ctx context.Context, tb testing.TB) string { tb.Helper() opts := defaultLCOWOptions(ctx, tb) @@ -43,6 +54,8 @@ func TestGetProperties_WithPolicy(t *testing.T) { ctx := util.Context(namespacedContext(context.Background()), t) scratchPath := setupScratchTemplate(ctx, t) + rng := rand.New(rand.NewSource(0)) + ls := linuxImageLayers(ctx, t) for _, allowProperties := range []bool{true, false} { t.Run(fmt.Sprintf("AllowPropertiesAccess_%t", allowProperties), func(t *testing.T) { @@ -61,21 +74,24 @@ func TestGetProperties_WithPolicy(t *testing.T) { ) opts.SecurityPolicyEnforcer = "rego" opts.SecurityPolicy = policy + // VPMem is not currently supported for C-LCOW. + opts.VPMemDeviceCount = 0 - cleanName := util.CleanName(t) + containerID := genValidContainerID(t, rng) vm := testuvm.CreateAndStartLCOWFromOpts(ctx, t, opts) spec := testoci.CreateLinuxSpec( ctx, t, - cleanName, + containerID, testoci.DefaultLinuxSpecOpts( "", ctrdoci.WithProcessArgs("/bin/sh", "-c", testoci.TailNullArgs), + ctrdoci.WithEnv(testoci.DefaultUnixEnv), testoci.WithWindowsLayerFolders(append(ls, scratchPath)), )..., ) - c, _, cleanup := testcontainer.Create(ctx, t, vm, spec, cleanName, hcsOwner) + c, _, cleanup := testcontainer.Create(ctx, t, vm, spec, containerID, hcsOwner) t.Cleanup(cleanup) init := testcontainer.Start(ctx, t, c, nil) From 7ceddb4d3ca68ff8fc97dabc96009634a963a018 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 10 Nov 2025 14:51:27 +0000 Subject: [PATCH 02/12] Extend and apply checkValidContainerID to virtual pod IDs as well Since the virtual pod ID is used to construct paths, it needs to be validated in the same way we validate container/sandbox IDs. This only applies to confidential. Signed-off-by: Tingmao Wang --- internal/guest/runtime/hcsv2/uvm.go | 36 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index c57079cfb0..45386582f4 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -61,20 +61,17 @@ const UVMContainerID = "00000000-0000-0000-0000-000000000000" // to check that sandbox IDs (which is also used in paths) are valid, which has // the same format. const validContainerIDRegexRaw = `[0-9a-fA-F]{64}` + var validContainerIDRegex = regexp.MustCompile("^" + validContainerIDRegexRaw + "$") -// isSandboxId just changes the error message -func checkValidContainerID(id string, isSandboxId bool) error { +// idType just changes the error message +func checkValidContainerID(id string, idType string) error { if id == UVMContainerID { return nil } if !validContainerIDRegex.MatchString(id) { - idtype := "container" - if isSandboxId { - idtype = "sandbox" - } - return errors.Errorf("invalid %s id: %s (must match %s)", idtype, id, validContainerIDRegex.String()) + return errors.Errorf("invalid %s id: %s (must match %s)", idType, id, validContainerIDRegex.String()) } return nil @@ -313,17 +310,22 @@ func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHost } func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { - if h.HasSecurityPolicy() { - if err = checkValidContainerID(id, false); err != nil { - return nil, err - } - } - criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // Check for virtual pod annotation virtualPodID, isVirtualPod := settings.OCISpecification.Annotations[annotations.VirtualPodID] + if h.HasSecurityPolicy() { + if err = checkValidContainerID(id, "container"); err != nil { + return nil, err + } + if virtualPodID != "" { + if err = checkValidContainerID(virtualPodID, "virtual pod"); err != nil { + return nil, err + } + } + } + // Special handling for virtual pod sandbox containers: // The first container in a virtual pod (containerID == virtualPodID) should be treated as a sandbox // even if the CRI annotation might indicate otherwise due to host-side UVM setup differences @@ -467,7 +469,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM sid, ok := settings.OCISpecification.Annotations[annotations.KubernetesSandboxID] sandboxID = sid if h.HasSecurityPolicy() { - if err = checkValidContainerID(sid, true); err != nil { + if err = checkValidContainerID(sid, "sandbox"); err != nil { return nil, err } } @@ -690,7 +692,7 @@ func writeSpecToFile(ctx context.Context, configFile string, spec *specs.Spec) e func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) (retErr error) { if h.HasSecurityPolicy() { - if err := checkValidContainerID(containerID, false); err != nil { + if err := checkValidContainerID(containerID, "container"); err != nil { return err } } @@ -780,7 +782,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * func (h *Host) modifyContainerSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) error { if h.HasSecurityPolicy() { - if err := checkValidContainerID(containerID, false); err != nil { + if err := checkValidContainerID(containerID, "container"); err != nil { return err } } @@ -1306,7 +1308,7 @@ func modifyCombinedLayers( switch rt { case guestrequest.RequestTypeAdd: if isConfidential { - if err := checkValidContainerID(containerID, false); err != nil { + if err := checkValidContainerID(containerID, "container"); err != nil { return err } From 304c4308071c60c11eaad37d879c885ed61f1e2b Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 14 Jul 2025 21:13:39 +0100 Subject: [PATCH 03/12] Merged PR 12916199: hcsv2/uvm: Check that OCISpecification.Hooks is nil for confidential [cherry-picked from 73a7151ad8f15aad9f0e4eef6b86ed1a1a5a0953] This fixes a vulnerability where the host can provide arbitrary commands as hooks in the OCISpecification, which will get executed when passed to runc. Normally, ociSpec.Hooks is cleared explicitly by the host, and so we should never get passed hooks from a genuine host. The only case where OCI hooks would be set is if we are running a GPU container, in which case guest-side code in GCS will add hooks to set up the device - see addNvidiaDeviceHook. This restriction will also apply to the hook set this way by ourselves, however this is fine for now as we do not support Confidential GPUs. When we do support GPU confidential containers, we will need to pull this enforcement into Rego, and allow the policy to determine whether the hook should be allowed (along with enforcing which devices are allowed to be exposed to the container). Even after this fix, we still have unenforced fields in the OCISpecification that we receive from the host, that may allow the host to, for example, add device nodes, and we should harden this further in a future release. Tested using the following host-side reproduction patch: ```diff diff --git a/internal/hcsoci/hcsdoc_lcow.go b/internal/hcsoci/hcsdoc_lcow.go index db94d73df..a270c13ac 100644 --- a/internal/hcsoci/hcsdoc_lcow.go +++ b/internal/hcsoci/hcsdoc_lcow.go @@ -94,6 +94,22 @@ func createLinuxContainerDocument(ctx context.Context, coi *createOptionsInterna return nil, err } + isSandbox := spec.Annotations[annotations.KubernetesContainerType] == "sandbox" + + if !isSandbox { + if spec.Hooks == nil { + spec.Hooks = &specs.Hooks{} + } + timeout := 5000000 + spec.Hooks.StartContainer = append(spec.Hooks.StartContainer, specs.Hook{ + Path: "/bin/sh", + Args: []string{"/bin/sh", "-c", "echo hacked by hooks > /hacked.txt"}, + Env: []string{"PATH=/usr/local/bin:/usr/bin:/bin:/sbin"}, + Timeout: &timeout, + }) + log.G(ctx).Info("Injected hook into OCISpec") + } + log.G(ctx).WithField("guestRoot", guestRoot).Debug("hcsshim::createLinuxContainerDoc") return &linuxHostedSystem{ SchemaVersion: schemaversion.SchemaV21(), ``` Non-confidential scenarios should not be affected. Signed-off-by: Tingmao Wang --- internal/guest/runtime/hcsv2/uvm.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 45386582f4..e748260650 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -306,6 +306,10 @@ func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHost scratchDirPath, expectedScratchDirPathNonShared, expectedScratchDirPathShared) } + if settings.OCISpecification.Hooks != nil { + return errors.Errorf("OCISpecification.Hooks must be nil.") + } + return nil } From 9ee9acd2f740db4723629771f4c1750e58a28617 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 10 Nov 2025 15:10:23 +0000 Subject: [PATCH 04/12] Merged PR 12936349: rego: Harden fragments loading [cherry-picked from 5469e5c013e08dfca8aaf31dae8f617831fd5ac7] For an arbitrary fragment offered by the host, currently we need to interpret the fragment code as a new Rego module in order for the framework to check it. However, we did not check that the namespace of the fragment, which is used as the namespace of the module, does not conflict with any of our own namespaces (e.g. framework, policy, or api). This means that as soon as we load it for policy checking, a malicious fragment could override enforcement points defined in the framework, including "load_fragment", and allow itself in, even if we would otherwise have denied it because of wrong issuer, etc. This can be tested with the below fragment: package framework svn := "1" framework_version := "0.3.1" load_fragment := {"allowed": true, "add_module": true} enforcement_point_info := { "available": true, "unknown": false, "invalid": false, "version_missing": false, "default_results": {"allowed": true}, "use_framework": true } mount_device := {"allowed": true} rw_mount_device := {"allowed": true} mount_overlay := {"allowed": true} create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} unmount_device := {"allowed": true} rw_unmount_device := {"allowed": true} unmount_overlay := {"allowed": true} exec_in_container := {"allowed": true, "env_list": null} exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} shutdown_container := {"allowed": true} signal_container_process := {"allowed": true} plan9_mount := {"allowed": true} plan9_unmount := {"allowed": true} get_properties := {"allowed": true} dump_stacks := {"allowed": true} runtime_logging := {"allowed": true} scratch_mount := {"allowed": true} scratch_unmount := {"allowed": true} However, our namespace parsing was also suspectable to attacks, and we might end up concluding that the fragment namespace is safe even if when the Rego engine parses it, it sees a package of "framework". This can be demonstrated with the following fragment: package #aa framework svn := "0" framework_version := "0.3.1" load_fragment := {"allowed": true, "add_module": true} And so we also ensure that the `package ...` line can't contain anything unusual. Furthermore, since it can still be risky to inject arbitrary, untrusted Rego code into our execution context, we add another check prior to loading the fragment as a module, where we make sure that the fragment has to first pass the issuer and feed check, which do not require loading it. Signed-off-by: Tingmao Wang --- pkg/securitypolicy/framework.rego | 34 +- .../policy_v0.10.0_api_test.rego | 13 + .../policy_v0.10.0_api_test_allow_all.rego | 22 + pkg/securitypolicy/rego_utils_test.go | 11 +- pkg/securitypolicy/regopolicy_linux_test.go | 535 +++++++++++++++++- .../securitypolicyenforcer_rego.go | 75 ++- 6 files changed, 644 insertions(+), 46 deletions(-) create mode 100644 pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index 20e0e8b067..1be2625048 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -1209,8 +1209,6 @@ candidate_fragments := fragments { fragments := array.concat(policy_fragments, fragment_fragments) } -default load_fragment := {"allowed": false} - svn_ok(svn, minimum_svn) { # deprecated semver.is_valid(svn) @@ -1222,15 +1220,32 @@ svn_ok(svn, minimum_svn) { to_number(svn) >= to_number(minimum_svn) } -fragment_ok(fragment) { +fragment_issuer_feed_ok(fragment) { input.issuer == fragment.issuer input.feed == fragment.feed - svn_ok(data[input.namespace].svn, fragment.minimum_svn) +} + +default load_fragment := {"allowed": false} + +# load_fragment gets called twice - first before loading the fragment as a Rego +# module, with input.fragment_loaded set to false, in which case we do not yet +# have access to anything under data[fragment.namespace] yet, and so we only +# check that the fragment issuer and feed is valid, but does not actually load +# the fragment into metadata. It will then be called a second time, at which +# point we can check the SVN defined in the fragment is valid, and if +# successful, add the fragment to the metadata. + +load_fragment := {"allowed": true} { + not input.fragment_loaded + some fragment in candidate_fragments + fragment_issuer_feed_ok(fragment) } load_fragment := {"metadata": [updateIssuer], "add_module": add_module, "allowed": true} { + input.fragment_loaded some fragment in candidate_fragments - fragment_ok(fragment) + fragment_issuer_feed_ok(fragment) + svn_ok(data[input.namespace].svn, fragment.minimum_svn) issuer := update_issuer(fragment.includes) updateIssuer := { @@ -1641,6 +1656,7 @@ default fragment_version_is_valid := false fragment_version_is_valid { some fragment in candidate_fragments + input.fragment_loaded fragment.issuer == input.issuer fragment.feed == input.feed svn_ok(data[input.namespace].svn, fragment.minimum_svn) @@ -1652,6 +1668,7 @@ svn_mismatch { some fragment in candidate_fragments fragment.issuer == input.issuer fragment.feed == input.feed + input.fragment_loaded to_number(data[input.namespace].svn) semver.is_valid(fragment.minimum_svn) } @@ -1660,6 +1677,7 @@ svn_mismatch { some fragment in candidate_fragments fragment.issuer == input.issuer fragment.feed == input.feed + input.fragment_loaded semver.is_valid(data[input.namespace].svn) to_number(fragment.minimum_svn) } @@ -1667,6 +1685,7 @@ svn_mismatch { errors["fragment svn is below the specified minimum"] { input.rule == "load_fragment" fragment_feed_matches + input.fragment_loaded not svn_mismatch not fragment_version_is_valid } @@ -1674,6 +1693,7 @@ errors["fragment svn is below the specified minimum"] { errors["fragment svn and the specified minimum are different types"] { input.rule == "load_fragment" fragment_feed_matches + input.fragment_loaded svn_mismatch } @@ -1704,12 +1724,16 @@ errors[framework_version_error] { } errors[fragment_framework_version_error] { + input.rule == "load_fragment" + input.fragment_loaded input.namespace fragment_framework_version == null fragment_framework_version_error := concat(" ", ["fragment framework_version is missing. Current version:", version]) } errors[fragment_framework_version_error] { + input.rule == "load_fragment" + input.fragment_loaded input.namespace semver.compare(fragment_framework_version, version) > 0 fragment_framework_version_error := concat(" ", ["fragment framework_version is ahead of the current version:", fragment_framework_version, "is greater than", version]) diff --git a/pkg/securitypolicy/policy_v0.10.0_api_test.rego b/pkg/securitypolicy/policy_v0.10.0_api_test.rego index a5ea9c56d4..407c3ee8ff 100644 --- a/pkg/securitypolicy/policy_v0.10.0_api_test.rego +++ b/pkg/securitypolicy/policy_v0.10.0_api_test.rego @@ -3,6 +3,19 @@ package policy api_version := "0.10.0" framework_version := "0.3.0" +fragments := [ + { + "feed": "@@FRAGMENT_FEED@@", + "includes": [ + "containers", + "fragments" + ], + "issuer": "@@FRAGMENT_ISSUER@@", + "minimum_svn": "0" + } +] + + containers := [ { "allow_elevated": false, diff --git a/pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego b/pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego new file mode 100644 index 0000000000..dccdba0dec --- /dev/null +++ b/pkg/securitypolicy/policy_v0.10.0_api_test_allow_all.rego @@ -0,0 +1,22 @@ +package policy + +api_version := "0.10.0" +framework_version := "0.3.0" + +mount_device := {"allowed": true} +mount_overlay := {"allowed": true} +create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} +unmount_device := {"allowed": true} +unmount_overlay := {"allowed": true} +exec_in_container := {"allowed": true, "env_list": null} +exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} +shutdown_container := {"allowed": true} +signal_container_process := {"allowed": true} +plan9_mount := {"allowed": true} +plan9_unmount := {"allowed": true} +get_properties := {"allowed": true} +dump_stacks := {"allowed": true} +runtime_logging := {"allowed": true} +load_fragment := {"allowed": true} +scratch_mount := {"allowed": true} +scratch_unmount := {"allowed": true} diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index 2e357f071a..5ac12a5a0a 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -1628,8 +1628,15 @@ var apiTestCode string //go:embed policy_v0.10.0_api_test.rego var policyWith_0_10_0_apiTestCode string -func getPolicyCode_0_10_0(layerHash string) string { - return strings.Replace(policyWith_0_10_0_apiTestCode, "@@CONTAINER_LAYER_HASH@@", layerHash, 1) +//go:embed policy_v0.10.0_api_test_allow_all.rego +var policyWith_0_10_0_apiTestAllowAllCode string + +func getPolicyCode_0_10_0(layerHash, fragmentIssuer, fragmentFeed string) string { + s := policyWith_0_10_0_apiTestCode + s = strings.Replace(s, "@@CONTAINER_LAYER_HASH@@", layerHash, 1) + s = strings.Replace(s, "@@FRAGMENT_ISSUER@@", fragmentIssuer, 1) + s = strings.Replace(s, "@@FRAGMENT_FEED@@", fragmentFeed, 1) + return s } func (p *regoEnforcer) injectTestAPI() error { diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index bec4f27f07..e83a6d30e4 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -564,30 +564,7 @@ func Test_Rego_EnforceRWDeviceMountPolicy_InvalidFilesystem(t *testing.T) { // allowing enforcing rw mounts. func Test_Rego_EnforceRWDeviceMountPolicy_Compat_0_10_0_allow_all(t *testing.T) { p := generateConstraints(testRand, 1) - regoPolicy := ` -package policy - -api_version := "0.10.0" -framework_version := "0.3.0" - -mount_device := {"allowed": true} -mount_overlay := {"allowed": true} -create_container := {"allowed": true, "env_list": null, "allow_stdio_access": true} -unmount_device := {"allowed": true} -unmount_overlay := {"allowed": true} -exec_in_container := {"allowed": true, "env_list": null} -exec_external := {"allowed": true, "env_list": null, "allow_stdio_access": true} -shutdown_container := {"allowed": true} -signal_container_process := {"allowed": true} -plan9_mount := {"allowed": true} -plan9_unmount := {"allowed": true} -get_properties := {"allowed": true} -dump_stacks := {"allowed": true} -runtime_logging := {"allowed": true} -load_fragment := {"allowed": true} -scratch_mount := {"allowed": true} -scratch_unmount := {"allowed": true} -` + regoPolicy := policyWith_0_10_0_apiTestAllowAllCode for _, b1 := range []bool{true, false} { for _, b2 := range []bool{true, false} { policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) @@ -608,7 +585,7 @@ scratch_unmount := {"allowed": true} // enforcing rw mounts. func Test_Rego_EnforceRWDeviceMountPolicy_Compat_0_10_0(t *testing.T) { p := generateConstraints(testRand, 1) - regoPolicy := getPolicyCode_0_10_0(p.containers[0].Layers[0]) + regoPolicy := getPolicyCode_0_10_0(p.containers[0].Layers[0], testDataGenerator.uniqueFragmentIssuer(), testDataGenerator.uniqueFragmentFeed()) for _, b1 := range []bool{true, false} { for _, b2 := range []bool{true, false} { policy, err := newRegoPolicy(regoPolicy, []oci.Mount{}, []oci.Mount{}, testOSType) @@ -4088,6 +4065,134 @@ func Test_Rego_LoadFragment_Container(t *testing.T) { } } +// Make sure we don't break fragment loading for old policies +func Test_Rego_LoadFragment_Container_Compat_0_10_0(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentTestConfigWithIncludes(p, []string{"containers"}) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + container := tc.containers[0] + rego := getPolicyCode_0_10_0(container.container.Layers[0], fragment.info.issuer, fragment.info.feed) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + tc.policy = policy + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) + if err != nil { + t.Error("unable to load fragment: %w", err) + return false + } + + containerID, err := mountImageForContainer(tc.policy, container.container) + if err != nil { + t.Error("unable to mount image for fragment container: %w", err) + return false + } + + _, _, _, err = tc.policy.EnforceCreateContainerPolicy(p.ctx, + container.sandboxID, + containerID, + copyStrings(container.container.Command), + copyStrings(container.envList), + container.container.WorkingDir, + copyMounts(container.mounts), + false, + container.container.NoNewPrivileges, + container.user, + container.groups, + container.container.User.Umask, + container.capabilities, + container.seccomp, + ) + + if err != nil { + t.Error("unable to create container from fragment: %w", err) + return false + } + + if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + t.Error("module not removed after load") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_Container_Compat_0_10_0: %v", err) + } +} + +// Make sure we don't break fragment loading for old allow all policies +func Test_Rego_LoadFragment_Container_Compat_0_10_0_allow_all(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentTestConfigWithIncludes(p, []string{"containers"}) + if err != nil { + t.Error(err) + return false + } + + rego := policyWith_0_10_0_apiTestAllowAllCode + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + tc.policy = policy + + fragment := tc.fragments[0] + container := tc.containers[0] + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, fragment.code) + if err != nil { + t.Error("unable to load fragment: %w", err) + return false + } + + containerID, err := mountImageForContainer(tc.policy, container.container) + if err != nil { + t.Error("unable to mount image for fragment container: %w", err) + return false + } + + _, _, _, err = tc.policy.EnforceCreateContainerPolicy(p.ctx, + container.sandboxID, + containerID, + copyStrings(container.container.Command), + copyStrings(container.envList), + container.container.WorkingDir, + copyMounts(container.mounts), + false, + container.container.NoNewPrivileges, + container.user, + container.groups, + container.container.User.Umask, + container.capabilities, + container.seccomp, + ) + + if err != nil { + t.Error("unable to create container from fragment: %w", err) + return false + } + + if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + t.Error("module not removed after load") + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_Container_Compat_0_10_0: %v", err) + } +} + func Test_Rego_LoadFragment_Fragment(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoFragmentTestConfigWithIncludes(p, []string{"fragments"}) @@ -4236,6 +4341,173 @@ func Test_Rego_LoadFragment_BadFeed(t *testing.T) { } } +func Test_Rego_parseNamespace(t *testing.T) { + type testCase struct { + inputs []string + expected string + expectFail bool + } + testCases := []testCase{ + { + inputs: []string{ + "package a\nanything-else", + "package a \n\n", + "package a ", + }, + expected: "a", + }, + { + inputs: []string{ + "package aaa", + "package aaa ", + "package aaa\n# anything", + }, + expected: "aaa", + }, + { + inputs: []string{ + "package", + "package\n", + "package ", + "package ", + "package$", + "package aa#bb\nframework", + "package\naa\n", + }, + expectFail: true, + }, + { + inputs: []string{ + "package framework", + "package api", + }, + expectFail: true, + }, + } + + for _, tc := range testCases { + for _, input := range tc.inputs { + result, err := parseNamespace(input) + if tc.expectFail && err == nil { + t.Errorf("Expected failure for input %q, but got success", input) + } else if !tc.expectFail && err != nil { + t.Errorf("Unexpected error for input %q: %v", input, err) + } else if !tc.expectFail && result != tc.expected { + t.Errorf("Expected to parse namespace %q for input %q, but got %q", tc.expected, input, result) + } + } + } +} + +func expectFragmentNotLoaded(t *testing.T, policy *regoEnforcer, issuer, feed string) bool { + if policy.rego.IsModuleActive(rpi.ModuleID(issuer, feed)) { + t.Errorf("fragment module is present") + return false + } + mtdIssuer, err := policy.rego.GetMetadata("issuers", issuer) + if err != nil && !strings.Contains(err.Error(), "value not found") && + !strings.Contains(err.Error(), "metadata not found for name issuers") { + t.Errorf("unexpected error when checking issuer metadata: %v", err) + return false + } + if mtdIssuer != nil || err == nil { + t.Errorf("fragment issuer metadata is present") + return false + } + return true +} + +func Test_Rego_LoadFragment_BadNamespace(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + code := fmt.Sprintf(`package framework + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +enforcement_point_info := { + "available": true, + "unknown": false, + "invalid": false, + "version_missing": false, + "default_results": {"allowed": true}, + "use_framework": true +} +`, fragment.info.minimumSVN, frameworkVersion) + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, code) + + if err == nil { + t.Error("expected to be unable to load fragment due to bad namespace") + return false + } + + if !strings.Contains(err.Error(), "namespace \"framework\" is reserved") { + t.Errorf("expected error string to contain 'namespace \"framework\" is reserved', but got %q", err.Error()) + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadNamespace: %v", err) + } +} + +func Test_Rego_LoadFragment_BadNamespace2(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + code := fmt.Sprintf(`package #aa +framework + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +`, fragment.info.minimumSVN, frameworkVersion) + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, code) + + if err == nil { + t.Error("expected to be unable to load fragment due to invalid namespace") + return false + } + + if !strings.Contains(err.Error(), "valid package definition required on first line") { + t.Errorf("expected error string to contain 'valid package definition required on first line', but got %q", err.Error()) + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadNamespace: %v", err) + } +} + func Test_Rego_LoadFragment_InvalidSVN(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoFragmentSVNErrorTestConfig(p) @@ -4256,7 +4528,7 @@ func Test_Rego_LoadFragment_InvalidSVN(t *testing.T) { return false } - if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { t.Error("module not removed upon failure") return false } @@ -4369,7 +4641,7 @@ func Test_Rego_LoadFragment_SVNMismatch(t *testing.T) { return false } - if tc.policy.rego.IsModuleActive(rpi.ModuleID(fragment.info.issuer, fragment.info.feed)) { + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { t.Error("module not removed upon failure") return false } @@ -4671,10 +4943,16 @@ framework_version := "%s" default load_fragment := {"allowed": false} +check_svn_if_loaded { + not input.fragment_loaded +} else { + data[input.namespace].svn >= 1 +} + load_fragment := {"allowed": true, "add_module": true} { input.issuer == "%s" input.feed == "%s" - data[input.namespace].svn >= 1 + check_svn_if_loaded } mount_device := data.fragment.mount_device @@ -4705,6 +4983,207 @@ mount_device := data.fragment.mount_device } } +func Test_Rego_LoadFragment_BadIssuer_AttemptOverrideFrameworkItems(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + expectedIssuer := fragment.info.issuer + actualIssuer := testDataGenerator.uniqueFragmentIssuer() + code := fmt.Sprintf(`package fragment + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +data.framework.load_fragment := {"allowed": true, "add_module": true} +input.issuer := "%s" +data.framework.input.issuer := "%s" +`, fragment.info.minimumSVN, frameworkVersion, expectedIssuer, expectedIssuer) + + err = tc.policy.LoadFragment(p.ctx, actualIssuer, fragment.info.feed, code) + + if !assertDecisionJSONContains(t, err, "invalid fragment issuer") { + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadIssuer_AttemptOverrideFrameworkItems: %v", err) + } +} + +// The intent of this test is really to check that Rego module names are +// case-sensitive, since we do not deny a fragment from having a namespace +// "Framework" or the like. We use svn mismatch here since otherwise the +// enforcer will not even try to load the fragment module at all if issuer or +// feed is wrong. But in reality, if an attacker can sign fragments with the +// correct issuer, they can make the fragment have any SVN they want. +func Test_Rego_LoadFragment_BadSvn_FrameworkNamespaceCaseConfusion(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupRegoFragmentSVNErrorTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + code := fmt.Sprintf(`package Framework + +svn := "%s" +framework_version := "%s" + +load_fragment := {"allowed": true, "add_module": true} +enforcement_point_info := { + "available": true, + "unknown": false, + "invalid": false, + "version_missing": false, + "default_results": {"allowed": true}, + "use_framework": true +} +data.framework.load_fragment := load_fragment +`, fragment.constraints.svn, frameworkVersion) + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, fragment.info.feed, code) + + if !assertDecisionJSONContains(t, err, "fragment svn is below the specified minimum") { + return false + } + + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, fragment.info.feed) { + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadSvn_FrameworkNamespaceCaseConfusion: %v", err) + } +} + +func Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + actualIssuer := testDataGenerator.uniqueFragmentIssuer() + code := "package fragment\n!invalid!rego" + + err = tc.policy.LoadFragment(p.ctx, actualIssuer, fragment.info.feed, code) + + if strings.Contains(err.Error(), "error when compiling module") || + !assertDecisionJSONDoesNotContain(t, err, "error when compiling module") { + t.Errorf("expected error to not contain 'error when compiling module', got: %s", err.Error()) + return false + } + if !assertDecisionJSONDoesNotContain(t, err, "fragment framework_version is missing") { + return false + } + if !assertDecisionJSONContains(t, err, "invalid fragment issuer") { + return false + } + if !expectFragmentNotLoaded(t, tc.policy, actualIssuer, fragment.info.feed) { + return false + } + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego: %v", err) + } +} + +func Test_Rego_LoadFragment_BadFeed_MustNotTryToLoadRego(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + + fragment := tc.fragments[0] + actualFeed := testDataGenerator.uniqueFragmentFeed() + code := "package fragment\n!invalid!rego" + + err = tc.policy.LoadFragment(p.ctx, fragment.info.issuer, actualFeed, code) + + if strings.Contains(err.Error(), "error when compiling module") || + !assertDecisionJSONDoesNotContain(t, err, "error when compiling module") { + t.Errorf("expected error to not contain 'error when compiling module', got: %s", err.Error()) + return false + } + if !assertDecisionJSONDoesNotContain(t, err, "fragment framework_version is missing") { + return false + } + if !assertDecisionJSONContains(t, err, "invalid fragment feed") { + return false + } + if !expectFragmentNotLoaded(t, tc.policy, fragment.info.issuer, actualFeed) { + return false + } + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadFeed_MustNotTryToLoadRego: %v", err) + } +} + +func Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego_Compat_0_10_0(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoFragmentTestConfig(p) + if err != nil { + t.Error(err) + return false + } + rego := getPolicyCode_0_10_0(tc.containers[0].container.Layers[0], tc.fragments[0].info.issuer, tc.fragments[0].info.feed) + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("unable to create Rego policy: %v", err) + } + tc.policy = policy + + fragment := tc.fragments[0] + actualIssuer := testDataGenerator.uniqueFragmentIssuer() + code := "package fragment\n!invalid!rego" + + err = tc.policy.LoadFragment(p.ctx, actualIssuer, fragment.info.feed, code) + + if strings.Contains(err.Error(), "error when compiling module") || + !assertDecisionJSONDoesNotContain(t, err, "error when compiling module") { + t.Errorf("expected error to not contain 'error when compiling module', got: %s", err.Error()) + return false + } + if !assertDecisionJSONDoesNotContain(t, err, "fragment framework_version is missing") { + return false + } + if !assertDecisionJSONContains(t, err, "invalid fragment issuer") { + return false + } + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 25, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_LoadFragment_BadIssuer_MustNotTryToLoadRego_Compat_0_10_0: %v", err) + } +} + func Test_Rego_Scratch_Mount_Policy(t *testing.T) { for _, tc := range []struct { unencryptedAllowed bool diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index 276fb7b9bf..075d5450c5 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -9,6 +9,8 @@ import ( "encoding/base64" "encoding/json" "fmt" + "regexp" + "slices" "strings" "syscall" @@ -1044,14 +1046,45 @@ func (policy *regoEnforcer) EnforceRuntimeLoggingPolicy(ctx context.Context) err return err } +// Rego identifier is a letter or underscore, followed by any number of letters, +// underscores, or digits. See open-policy-agent/opa +// ast/internal/scanner/scanner.go :: scanIdentifier, isLetter +// Technically it also allows other unicode digit characters (but not letters) +// but we do not allow those, for simplicity. +var validNamespaceRegex = `[a-zA-Z_][a-zA-Z0-9_]*` + +// First line of the fragment Rego source code must be a package definition +// without any potential for confusion attacks. We thus limit it to exactly +// "package" followed by one or more spaces, then a valid Rego identifier, then +// optionally more spaces. We do not check if the namespace is a Rego keyword +// (e.g. "in", "every" etc) but it would fail Rego compilation anyway. +var validFirstLine = regexp.MustCompile(`^package +(` + validNamespaceRegex + `)\s*$`) + +// These namespaces must not be overridden by a fragment +var reservedNamespaces []string = []string{ + // Built-in modules + "framework", + "api", + "policy", + // This is not a module, but to prevent confusion since framework uses + // data.metadata to access those, we block it as well. + "metadata", +} + func parseNamespace(rego string) (string, error) { lines := strings.Split(rego, "\n") - parts := strings.Split(lines[0], " ") - if parts[0] != "package" { - return "", errors.New("package definition required on first line") + if lines[0] == "" { + return "", errors.New("Fragment Rego is empty") } - - return strings.TrimSpace(parts[1]), nil + match := validFirstLine.FindStringSubmatch(lines[0]) + if match == nil { + return "", errors.Errorf("valid package definition required on first line, got %q", lines[0]) + } + namespace := match[1] + if slices.Contains(reservedNamespaces, namespace) { + return "", errors.Errorf("namespace %q is reserved and cannot be used for fragments", namespace) + } + return namespace, nil } func (policy *regoEnforcer) LoadFragment(ctx context.Context, issuer string, feed string, rego string) error { @@ -1067,22 +1100,42 @@ func (policy *regoEnforcer) LoadFragment(ctx context.Context, issuer string, fee Namespace: namespace, } - policy.rego.AddModule(fragment.ID(), fragment) - input := inputData{ - "issuer": issuer, - "feed": feed, - "namespace": namespace, + "issuer": issuer, + "feed": feed, + "namespace": namespace, + "fragment_loaded": false, } + // Check that the fragment is signed by the expected issuer before loading + // its Rego code. + _, err = policy.enforce(ctx, "load_fragment", input) + if err != nil { + return err + } + + // At this point we need to add the fragment code as a new Rego module in + // order for the framework (or any user defined policies) to check the SVN, + // and potentially other information defined by its Rego code. We've already + // checked that the fragment is signed correctly, and the namespace is safe + // to load (won't override framework or other built-in modules). Once we + // added the module, we must make sure the module is removed if we return + // with error (or if add_module returned from Rego is false). + policy.rego.AddModule(fragment.ID(), fragment) + input["fragment_loaded"] = true + results, err := policy.enforce(ctx, "load_fragment", input) + if err != nil { + policy.rego.RemoveModule(fragment.ID()) + return err + } addModule, _ := results.Bool("add_module") if !addModule { policy.rego.RemoveModule(fragment.ID()) } - return err + return nil } func (policy *regoEnforcer) EnforceScratchMountPolicy(ctx context.Context, scratchPath string, encrypted bool) error { From e6000878ce88fd4e28e99156e3f0e0b75c230d56 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 10 Nov 2025 15:14:07 +0000 Subject: [PATCH 05/12] Merged PR 12878106: Small rego fixes [cherry-picked from 0ca40bb4f130b3508f4a130011463070328d40d0] - rego: Fix missing error reason when mounting a rw device to an existing mount point. This fixes a missing error message introduced in the last round of security fixes. It's not hugely important, but eases debugging if we get policy denials on mounting the scratch, for whatever reason. Also adds test for it. - Remove a no-op from rego Checked with @ earlier that this basically does nothing and is just something left over. However I will not actually add a remove op for `metadata.started` for now. This PR is targeting the conf-aci branch on ADO because the commit being fixed is not on main yet. This should be backported to main together with the fixes from last month. Signed-off-by: Tingmao Wang --- pkg/securitypolicy/framework.rego | 4 ++-- pkg/securitypolicy/regopolicy_linux_test.go | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index 1be2625048..ca6721c5dc 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -901,7 +901,7 @@ exec_in_container := {"metadata": [updateMatches], default shutdown_container := {"allowed": false} -shutdown_container := {"started": remove, "metadata": [remove], "allowed": true} { +shutdown_container := {"metadata": [remove], "allowed": true} { container_started remove := { "name": "matches", @@ -1313,7 +1313,7 @@ errors["deviceHash not found"] { } errors["device already mounted at path"] { - input.rule == "mount_device" + input.rule in ["mount_device", "rw_mount_device"] device_mounted(input.target) } diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index e83a6d30e4..8dd409fccf 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -378,7 +378,7 @@ func Test_Rego_EnforceDeviceMountPolicy_InvalidMountTarget_PathTraversal(t *test assertDecisionJSONContains(t, err, "mountpoint invalid") } -func deviceMountUnmountTest(t *testing.T, p *generatedConstraints, policy *regoEnforcer, mountScratchFirst, unmountScratchFirst, testInvalidUnmount bool) bool { +func deviceMountUnmountTest(t *testing.T, p *generatedConstraints, policy *regoEnforcer, mountScratchFirst, unmountScratchFirst, testDenials bool) bool { container := selectContainerFromContainerList(p.containers, testRand) containerID := testDataGenerator.uniqueContainerID() rotarget := testDataGenerator.uniqueLayerMountTarget() @@ -414,6 +414,18 @@ func deviceMountUnmountTest(t *testing.T, p *generatedConstraints, policy *regoE } } + if testDenials { + err = policy.EnforceRWDeviceMountPolicy(p.ctx, rwtarget, true, true, "xfs") + if !assertDecisionJSONContains(t, err, "device already mounted at path") { + return false + } + + err = policy.EnforceDeviceMountPolicy(p.ctx, rotarget, container.Layers[0]) + if !assertDecisionJSONContains(t, err, "device already mounted at path") { + return false + } + } + unmountScratch := func() bool { err = policy.EnforceRWDeviceUnmountPolicy(p.ctx, rwtarget) if err != nil { @@ -442,7 +454,7 @@ func deviceMountUnmountTest(t *testing.T, p *generatedConstraints, policy *regoE } } - if testInvalidUnmount { + if testDenials { err = policy.EnforceDeviceUnmountPolicy(p.ctx, rotarget) if !assertDecisionJSONContains(t, err, "no device at path to unmount") { return false From d51d004606725cdc4a6cc7f060be300aba8d4f47 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 10 Nov 2025 16:08:47 +0000 Subject: [PATCH 06/12] Merged PR 12878455: Fix for the Rego "metadata desync" bug [cherry-picked from 421b12249544a334e36df33dc4846673b2a88279] This set of changes fixes the [Metadata Desync with UVM State](https://msazure.visualstudio.com/One/_workitems/edit/33232631/) bug, by reverting the Rego policy state on mount and some types of unmount failures. For mounts, a minor cleanup code is added to ensure we close down the dm-crypt device if we fails to mount it. Aside from this, it is relatively straightforward - if we get a failure, the clean up functions will remove the directory, remove any dm-devices, and we will revert the Rego metadata. For unmounts, careful consideration needs to be taken, since if the directory has been unmounted successfully (or even partially successful?), and we get an error, we cannot just revert the policy state, as this may allow the host to use a broken / empty mount as one of the image layer. See 615c9a0bdf's commit message for more detailed thoughts. The solution I opted for is, for non-trivial unmount failure cases (i.e. not policy denial, not invalid mountpoint), if it fails, then we will block all further mount, unmount, container creation and deletion attempts. I think this is OK since we really do not expect unmounts to fail - this is especially the case for us since the only writable disk mount we have is the shared scratch disk, which we do not unmount at all unless we're about to kill the UVM. Testing ------- The "Rollback policy state on mount errors" commit message has some instruction for making a deliberately corrupted (but still passes the verifyinfo extraction) VHD that will cause a mount error. The other way we could make mount / unmount fail, and thus test this change, is by mounting a tmpfs or ro bind in relevant places: To make unmount fail: mkdir /run/gcs/c/.../rootfs/a && mount -t tmpfs none /run/gcs/c/.../rootfs/a or mkdir /run/gcs/mounts/scsi/m1/a && mount -t tmpfs none /run/gcs/mounts/scsi/m1/a To make mount fail: mount -o ro --bind /run/mounts/scsi /run/mounts/scsi or mount --bind -o ro /run/gcs/c /run/gcs/c Once failure is triggered, one can make them work again by `umount`ing the tmpfs or ro bind. What about other operations? ---------------------------- This PR covers mount and unmount of disks, overlays and 9p. Aside from setting `metadata.matches` as part of the narrowing scheme, and except for `metadata.started` to prevent re-using a container ID, Rego does not use persistent state for any other operations. Since it's not clear whether reverting the state would be semantically correct (we would need to carefully consider exactly what are the side effects of say, failing to execute a process, start a container, or send a signal, etc), and adding the revert to those operations does not currently affect much behaviour, I've opted not to apply the metadata revert to those for now. Signed-off-by: Tingmao Wang --- internal/gcs/unrecoverable_error.go | 49 ++++ internal/guest/runtime/hcsv2/uvm.go | 252 ++++++++++++++-- internal/guest/storage/mount.go | 2 + internal/guest/storage/overlay/overlay.go | 3 +- internal/guest/storage/scsi/scsi.go | 15 +- internal/guest/storage/scsi/scsi_test.go | 6 + .../regopolicyinterpreter.go | 66 ++++- .../regopolicyinterpreter_test.go | 164 +++++++++++ pkg/securitypolicy/rego_utils_test.go | 38 ++- pkg/securitypolicy/regopolicy_linux_test.go | 273 ++++++++++++++++++ pkg/securitypolicy/securitypolicyenforcer.go | 22 ++ .../securitypolicyenforcer_rego.go | 84 ++++++ 12 files changed, 932 insertions(+), 42 deletions(-) create mode 100644 internal/gcs/unrecoverable_error.go diff --git a/internal/gcs/unrecoverable_error.go b/internal/gcs/unrecoverable_error.go new file mode 100644 index 0000000000..dbb7240266 --- /dev/null +++ b/internal/gcs/unrecoverable_error.go @@ -0,0 +1,49 @@ +//go:build linux +// +build linux + +package gcs + +import ( + "context" + "fmt" + "os" + "runtime" + "time" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" + "github.com/sirupsen/logrus" +) + +// UnrecoverableError logs the error and then puts the current thread into an +// infinite sleep loop. This is to be used instead of panicking, as the +// behaviour of GCS panics is unpredictable. This function can be extended to, +// for example, try to shutdown the VM cleanly. +func UnrecoverableError(err error) { + buf := make([]byte, 300*(1<<10)) + stackSize := runtime.Stack(buf, true) + stackTrace := string(buf[:stackSize]) + + errPrint := fmt.Sprintf( + "Unrecoverable error in GCS: %v\n%s", + err, stackTrace, + ) + isSnp := amdsevsnp.IsSNP() + if isSnp { + errPrint += "\nThis thread will now enter an infinite loop." + } + log.G(context.Background()).WithError(err).Logf( + logrus.FatalLevel, + "%s", + errPrint, + ) + + if !isSnp { + panic("Unrecoverable error in GCS: " + err.Error()) + } else { + fmt.Fprintf(os.Stderr, "%s\n", errPrint) + for { + time.Sleep(time.Hour) + } + } +} diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index e748260650..694ced53e5 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -17,6 +17,7 @@ import ( "regexp" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -113,6 +114,16 @@ type Host struct { // hostMounts keeps the state of currently mounted devices and file systems, // which is used for GCS hardening. hostMounts *hostMounts + // A permanent flag to indicate that further mounts, unmounts and container + // creation should not be allowed. This is set when, because of a failure + // during an unmount operation, we end up in a state where the policy + // enforcer's state is out of sync with what we have actually done, but we + // cannot safely revert its state. + // + // Not used in non-confidential mode. + mountsBroken atomic.Bool + // A user-friendly error message for why mountsBroken was set. + mountsBrokenCausedBy string } func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer securitypolicy.SecurityPolicyEnforcer, logWriter io.Writer) *Host { @@ -132,6 +143,7 @@ func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer s devNullTransport: &transport.DevNullTransport{}, hostMounts: newHostMounts(), securityOptions: securityPolicyOptions, + mountsBroken: atomic.Bool{}, } } @@ -313,7 +325,44 @@ func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHost return nil } +// Returns an error if h.mountsBroken is set (and we're in a confidential +// container host) +func (h *Host) checkMountsNotBroken() error { + if h.HasSecurityPolicy() && h.mountsBroken.Load() { + return errors.Errorf( + "Mount, unmount, container creation and deletion has been disabled in this UVM due to a previous error (%q)", + h.mountsBrokenCausedBy, + ) + } + return nil +} + +func (h *Host) setMountsBrokenIfConfidential(cause string) { + if !h.HasSecurityPolicy() { + return + } + h.mountsBroken.Store(true) + h.mountsBrokenCausedBy = cause + log.G(context.Background()).WithFields(logrus.Fields{ + "cause": cause, + }).Error("Host::mountsBroken set to true. All further mounts/unmounts, container creation and deletion will fail.") +} + +func checkExists(path string) (error, bool) { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return nil, false + } + return errors.Wrapf(err, "failed to determine if path '%s' exists", path), false + } + return nil, true +} + func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { + if err = h.checkMountsNotBroken(); err != nil { + return nil, err + } + criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // Check for virtual pod annotation @@ -705,6 +754,10 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * case guestresource.ResourceTypeSCSIDevice: return modifySCSIDevice(ctx, req.RequestType, req.Settings.(*guestresource.SCSIDevice)) case guestresource.ResourceTypeMappedVirtualDisk: + if err := h.checkMountsNotBroken(); err != nil { + return err + } + mvd := req.Settings.(*guestresource.LCOWMappedVirtualDisk) // find the actual controller number on the bus and update the incoming request. var cNum uint8 @@ -742,18 +795,30 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * }() } } - return modifyMappedVirtualDisk(ctx, req.RequestType, mvd, h.securityOptions.PolicyEnforcer) + return h.modifyMappedVirtualDisk(ctx, req.RequestType, mvd) case guestresource.ResourceTypeMappedDirectory: - return modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory), h.securityOptions.PolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory)) case guestresource.ResourceTypeVPMemDevice: - return modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice), h.securityOptions.PolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice)) case guestresource.ResourceTypeCombinedLayers: + if err := h.checkMountsNotBroken(); err != nil { + return err + } + cl := req.Settings.(*guestresource.LCOWCombinedLayers) // when cl.ScratchPath == "", we mount overlay as read-only, in which case // we don't really care about scratch encryption, since the host already // knows about the layers and the overlayfs. encryptedScratch := cl.ScratchPath != "" && h.hostMounts.IsEncrypted(cl.ScratchPath) - return modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch, h.securityOptions.PolicyEnforcer) + return h.modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch) case guestresource.ResourceTypeNetwork: return modifyNetwork(ctx, req.RequestType, req.Settings.(*guestresource.LCOWNetworkAdapter)) case guestresource.ResourceTypeVPCIDevice: @@ -1141,19 +1206,19 @@ func modifySCSIDevice( } } -func modifyMappedVirtualDisk( +func (h *Host) modifyMappedVirtualDisk( ctx context.Context, rt guestrequest.RequestType, mvd *guestresource.LCOWMappedVirtualDisk, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer if mvd.ReadOnly { // The only time the policy is empty, and we want it to be empty // is when no policy is provided, and we default to open door // policy. In any other case, e.g. explicit open door or any // other rego policy we would like to mount layers with verity. - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { + if h.HasSecurityPolicy() { devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) if err != nil { return err @@ -1167,6 +1232,17 @@ func modifyMappedVirtualDisk( } } } + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: mountCtx, cancel := context.WithTimeout(ctx, time.Second*5) @@ -1194,6 +1270,12 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } + // Since we're rolling back the policy metadata (via the revertable + // section) on failure, we need to ensure that we have reverted all + // the side effects from this failed mount attempt, otherwise the + // Rego metadata is technically still inconsistent with reality. + // Mount cleans up the created directory and dm devices on failure, + // so we're good. return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, mvd.ReadOnly, mvd.Options, config) } @@ -1201,14 +1283,33 @@ func modifyMappedVirtualDisk( case guestrequest.RequestTypeRemove: if mvd.MountPath != "" { if mvd.ReadOnly { - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } } else { - if err := securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + if err = securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } } + // Check that the directory actually exists first, and if it does + // not then we just refuse to do anything, without closing the dm + // device or setting the mountsBroken flag. Policy metadata is + // still reverted to reflect the fact that we have not done + // anything. + // + // Note: we should not do this check before calling the policy + // enforcer, as otherwise we might inadvertently allow the host to + // find out whether an arbitrary path (which may point to sensitive + // data within a container rootfs) exists or not + if h.HasSecurityPolicy() { + err, exists := checkExists(mvd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting scsi device at %s failed: directory does not exist", mvd.MountPath) + } + } config := &scsi.Config{ Encrypted: mvd.Encrypted, VerityInfo: verityInfo, @@ -1216,8 +1317,11 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } - if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, - mvd.MountPath, config); err != nil { + err = scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, config) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting scsi device at %s failed: %v", mvd.MountPath, err), + ) return err } } @@ -1227,13 +1331,23 @@ func modifyMappedVirtualDisk( } } -func modifyMappedDirectory( +func (h *Host) modifyMappedDirectory( ctx context.Context, vsock transport.Transport, rt guestrequest.RequestType, md *guestresource.LCOWMappedDirectory, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + securityPolicy := h.securityOptions.PolicyEnforcer + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: err = securityPolicy.EnforcePlan9MountPolicy(ctx, md.MountPath) @@ -1241,6 +1355,9 @@ func modifyMappedDirectory( return errors.Wrapf(err, "mounting plan9 device at %s denied by policy", md.MountPath) } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, plan9.Mount here must clean up + // everything if it fails, which it does do. return plan9.Mount(ctx, vsock, md.MountPath, md.ShareName, uint32(md.Port), md.ReadOnly) case guestrequest.RequestTypeRemove: err = securityPolicy.EnforcePlan9UnmountPolicy(ctx, md.MountPath) @@ -1248,20 +1365,28 @@ func modifyMappedDirectory( return errors.Wrapf(err, "unmounting plan9 device at %s denied by policy", md.MountPath) } - return storage.UnmountPath(ctx, md.MountPath, true) + // Note: storage.UnmountPath is nop if path does not exist. + err = storage.UnmountPath(ctx, md.MountPath, true) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting plan9 device at %s failed: %v", md.MountPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } } -func modifyMappedVPMemDevice(ctx context.Context, +func (h *Host) modifyMappedVPMemDevice(ctx context.Context, rt guestrequest.RequestType, vpd *guestresource.LCOWMappedVPMemDevice, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer var deviceHash string - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { + if h.HasSecurityPolicy() { if vpd.MappingInfo != nil { return fmt.Errorf("multi mapping is not supported with verity") } @@ -1271,6 +1396,17 @@ func modifyMappedVPMemDevice(ctx context.Context, } deviceHash = verityInfo.RootDigest } + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: err = securityPolicy.EnforceDeviceMountPolicy(ctx, vpd.MountPath, deviceHash) @@ -1278,13 +1414,39 @@ func modifyMappedVPMemDevice(ctx context.Context, return errors.Wrapf(err, "mounting pmem device %d onto %s denied by policy", vpd.DeviceNumber, vpd.MountPath) } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, pmem.Mount here must clean up + // everything if it fails, which it does do. return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) case guestrequest.RequestTypeRemove: - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { return errors.Wrapf(err, "unmounting pmem device from %s denied by policy", vpd.MountPath) } - return pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + // Check that the directory actually exists first, and if it does not + // then we just refuse to do anything, without closing the dm-linear or + // dm-verity device or setting the mountsBroken flag. + // + // Similar to the reasoning in modifyMappedVirtualDisk, we should not do + // this check before calling the policy enforcer. + if h.HasSecurityPolicy() { + err, exists := checkExists(vpd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting pmem device at %s failed: directory does not exist", vpd.MountPath) + } + } + + err = pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting pmem device at %s failed: %v", vpd.MountPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } @@ -1299,19 +1461,28 @@ func modifyMappedVPCIDevice(ctx context.Context, rt guestrequest.RequestType, vp } } -func modifyCombinedLayers( +func (h *Host) modifyCombinedLayers( ctx context.Context, rt guestrequest.RequestType, cl *guestresource.LCOWCombinedLayers, scratchEncrypted bool, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { - isConfidential := len(securityPolicy.EncodedSecurityPolicy()) > 0 + securityPolicy := h.securityOptions.PolicyEnforcer containerID := cl.ContainerID + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: - if isConfidential { + if h.HasSecurityPolicy() { if err := checkValidContainerID(containerID, "container"); err != nil { return err } @@ -1366,15 +1537,27 @@ func modifyCombinedLayers( return fmt.Errorf("overlay creation denied by policy: %w", err) } + // Correctness for policy revertable section: + // MountLayer does two things - mkdir, then mount. On mount failure, the + // target directory is cleaned up. Therefore we're clean in terms of + // side effects. return overlay.MountLayer(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) case guestrequest.RequestTypeRemove: // cl.ContainerID is not set on remove requests, but rego checks that we can // only umount previously mounted targets anyway - if err := securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { + if err = securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { return errors.Wrap(err, "overlay removal denied by policy") } - return storage.UnmountPath(ctx, cl.ContainerRootPath, true) + // Note: storage.UnmountPath is a no-op if the path does not exist. + err = storage.UnmountPath(ctx, cl.ContainerRootPath, true) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting overlay at %s failed: %v", cl.ContainerRootPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } @@ -1664,3 +1847,22 @@ func setupVirtualPodHugePageMountsPath(virtualSandboxID string) error { return storage.MountRShared(mountPath) } + +// If *err is not nil, the section is rolled back, otherwise it is committed. +func (h *Host) commitOrRollbackPolicyRevSection( + ctx context.Context, + rev securitypolicy.RevertableSectionHandle, + err *error, +) { + if !h.HasSecurityPolicy() { + // Don't produce bogus log entries if we aren't in confidential mode, + // even though rev.Rollback would have been no-op. + return + } + if *err != nil { + rev.Rollback() + logrus.WithContext(ctx).WithError(*err).Warn("rolling back security policy revertable section due to error") + } else { + rev.Commit() + } +} diff --git a/internal/guest/storage/mount.go b/internal/guest/storage/mount.go index a3d10a3b25..142f0ccbbc 100644 --- a/internal/guest/storage/mount.go +++ b/internal/guest/storage/mount.go @@ -16,6 +16,7 @@ import ( "go.opencensus.io/trace" "golang.org/x/sys/unix" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" ) @@ -126,6 +127,7 @@ func UnmountPath(ctx context.Context, target string, removeTarget bool) (err err if _, err := osStat(target); err != nil { if os.IsNotExist(err) { + log.G(ctx).WithField("target", target).Warnf("UnmountPath called for non-existent path") return nil } return errors.Wrapf(err, "failed to determine if path '%s' exists", target) diff --git a/internal/guest/storage/overlay/overlay.go b/internal/guest/storage/overlay/overlay.go index aa4877508f..84bf8fa529 100644 --- a/internal/guest/storage/overlay/overlay.go +++ b/internal/guest/storage/overlay/overlay.go @@ -56,8 +56,7 @@ func processErrNoSpace(ctx context.Context, path string, err error) { }).WithError(err).Warn("got ENOSPC, gathering diagnostics") } -// MountLayer first enforces the security policy for the container's layer paths -// and then calls Mount to mount the layer paths as an overlayfs. +// MountLayer calls Mount to mount the layer paths as an overlayfs. func MountLayer( ctx context.Context, layerPaths []string, diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index ec62636590..83c586c3eb 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -121,8 +121,9 @@ type Config struct { // Mount creates a mount from the SCSI device on `controller` index `lun` to // `target` // -// `target` will be created. On mount failure the created `target` will be -// automatically cleaned up. +// `target` will be created. On mount failure the created `target`, as well as +// any associated dm-crypt or dm-verify devices will be automatically cleaned +// up. // // If the config has `encrypted` is set to true, the SCSI device will be // encrypted using dm-crypt. @@ -200,7 +201,8 @@ func Mount( var deviceFS string if config.Encrypted { cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition) - encryptedSource, err := encryptDevice(spnCtx, source, cryptDeviceName) + var encryptedSource string + encryptedSource, err = encryptDevice(spnCtx, source, cryptDeviceName) if err != nil { // todo (maksiman): add better retry logic, similar to how SCSI device mounts are // retried on unix.ENOENT and unix.ENXIO. The retry should probably be on an @@ -211,6 +213,13 @@ func Mount( } } source = encryptedSource + defer func() { + if err != nil { + if err := cleanupCryptDevice(spnCtx, cryptDeviceName); err != nil { + log.G(spnCtx).WithError(err).WithField("cryptDeviceName", cryptDeviceName).Debug("failed to cleanup dm-crypt device after mount failure") + } + } + }() } else { // Get the filesystem that is already on the device (if any) and use that // as the mountType unless `Filesystem` was given. diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index ebfcf8e382..94992047bd 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -999,6 +999,12 @@ func Test_Mount_EncryptDevice_Mkfs_Error(t *testing.T) { } return expectedDevicePath, nil } + cleanupCryptDevice = func(_ context.Context, dmCryptName string) error { + if dmCryptName != expectedCryptTarget { + t.Fatalf("expected cleanupCryptDevice name %q got %q", expectedCryptTarget, dmCryptName) + } + return nil + } osStat = osStatNoop xfsFormat = func(arg string) error { diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter.go b/internal/regopolicyinterpreter/regopolicyinterpreter.go index 047a4a27b7..ebe3fcfff4 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter.go @@ -63,6 +63,9 @@ type RegoModule struct { type regoMetadata map[string]map[string]interface{} +const metadataRootKey = "metadata" +const metadataOperationsKey = "metadata" + type regoMetadataAction string const ( @@ -81,6 +84,11 @@ type regoMetadataOperation struct { // The result from a policy query type RegoQueryResult map[string]interface{} +// An immutable, saved copy of the metadata state. +type SavedMetadata struct { + metadataRoot regoMetadata +} + // deep copy for an object func copyObject(data map[string]interface{}) (map[string]interface{}, error) { objJSON, err := json.Marshal(data) @@ -113,6 +121,24 @@ func copyValue(value interface{}) (interface{}, error) { return valueCopy, nil } +// deep copy for regoMetadata. +// We cannot use copyObject for this due to the fact that map[string]interface{} +// is a concrete type and a map of it cannot be used as a map of interface{}. +func copyRegoMetadata(value regoMetadata) (regoMetadata, error) { + valueJSON, err := json.Marshal(value) + if err != nil { + return nil, err + } + + var valueCopy regoMetadata + err = json.Unmarshal(valueJSON, &valueCopy) + if err != nil { + return nil, err + } + + return valueCopy, nil +} + // NewRegoPolicyInterpreter creates a new RegoPolicyInterpreter, using the code provided. // inputData is the Rego data which should be used as the initial state // of the interpreter. A deep copy is performed on it such that it will @@ -123,8 +149,8 @@ func NewRegoPolicyInterpreter(code string, inputData map[string]interface{}) (*R return nil, fmt.Errorf("unable to copy the input data: %w", err) } - if _, ok := data["metadata"]; !ok { - data["metadata"] = make(regoMetadata) + if _, ok := data[metadataRootKey]; !ok { + data[metadataRootKey] = make(regoMetadata) } policy := &RegoPolicyInterpreter{ @@ -207,7 +233,7 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ r.dataAndModulesMutex.Lock() defer r.dataAndModulesMutex.Unlock() - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return nil, errors.New("illegal interpreter state: invalid metadata object type") } @@ -228,6 +254,32 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ } } +// Saves a copy of the internal policy metadata state. +func (r *RegoPolicyInterpreter) SaveMetadata() (s SavedMetadata, err error) { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) + if !ok { + return SavedMetadata{}, errors.New("illegal interpreter state: invalid metadata object type") + } + s.metadataRoot, err = copyRegoMetadata(metadataRoot) + return s, err +} + +// Restores a previously saved metadata state. +func (r *RegoPolicyInterpreter) RestoreMetadata(m SavedMetadata) error { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + copied, err := copyRegoMetadata(m.metadataRoot) + if err != nil { + return fmt.Errorf("unable to copy metadata: %w", err) + } + r.data[metadataRootKey] = copied + return nil +} + func newRegoMetadataOperation(operation interface{}) (*regoMetadataOperation, error) { var metadataOp regoMetadataOperation @@ -286,7 +338,7 @@ func (r *RegoPolicyInterpreter) UpdateOSType(os string) error { func (r *RegoPolicyInterpreter) updateMetadata(ops []*regoMetadataOperation) error { // dataAndModulesMutex must be held before calling this - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return errors.New("illegal interpreter state: invalid metadata object type") } @@ -431,7 +483,7 @@ func (r *RegoPolicyInterpreter) logMetadata() { return } - contents, err := json.Marshal(r.data["metadata"]) + contents, err := json.Marshal(r.data[metadataRootKey]) if err != nil { r.metadataLogger.Printf("error marshaling metadata: %v\n", err.Error()) } else { @@ -617,7 +669,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) r.logResult(rule, resultSet) ops := []*regoMetadataOperation{} - if rawMetadata, ok := resultSet["metadata"]; ok { + if rawMetadata, ok := resultSet[metadataOperationsKey]; ok { metadata, ok := rawMetadata.([]interface{}) if !ok { return nil, errors.New("error loading metadata array: invalid type") @@ -640,7 +692,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) } for name, value := range resultSet { - if name == "metadata" { + if name == metadataOperationsKey { continue } else { result[name] = value diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go index b7d86609f7..3872afff51 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go @@ -72,6 +72,37 @@ func Test_copyValue(t *testing.T) { } } +func Test_copyRegoMetadata(t *testing.T) { + f := func(orig testRegoMetadata) bool { + copy, err := copyRegoMetadata(regoMetadata(orig)) + if err != nil { + t.Error(err) + return false + } + + if len(orig) != len(copy) { + t.Errorf("original and copy have different number of objects: %d != %d", len(orig), len(copy)) + return false + } + + for name, origObject := range orig { + if copyObject, ok := copy[name]; ok { + if !assertObjectsEqual(origObject, copyObject) { + t.Errorf("original and copy differ on key %s", name) + } + } else { + t.Errorf("copy missing object %s", name) + } + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 30, Rand: testRand}); err != nil { + t.Errorf("Test_copyRegoMetadata: %v", err) + } +} + //go:embed test.rego var testCode string @@ -364,6 +395,107 @@ func Test_Metadata_Remove(t *testing.T) { } } +func Test_Metadata_SaveRestore(t *testing.T) { + rego, err := setupRego() + if err != nil { + t.Fatal(err) + } + + f := func(pairs1before, pairs1after intPairArray, name1 metadataName, pairs2before, pairs2after intPairArray, name2 metadataName) bool { + if name1 == name2 { + t.Fatalf("generated two identical names: %s", name1) + } + + err := appendAll(rego, pairs1before, name1) + if err != nil { + t.Errorf("error appending pairs1before: %v", err) + return false + } + err = appendAll(rego, pairs2before, name2) + if err != nil { + t.Errorf("error appending pairs2before: %v", err) + return false + } + + saved, err := rego.SaveMetadata() + if err != nil { + t.Errorf("unable to save metadata: %v", err) + return false + } + + beforeSum1 := getExpectedGapFromPairs(pairs1before) + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Error(err) + return false + } + + beforeSum2 := getExpectedGapFromPairs(pairs2before) + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Error(err) + return false + } + + // computeGap would have cleared the list, so we restore it. + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = appendAll(rego, pairs1after, name1) + if err != nil { + t.Errorf("error appending pairs1after: %v", err) + return false + } + + err = appendAll(rego, pairs2after, name2) + if err != nil { + t.Errorf("error appending pairs2after: %v", err) + return false + } + + afterSum1 := beforeSum1 + getExpectedGapFromPairs(pairs1after) + err = computeGap(rego, name1, afterSum1) + if err != nil { + t.Error(err) + return false + } + + afterSum2 := beforeSum2 + getExpectedGapFromPairs(pairs2after) + err = computeGap(rego, name2, afterSum2) + if err != nil { + t.Error(err) + return false + } + + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Errorf("computeGap failed for name1 after restore: %v", err) + return false + } + + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Errorf("computeGap failed for name2 after restore: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 100, Rand: testRand}); err != nil { + t.Errorf("Test_Metadata_SaveRestore: %v", err) + } +} + //go:embed module.rego var moduleCode string @@ -508,6 +640,7 @@ type testValue struct { } type testArray []interface{} type testObject map[string]interface{} +type testRegoMetadata regoMetadata type testValueType int @@ -580,6 +713,16 @@ func (testObject) Generate(r *rand.Rand, _ int) reflect.Value { return reflect.ValueOf(value) } +func (testRegoMetadata) Generate(r *rand.Rand, _ int) reflect.Value { + numObjects := r.Intn(maxNumberOfFields) + metadata := make(testRegoMetadata) + for i := 0; i < numObjects; i++ { + name := uniqueString(r) + metadata[name] = generateObject(r, 0) + } + return reflect.ValueOf(metadata) +} + func getResult(r *RegoPolicyInterpreter, p intPair, rule string) (RegoQueryResult, error) { input := map[string]interface{}{"a": p.a, "b": p.b} result, err := r.Query("data.test."+rule, input) @@ -640,6 +783,27 @@ func appendLists(r *RegoPolicyInterpreter, p intPair, name metadataName) error { return nil } +func appendAll(r *RegoPolicyInterpreter, pairs intPairArray, name metadataName) error { + for _, pair := range pairs { + if err := appendLists(r, pair, name); err != nil { + return fmt.Errorf("error appending pair %v: %w", pair, err) + } + } + return nil +} + +func getExpectedGapFromPairs(pairs intPairArray) int { + expected := 0 + for _, pair := range pairs { + if pair.a >= pair.b { + expected += pair.a - pair.b + } else { + expected += pair.b - pair.a + } + } + return expected +} + func computeGap(r *RegoPolicyInterpreter, name metadataName, expected int) error { input := map[string]interface{}{"name": string(name)} result, err := r.Query("data.test.compute_gap", input) diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index 5ac12a5a0a..3adbb6a2b6 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -347,18 +347,25 @@ type regoPlan9MountTestConfig struct { } func mountImageForContainer(policy *regoEnforcer, container *securityPolicyContainer) (string, error) { - ctx := context.Background() containerID := testDataGenerator.uniqueContainerID() + if err := mountImageForContainerWithID(policy, container, containerID); err != nil { + return "", err + } + return containerID, nil +} + +func mountImageForContainerWithID(policy *regoEnforcer, container *securityPolicyContainer, containerID string) error { + ctx := context.Background() layerPaths, err := testDataGenerator.createValidOverlayForContainer(policy, container) if err != nil { - return "", fmt.Errorf("error creating valid overlay: %w", err) + return fmt.Errorf("error creating valid overlay: %w", err) } scratchDisk := getScratchDiskMountTarget(containerID) err = policy.EnforceRWDeviceMountPolicy(ctx, scratchDisk, true, true, "xfs") if err != nil { - return "", fmt.Errorf("error mounting scratch disk: %w", err) + return fmt.Errorf("error mounting scratch disk: %w", err) } overlayTarget := getOverlayMountTarget(containerID) @@ -367,12 +374,13 @@ func mountImageForContainer(policy *regoEnforcer, container *securityPolicyConta err = policy.EnforceOverlayMountPolicy( ctx, containerID, copyStrings(layerPaths), overlayTarget) if err != nil { - return "", fmt.Errorf("error mounting filesystem: %w", err) + return fmt.Errorf("error mounting filesystem: %w", err) } - return containerID, nil + return nil } + func buildMountSpecFromMountArray(mounts []mountInternal, sandboxID string, r *rand.Rand) *oci.Spec { mountSpec := new(oci.Spec) @@ -1404,6 +1412,10 @@ func setupRegoCreateContainerTest(gc *generatedConstraints, testContainer *secur return nil, err } + return createTestContainerSpec(gc, containerID, testContainer, privilegedError, policy, defaultMounts, privilegedMounts) +} + +func createTestContainerSpec(gc *generatedConstraints, containerID string, testContainer *securityPolicyContainer, privilegedError bool, policy *regoEnforcer, defaultMounts, privilegedMounts []mountInternal) (*regoContainerTestConfig, error) { envList := buildEnvironmentVariablesFromEnvRules(testContainer.EnvRules, testRand) sandboxID := testDataGenerator.uniqueSandboxID() @@ -2994,3 +3006,19 @@ type containerInitProcess struct { WorkingDir string AllowStdioAccess bool } + +func startRevertableSection(t *testing.T, policy *regoEnforcer) RevertableSectionHandle { + rev, err := policy.StartRevertableSection() + if err != nil { + t.Fatalf("Failed to start revertable section: %v", err) + } + return rev +} + +func commitOrRollback(rev RevertableSectionHandle, shouldCommit bool) { + if shouldCommit { + rev.Commit() + } else { + rev.Rollback() + } +} diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index 8dd409fccf..0ddd933bac 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -963,6 +963,90 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } } +func Test_Rego_EnforceOverlayMountPolicy_MountFail(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + securityPolicy := gc.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + tc := selectContainerFromContainerList(gc.containers, testRand) + tid := testDataGenerator.uniqueContainerID() + scratchTarget := getScratchDiskMountTarget(tid) + + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchTarget, true, true, "xfs") + if err != nil { + t.Errorf("failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + rev.Commit() + + layerToErr := testRand.Intn(len(tc.Layers)) + errLayerPathIndex := len(tc.Layers) - layerToErr - 1 + layerPaths := make([]string, len(tc.Layers)) + for i, layerHash := range tc.Layers { + rev := startRevertableSection(t, policy) + target := testDataGenerator.uniqueLayerMountTarget() + layerPaths[len(tc.Layers)-i-1] = target + err = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", err) + return false + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback + rev.Rollback() + } else { + rev.Commit() + } + } + + rev = startRevertableSection(t, policy) + overlayTarget := getOverlayMountTarget(tid) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if !assertDecisionJSONContains(t, err, append(slices.Clone(layerPaths), "no matching containers for overlay")...) { + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + layerPathsWithoutErr := make([]string, 0) + for i, layerPath := range layerPaths { + if i != errLayerPathIndex { + layerPathsWithoutErr = append(layerPathsWithoutErr, layerPath) + } + } + + rev = startRevertableSection(t, policy) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPathsWithoutErr, overlayTarget) + if !assertDecisionJSONContains(t, err, append(slices.Clone(layerPathsWithoutErr), "no matching containers for overlay")...) { + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + retryTarget := layerPaths[errLayerPathIndex] + rev = startRevertableSection(t, policy) + err = policy.EnforceDeviceMountPolicy(gc.ctx, retryTarget, tc.Layers[layerToErr]) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy again after one previous reverted failure: %v", err) + return false + } + rev.Commit() + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if err != nil { + t.Errorf("failed to EnforceOverlayMountPolicy after one previous reverted failure: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceOverlayMountPolicy_MountFail: %v", err) + } +} + func Test_Rego_EnforceOverlayUnmountPolicy(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoOverlayTest(p, true) @@ -6103,6 +6187,195 @@ func Test_Rego_Enforce_CreateContainer_RequiredEnvMissingHasErrorMessage(t *test } } +func Test_Rego_EnforceCreateContainer_RejectRevertedOverlayMount(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + layers, err := testDataGenerator.createValidOverlayForContainer(policy, container) + if err != nil { + t.Errorf("Failed to createValidOverlayForContainer: %v", err) + return false + } + + scratchMountTarget := getScratchDiskMountTarget(containerID) + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + rev.Commit() + + rev = startRevertableSection(t, policy) + overlayTarget := getOverlayMountTarget(containerID) + err = policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layers, overlayTarget) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + // Simulate a failure by rolling back the overlay mount + rev.Rollback() + + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + // "Retry" overlay mount + rev = startRevertableSection(t, policy) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tc.containerID, layers, overlayTarget) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + rev.Commit() + + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying overlay mount: %v", err) + return false + } + rev.Commit() + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + +func Test_Rego_EnforceCreateContainer_RetryEverything(t *testing.T) { + f := func(gc *generatedConstraints, + newContainerID, failScratchMount, testDenyInvalidContainerCreation bool, + ) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + scratchMountTarget := getScratchDiskMountTarget(containerID) + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + + succeedLayerPaths := make([]string, 0) + + if failScratchMount { + rev.Rollback() + } else { + rev.Commit() + + // Simulate one of the layers failing to mount, after which the outside + // gives up on this container and starts over. + layerToErr := testRand.Intn(len(container.Layers)) + for i, layerHash := range container.Layers { + rev := startRevertableSection(t, policy) + target := testDataGenerator.uniqueLayerMountTarget() + err = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", err) + return false + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback + rev.Rollback() + break + } else { + rev.Commit() + succeedLayerPaths = append(succeedLayerPaths, target) + } + } + + for _, layerPath := range succeedLayerPaths { + rev := startRevertableSection(t, policy) + err = policy.EnforceDeviceUnmountPolicy(gc.ctx, layerPath) + if err != nil { + t.Errorf("Failed to EnforceDeviceUnmountPolicy: %v", err) + return false + } + rev.Commit() + } + + rev = startRevertableSection(t, policy) + err = policy.EnforceRWDeviceUnmountPolicy(gc.ctx, scratchMountTarget) + if err != nil { + t.Errorf("Failed to EnforceRWDeviceUnmountPolicy: %v", err) + return false + } + rev.Commit() + } + + if testDenyInvalidContainerCreation { + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + } + rev.Rollback() + } + + if newContainerID { + tc.containerID = testDataGenerator.uniqueContainerID() + } + + err = mountImageForContainerWithID(policy, container, tc.containerID) + if err != nil { + t.Errorf("Failed to mount image for container after reverting and retrying: %v", err) + return false + } + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + func Test_Rego_ExecInContainerPolicy_RequiredEnvMissingHasErrorMessage(t *testing.T) { constraints := generateConstraints(testRand, 1) container := selectContainerFromContainerList(constraints.containers, testRand) diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index c9782e0d63..4014640ef7 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -57,6 +57,14 @@ func init() { registeredEnforcers[openDoorEnforcerName] = createOpenDoorEnforcer } +// Represents an in-progress revertable section. To ensure state is consistent, +// Commit() and Rollback() must not fail, so they do not return anything, and if +// an error does occur they should panic. +type RevertableSectionHandle interface { + Commit() + Rollback() +} + type SecurityPolicyEnforcer interface { EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) (err error) EnforceRWDeviceMountPolicy(ctx context.Context, target string, encrypted, ensureFilesystem bool, filesystem string) (err error) @@ -127,6 +135,7 @@ type SecurityPolicyEnforcer interface { EnforceScratchUnmountPolicy(ctx context.Context, scratchPath string) (err error) GetUserInfo(spec *oci.Process, rootPath string) (IDName, []IDName, string, error) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) (err error) + StartRevertableSection() (RevertableSectionHandle, error) } //nolint:unused @@ -181,6 +190,11 @@ func CreateSecurityPolicyEnforcer( } } +type nopRevertableSectionHandle struct{} + +func (nopRevertableSectionHandle) Commit() {} +func (nopRevertableSectionHandle) Rollback() {} + type OpenDoorSecurityPolicyEnforcer struct { encodedSecurityPolicy string } @@ -319,6 +333,10 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Cont return nil } +func (*OpenDoorSecurityPolicyEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + return nopRevertableSectionHandle{}, nil +} + type ClosedDoorSecurityPolicyEnforcer struct{} var _ SecurityPolicyEnforcer = (*ClosedDoorSecurityPolicyEnforcer)(nil) @@ -443,3 +461,7 @@ func (ClosedDoorSecurityPolicyEnforcer) GetUserInfo(spec *oci.Process, rootPath func (ClosedDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) error { return nil } + +func (*ClosedDoorSecurityPolicyEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + return nopRevertableSectionHandle{}, nil +} diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index 075d5450c5..1393d5109a 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -12,8 +12,10 @@ import ( "regexp" "slices" "strings" + "sync" "syscall" + "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" rpi "github.com/Microsoft/hcsshim/internal/regopolicyinterpreter" @@ -57,6 +59,10 @@ type regoEnforcer struct { maxErrorMessageLength int // OS type osType string + // Mutex to ensure only one revertable section is active + revertableSectionLock sync.Mutex + // Saved metadata for the revertable section + savedMetadata rpi.SavedMetadata } var _ SecurityPolicyEnforcer = (*regoEnforcer)(nil) @@ -1175,3 +1181,81 @@ func (policy *regoEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Context, conta func (policy *regoEnforcer) GetUserInfo(process *oci.Process, rootPath string) (IDName, []IDName, string, error) { return GetAllUserInfo(process, rootPath) } + +type revertableSectionHandle struct { + // policy is cleared once this struct is "used", to prevent accidental + // duplicate Commit/Rollback calls. + policy *regoEnforcer +} + +func (policy *regoEnforcer) inRevertableSection() bool { + succ := policy.revertableSectionLock.TryLock() + if succ { + // since nobody else has the lock, we're not in fact in a revertable + // section. + policy.revertableSectionLock.Unlock() + return false + } + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Don't unlock it here! + return true +} + +// Starts a revertable section by saving the current policy state. If another +// revertable section is already active, this will wait until that one is +// finished. +func (policy *regoEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + policy.revertableSectionLock.Lock() + var err error + policy.savedMetadata, err = policy.rego.SaveMetadata() + if err != nil { + err = errors.Wrapf(err, "unable to save metadata for revertable section") + policy.revertableSectionLock.Unlock() + return &revertableSectionHandle{}, err + } + // Keep policy.revertableSectionLock locked until the end of the section. + sh := &revertableSectionHandle{ + policy: policy, + } + return sh, nil +} + +func (sh *revertableSectionHandle) Commit() { + if sh.policy == nil { + gcs.UnrecoverableError(errors.New("revertable section handle already used")) + } + + policy := sh.policy + sh.policy = nil + lockSucc := policy.revertableSectionLock.TryLock() + if lockSucc { + gcs.UnrecoverableError(errors.New("not in a revertable section")) + } else { + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Clear the saved metadata just in case, then unlock to exit the + // section. + policy.savedMetadata = rpi.SavedMetadata{} + policy.revertableSectionLock.Unlock() + } +} + +func (sh *revertableSectionHandle) Rollback() { + if sh.policy == nil { + gcs.UnrecoverableError(errors.New("revertable section handle already used")) + } + + policy := sh.policy + sh.policy = nil + lockSucc := policy.revertableSectionLock.TryLock() + if lockSucc { + gcs.UnrecoverableError(errors.New("not in a revertable section")) + } else { + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Restore the saved metadata, then unlock to exit the section. + err := policy.rego.RestoreMetadata(policy.savedMetadata) + if err != nil { + gcs.UnrecoverableError(errors.Wrap(err, "unable to restore metadata for revertable section")) + } + policy.revertableSectionLock.Unlock() + } +} From beb61972eb5e7bdf971f84b3764152f2d2061c50 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Fri, 19 Dec 2025 11:52:58 +0000 Subject: [PATCH 07/12] Allow unrecoverable_error.go to build on Windows and fix IsSNP() invocation IsSNP() now can return an error, although this is not expected on LCOW. Signed-off-by: Tingmao Wang --- internal/gcs/unrecoverable_error.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/gcs/unrecoverable_error.go b/internal/gcs/unrecoverable_error.go index dbb7240266..96528088aa 100644 --- a/internal/gcs/unrecoverable_error.go +++ b/internal/gcs/unrecoverable_error.go @@ -1,6 +1,3 @@ -//go:build linux -// +build linux - package gcs import ( @@ -28,7 +25,14 @@ func UnrecoverableError(err error) { "Unrecoverable error in GCS: %v\n%s", err, stackTrace, ) - isSnp := amdsevsnp.IsSNP() + + isSnp, err := amdsevsnp.IsSNP() + if err != nil { + // IsSNP() cannot fail on LCOW + // but if it does, we proceed as if we're on SNP to be safe. + isSnp = true + } + if isSnp { errPrint += "\nThis thread will now enter an infinite loop." } From 5ef19a4671c04ed192e5fd750155786f9f60f79f Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Tue, 11 Nov 2025 14:58:51 +0000 Subject: [PATCH 08/12] Merged PR 12878148: bridge: Force sequential message handling for confidential containers [cherry-picked from f81b450894206a79fff4d63182ff034ba503ebdb] This PR contains 2 commits. The first one is the fix: **bridge: Force sequential message handling for confidential containers** This fixes a vulnerability (and reduces the surface for other similar potential vulnerabilities) in confidential containers where if the host sends a mount/unmount modification request concurrently with an ongoing CreateContainer request, the host could re-order or skip image layers for the container to be started. While this could be fixed by adding mutex lock/unlock around the individual modifyMappedVirtualDisk/modifyCombinedLayers/CreateContainer functions, we decided that in order to prevent any more of this class of issues, the UVM, when running in confidential mode, should just not allow concurrent requests (with exception for any actually long-running requests, which for now is just waitProcess). The second one adds a log entry for when the processing thread blocks. This will make it easier to debug should the gcs "hung" on a request. This PR is created on ADO targeting the conf branch as this security vulnerability is not public yet. This fix should be backported to main once deployed. Related work items: #33357501, #34327300 Signed-off-by: Tingmao Wang --- cmd/gcs/main.go | 5 ++ internal/guest/bridge/bridge.go | 94 +++++++++++++++++++++++++-------- 2 files changed, 76 insertions(+), 23 deletions(-) diff --git a/cmd/gcs/main.go b/cmd/gcs/main.go index 36ae1991b6..a17f3ae232 100644 --- a/cmd/gcs/main.go +++ b/cmd/gcs/main.go @@ -31,6 +31,7 @@ import ( "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/version" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" "github.com/Microsoft/hcsshim/pkg/securitypolicy" ) @@ -362,6 +363,10 @@ func main() { b := bridge.Bridge{ Handler: mux, EnableV4: *v4, + + // For confidential containers, we protect ourselves against attacks caused + // by concurrent modifications, by processing one request at a time. + ForceSequential: amdsevsnp.IsSNP(), } h := hcsv2.NewHost(rtime, tport, initialEnforcer, logWriter) // Initialize virtual pod support in the host diff --git a/internal/guest/bridge/bridge.go b/internal/guest/bridge/bridge.go index 4ea03ed104..875def5809 100644 --- a/internal/guest/bridge/bridge.go +++ b/internal/guest/bridge/bridge.go @@ -177,6 +177,10 @@ type Bridge struct { Handler Handler // EnableV4 enables the v4+ bridge and the schema v2+ interfaces. EnableV4 bool + // Setting ForceSequential to true will force the bridge to only process one + // request at a time, except for certain long-running operations (as defined + // in asyncMessages). + ForceSequential bool // responseChan is the response channel used for both request/response // and publish notification workflows. @@ -191,6 +195,14 @@ type Bridge struct { protVer prot.ProtocolVersion } +// Messages that will be processed asynchronously even in sequential mode. Note +// that in sequential mode, these messages will still wait for any in-progress +// non-async messages to be handled before they are processed, but once they are +// "acknowledged", the rest will be done asynchronously. +var alwaysAsyncMessages map[prot.MessageIdentifier]bool = map[prot.MessageIdentifier]bool{ + prot.ComputeSystemWaitForProcessV1: true, +} + // AssignHandlers creates and assigns the appropriate bridge // events to be listen for and intercepted on `mux` before forwarding // to `gcs` for handling. @@ -238,6 +250,10 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser defer close(requestErrChan) defer bridgeIn.Close() + if b.ForceSequential { + log.G(context.Background()).Info("bridge: ForceSequential enabled") + } + // Receive bridge requests and schedule them to be processed. go func() { var recverr error @@ -340,30 +356,36 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser }() // Process each bridge request async and create the response writer. go func() { - for req := range requestChan { - go func(r *Request) { - br := bridgeResponse{ - ctx: r.Context, - header: &prot.MessageHeader{ - Type: prot.GetResponseIdentifier(r.Header.Type), - ID: r.Header.ID, - }, - } - resp, err := b.Handler.ServeMsg(r) - if resp == nil { - resp = &prot.MessageResponseBase{} - } - resp.Base().ActivityID = r.ActivityID - if err != nil { - span := trace.FromContext(r.Context) - if span != nil { - oc.SetSpanStatus(span, err) - } - setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) + doOneRequest := func(r *Request) { + br := bridgeResponse{ + ctx: r.Context, + header: &prot.MessageHeader{ + Type: prot.GetResponseIdentifier(r.Header.Type), + ID: r.Header.ID, + }, + } + resp, err := b.Handler.ServeMsg(r) + if resp == nil { + resp = &prot.MessageResponseBase{} + } + resp.Base().ActivityID = r.ActivityID + if err != nil { + span := trace.FromContext(r.Context) + if span != nil { + oc.SetSpanStatus(span, err) } - br.response = resp - b.responseChan <- br - }(req) + setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) + } + br.response = resp + b.responseChan <- br + } + + for req := range requestChan { + if b.ForceSequential && !alwaysAsyncMessages[req.Header.Type] { + runSequentialRequestHandler(req, doOneRequest) + } else { + go doOneRequest(req) + } } }() // Process each bridge response sync. This channel is for request/response and publish workflows. @@ -423,6 +445,32 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser } } +// Do handleFn(r), but prints a warning if handleFn does not, or takes too long +// to return. +func runSequentialRequestHandler(r *Request, handleFn func(*Request)) { + // Note that this is only a context used for triggering the blockage + // warning, the request processing still uses r.Context. We don't want to + // cancel the request handling itself when we reach the 5s timeout. + timeoutCtx, cancel := context.WithTimeout(r.Context, 5*time.Second) + go func() { + <-timeoutCtx.Done() + if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + log.G(timeoutCtx).WithFields(logrus.Fields{ + // We want to log those even though we're providing r.Context, since if + // the request never finishes the span end log will never get written, + // and we may therefore not be able to find out about the following info + // otherwise: + "message-type": r.Header.Type.String(), + "message-id": r.Header.ID, + "activity-id": r.ActivityID, + "container-id": r.ContainerID, + }).Warnf("bridge: request processing thread in sequential mode blocked on the current request for more than 5 seconds") + } + }() + defer cancel() + handleFn(r) +} + // PublishNotification writes a specific notification to the bridge. func (b *Bridge) PublishNotification(n *prot.ContainerNotification) { ctx, span := oc.StartSpan(context.Background(), From fd18c53571146b838771910c5cacb05fff4e73bf Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Fri, 19 Dec 2025 11:48:20 +0000 Subject: [PATCH 09/12] Fix usage of IsSNP() to handle the new error return value Signed-off-by: Tingmao Wang --- cmd/gcs/main.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/cmd/gcs/main.go b/cmd/gcs/main.go index a17f3ae232..09bfe02305 100644 --- a/cmd/gcs/main.go +++ b/cmd/gcs/main.go @@ -360,13 +360,22 @@ func main() { logrus.WithError(err).Fatal("failed to initialize new runc runtime") } mux := bridge.NewBridgeMux() + + forceSequential, err := amdsevsnp.IsSNP() + if err != nil { + // IsSNP cannot fail on LCOW + logrus.Errorf("Got unexpected error from IsSNP(): %v", err) + // If it fails, we proceed with forceSequential enabled to be safe + forceSequential = true + } + b := bridge.Bridge{ Handler: mux, EnableV4: *v4, // For confidential containers, we protect ourselves against attacks caused // by concurrent modifications, by processing one request at a time. - ForceSequential: amdsevsnp.IsSNP(), + ForceSequential: forceSequential, } h := hcsv2.NewHost(rtime, tport, initialEnforcer, logWriter) // Initialize virtual pod support in the host From 82ece465c7674f1dfad6dfdc14e68c2488006519 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Mon, 20 Oct 2025 16:52:21 +0000 Subject: [PATCH 10/12] Merged PR 13618357: guest/network: Restrict hostname to valid characters [cherry-picked from 055ee5eb4a802cb407575fb6cc1e9b07069d3319] guest/network: Restrict hostname to valid characters Because we write this hostname to /etc/hosts, without proper validation the host can trick us into writing arbitrary data to /etc/hosts, which can, for example, redirect things like ip6-localhost (but likely not localhost itself) to an attacker-controlled IP address. We implement a check here that the host-provided DNS name in the OCI spec is valid. ACI actually restricts this to 5-63 characters of a-zA-Z0-9 and '-', where the first and last characters can not be '-'. This aligns with the Kubernetes restriction. c.f. IsValidDnsLabel in Compute-ACI. However, there is no consistent official agreement on what a valid hostname can contain. RFC 952 says that "Domain name" can be up to 24 characters of A-Z0-9 '.' and '-', RFC 1123 expands this to 255 characters, but RFC 1035 claims that domain names can in fact contain anything if quoted (as long as the length is within 255 characters), and this is confirmed again in RFC 2181. In practice we see names with underscopes, most commonly \_dmarc.example.com. curl allows 0-9a-zA-Z and -.\_|~ and any other codepoints from \u0001-\u001f and above \u007f: https://github.com/curl/curl/blob/master/lib/urlapi.c#L527-L545 With the above in mind, this commit allows up to 255 characters of a-zA-Z0-9 and '_', '-' and '.'. This change is applied to all LCOW use cases. This fix can be tested with the below code to bypass any host-side checks: +++ b/internal/hcsoci/hcsdoc_lcow.go @@ -52,6 +52,10 @@ func createLCOWSpec(ctx context.Context, coi *createOptionsInternal) (*specs.Spe spec.Linux.Seccomp = nil } + if spec.Annotations[annotations.KubernetesContainerType] == "sandbox" { + spec.Hostname = "invalid-hostname\n1.1.1.1 ip6-localhost ip6-loopback localhost" + } + return spec, nil } Output: time="2025-10-01T15:13:41Z" level=fatal msg="run pod sandbox: rpc error: code = Unknown desc = failed to create containerd task: failed to create shim task: failed to create container f2209bb2960d5162fc9937d3362e1e2cf1724c56d1296ba2551ce510cb2bcd43: guest RPC failure: hostname \"invalid-hostname\\n1.1.1.1 ip6-localhost ip6-loopback localhost\" invalid: must match ^[a-zA-Z0-9_\\-\\.]{0,999}$: unknown" Related work items: #34370598 Closes: https://msazure.visualstudio.com/One/_workitems/edit/34370598 Signed-off-by: Tingmao Wang --- internal/guest/network/network.go | 13 +++++++ internal/guest/network/network_test.go | 35 +++++++++++++++++++ .../guest/runtime/hcsv2/sandbox_container.go | 3 ++ .../runtime/hcsv2/standalone_container.go | 3 ++ 4 files changed, 54 insertions(+) diff --git a/internal/guest/network/network.go b/internal/guest/network/network.go index 68f7c1bef1..4cac108b03 100644 --- a/internal/guest/network/network.go +++ b/internal/guest/network/network.go @@ -9,6 +9,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strings" "time" @@ -32,6 +33,18 @@ var ( // maxDNSSearches is limited to 6 in `man 5 resolv.conf` const maxDNSSearches = 6 +var validHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]{0,255}$`) + +// Check that the hostname is safe. This function is less strict than +// technically allowed, but ensures that when the hostname is inserted to +// /etc/hosts, it cannot lead to injection attacks. +func ValidateHostname(hostname string) error { + if !validHostnameRegex.MatchString(hostname) { + return errors.Errorf("hostname %q invalid: must match %s", hostname, validHostnameRegex.String()) + } + return nil +} + // GenerateEtcHostsContent generates a /etc/hosts file based on `hostname`. func GenerateEtcHostsContent(ctx context.Context, hostname string) string { _, span := oc.StartSpan(ctx, "network::GenerateEtcHostsContent") diff --git a/internal/guest/network/network_test.go b/internal/guest/network/network_test.go index 4ac6ff1f10..bb19db6974 100644 --- a/internal/guest/network/network_test.go +++ b/internal/guest/network/network_test.go @@ -7,6 +7,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" ) @@ -122,6 +123,40 @@ func Test_MergeValues(t *testing.T) { } } +func Test_ValidateHostname(t *testing.T) { + validNames := []string{ + "localhost", + "my-hostname", + "my.hostname", + "my-host-name123", + "_underscores.are.allowed.too", + "", // Allow not specifying a hostname + } + + invalidNames := []string{ + "localhost\n13.104.0.1 ip6-localhost ip6-loopback localhost", + "localhost\n2603:1000::1 ip6-localhost ip6-loopback", + "hello@microsoft.com", + "has space", + "has,comma", + "\x00", + "a\nb", + strings.Repeat("a", 1000), + } + + for _, n := range validNames { + if err := ValidateHostname(n); err != nil { + t.Fatalf("expected %q to be valid, got: %v", n, err) + } + } + + for _, n := range invalidNames { + if err := ValidateHostname(n); err == nil { + t.Fatalf("expected %q to be invalid, but got nil error", n) + } + } +} + func Test_GenerateEtcHostsContent(t *testing.T) { type testcase struct { name string diff --git a/internal/guest/runtime/hcsv2/sandbox_container.go b/internal/guest/runtime/hcsv2/sandbox_container.go index 7456e1462a..da29a95835 100644 --- a/internal/guest/runtime/hcsv2/sandbox_container.go +++ b/internal/guest/runtime/hcsv2/sandbox_container.go @@ -54,6 +54,9 @@ func setupSandboxContainerSpec(ctx context.Context, id string, spec *oci.Spec) ( // Write the hostname hostname := spec.Hostname + if err = network.ValidateHostname(hostname); err != nil { + return err + } if hostname == "" { var err error hostname, err = os.Hostname() diff --git a/internal/guest/runtime/hcsv2/standalone_container.go b/internal/guest/runtime/hcsv2/standalone_container.go index bb1c5ad390..296b328cf5 100644 --- a/internal/guest/runtime/hcsv2/standalone_container.go +++ b/internal/guest/runtime/hcsv2/standalone_container.go @@ -61,6 +61,9 @@ func setupStandaloneContainerSpec(ctx context.Context, id string, spec *oci.Spec }() hostname := spec.Hostname + if err = network.ValidateHostname(hostname); err != nil { + return err + } if hostname == "" { var err error hostname, err = os.Hostname() From 6d74816ba902feac67d7afb60deff24162b7c99e Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Tue, 11 Nov 2025 16:48:20 +0000 Subject: [PATCH 11/12] Merged PR 12878779: Enhance uvm_state::hostMounts to track in-use mounts, and prevent unmounting or deleting in-use things [cherry-picked from d0334883cd43eecbb401a6ded3e0317179a3e54b] This set of changes adds some checks (when running with a confidential policy) to prevent the host from trying to clean up mounts, overlays, or the container states dir when the container is running (or when the overlay has not been unmounted yet). This is through enhancing the existing `hostMounts` utility, as well as adding a `terminated` flag to the Container struct. The correct order of operations should always be: - mount read-only layers and scratch (in any order, and individual containers (not the sandbox) might not have their own scratch) - mount the overlay - start the container - container terminates - unmount overlay - unmount read-only layers and scratch The starting up order is implied, and we now explicitly deny e.g. unmounting layer/scratch before unmounting overlay, or unmounting the overlay while container has not terminated. We also deny deleteContainerState requests when the container is running or when the overlay is mounted. Doing so when a container is running can result in unexpectedly deleting its files, which breaks it in unpredictable ways and is bad. Signed-off-by: Tingmao Wang --- internal/guest/bridge/bridge_v2.go | 8 +- internal/guest/runtime/hcsv2/container.go | 3 + internal/guest/runtime/hcsv2/process.go | 1 + internal/guest/runtime/hcsv2/uvm.go | 245 +++++++++--- internal/guest/runtime/hcsv2/uvm_state.go | 359 +++++++++++++++--- .../guest/runtime/hcsv2/uvm_state_test.go | 219 ++++++++++- 6 files changed, 737 insertions(+), 98 deletions(-) diff --git a/internal/guest/bridge/bridge_v2.go b/internal/guest/bridge/bridge_v2.go index 800094e549..2f105ef5b6 100644 --- a/internal/guest/bridge/bridge_v2.go +++ b/internal/guest/bridge/bridge_v2.go @@ -467,16 +467,10 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message) } - c, err := b.hostState.GetCreatedContainer(request.ContainerID) + err = b.hostState.DeleteContainerState(ctx, request.ContainerID) if err != nil { return nil, err } - // remove container state regardless of delete's success - defer b.hostState.RemoveContainer(request.ContainerID) - - if err := c.Delete(ctx); err != nil { - return nil, err - } return &prot.MessageResponseBase{}, nil } diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index bb9c3af5ea..886eb0528d 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -73,6 +73,9 @@ type Container struct { // and deal with the extra pointer dereferencing overhead. status atomic.Uint32 + // Set to true when the init process for the container has exited + terminated atomic.Bool + // scratchDirPath represents the path inside the UVM where the scratch directory // of this container is located. Usually, this is either `/run/gcs/c/` or // `/run/gcs/c//container_` if scratch is shared with UVM scratch. diff --git a/internal/guest/runtime/hcsv2/process.go b/internal/guest/runtime/hcsv2/process.go index e94c9792f6..96564cfab0 100644 --- a/internal/guest/runtime/hcsv2/process.go +++ b/internal/guest/runtime/hcsv2/process.go @@ -99,6 +99,7 @@ func newProcess(c *Container, spec *oci.Process, process runtime.Process, pid ui log.G(ctx).WithError(err).Error("failed to wait for runc process") } p.exitCode = exitCode + c.terminated.Store(true) log.G(ctx).WithField("exitCode", p.exitCode).Debug("process exited") // Free any process waiters diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 694ced53e5..6d61222d84 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -27,6 +27,7 @@ import ( "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "go.opencensus.io/trace" "golang.org/x/sys/unix" "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" @@ -45,6 +46,7 @@ import ( "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/oci" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" @@ -112,7 +114,9 @@ type Host struct { securityOptions *securitypolicy.SecurityOptions // hostMounts keeps the state of currently mounted devices and file systems, - // which is used for GCS hardening. + // which is used for GCS hardening. It is only used for confidential + // containers, and is initialized in SetConfidentialUVMOptions. If this is + // nil, we do not do add any special restrictions on mounts / unmounts. hostMounts *hostMounts // A permanent flag to indicate that further mounts, unmounts and container // creation should not be allowed. This is set when, because of a failure @@ -141,7 +145,7 @@ func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer s rtime: rtime, vsock: vsock, devNullTransport: &transport.DevNullTransport{}, - hostMounts: newHostMounts(), + hostMounts: nil, securityOptions: securityPolicyOptions, mountsBroken: atomic.Bool{}, } @@ -400,6 +404,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM isSandbox: criType == "sandbox", exitType: prot.NtUnexpectedExit, processes: make(map[uint32]*containerProcess), + terminated: atomic.Bool{}, scratchDirPath: settings.ScratchDirPath, } c.setStatus(containerCreating) @@ -743,6 +748,25 @@ func writeSpecToFile(ctx context.Context, configFile string, spec *specs.Spec) e return nil } +// Returns whether there is a running container that is currently using the +// given overlay (as its rootfs). +func (h *Host) IsOverlayInUse(overlayPath string) bool { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + + for _, c := range h.containers { + if c.terminated.Load() { + continue + } + + if c.spec.Root.Path == overlayPath { + return true + } + } + + return false +} + func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) (retErr error) { if h.HasSecurityPolicy() { if err := checkValidContainerID(containerID, "container"); err != nil { @@ -766,35 +790,6 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * return err } mvd.Controller = cNum - // first we try to update the internal state for read-write attachments. - if !mvd.ReadOnly { - localCtx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - source, err := scsi.GetDevicePath(localCtx, mvd.Controller, mvd.Lun, mvd.Partition) - if err != nil { - return err - } - switch req.RequestType { - case guestrequest.RequestTypeAdd: - if err := h.hostMounts.AddRWDevice(mvd.MountPath, source, mvd.Encrypted); err != nil { - return err - } - defer func() { - if retErr != nil { - _ = h.hostMounts.RemoveRWDevice(mvd.MountPath, source) - } - }() - case guestrequest.RequestTypeRemove: - if err := h.hostMounts.RemoveRWDevice(mvd.MountPath, source); err != nil { - return err - } - defer func() { - if retErr != nil { - _ = h.hostMounts.AddRWDevice(mvd.MountPath, source, mvd.Encrypted) - } - }() - } - } return h.modifyMappedVirtualDisk(ctx, req.RequestType, mvd) case guestresource.ResourceTypeMappedDirectory: if err := h.checkMountsNotBroken(); err != nil { @@ -813,12 +808,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * return err } - cl := req.Settings.(*guestresource.LCOWCombinedLayers) - // when cl.ScratchPath == "", we mount overlay as read-only, in which case - // we don't really care about scratch encryption, since the host already - // knows about the layers and the overlayfs. - encryptedScratch := cl.ScratchPath != "" && h.hostMounts.IsEncrypted(cl.ScratchPath) - return h.modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch) + return h.modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers)) case guestresource.ResourceTypeNetwork: return modifyNetwork(ctx, req.RequestType, req.Settings.(*guestresource.LCOWNetworkAdapter)) case guestresource.ResourceTypeVPCIDevice: @@ -834,10 +824,22 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * if !ok { return errors.New("the request's settings are not of type ConfidentialOptions") } - return h.securityOptions.SetConfidentialOptions(ctx, + err := h.securityOptions.SetConfidentialOptions(ctx, r.EnforcerType, r.EncodedSecurityPolicy, r.EncodedUVMReference) + if err != nil { + return err + } + + // Start tracking mounts and restricting unmounts on confidential containers. + // As long as we started off with the ClosedDoorSecurityPolicyEnforcer, no + // mounts should have been allowed until this point. + if h.HasSecurityPolicy() { + log.G(ctx).Debug("hostMounts initialized") + h.hostMounts = newHostMounts() + } + return nil case guestresource.ResourceTypePolicyFragment: r, ok := req.Settings.(*guestresource.SecurityPolicyFragment) if !ok { @@ -1211,18 +1213,33 @@ func (h *Host) modifyMappedVirtualDisk( rt guestrequest.RequestType, mvd *guestresource.LCOWMappedVirtualDisk, ) (err error) { + ctx, span := oc.StartSpan(ctx, "gcs::Host::modifyMappedVirtualDisk") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("requestType", string(rt)), + trace.BoolAttribute("hasHostMounts", h.hostMounts != nil), + trace.Int64Attribute("controller", int64(mvd.Controller)), + trace.Int64Attribute("lun", int64(mvd.Lun)), + trace.Int64Attribute("partition", int64(mvd.Partition)), + trace.BoolAttribute("readOnly", mvd.ReadOnly), + trace.StringAttribute("mountPath", mvd.MountPath), + ) + var verityInfo *guestresource.DeviceVerityInfo securityPolicy := h.securityOptions.PolicyEnforcer + devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) + if err != nil { + return err + } + span.AddAttributes(trace.StringAttribute("devicePath", devPath)) + if mvd.ReadOnly { // The only time the policy is empty, and we want it to be empty // is when no policy is provided, and we default to open door // policy. In any other case, e.g. explicit open door or any // other rego policy we would like to mount layers with verity. if h.HasSecurityPolicy() { - devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) - if err != nil { - return err - } verityInfo, err = verity.ReadVeritySuperBlock(ctx, devPath) if err != nil { return err @@ -1257,11 +1274,42 @@ func (h *Host) modifyMappedVirtualDisk( if err != nil { return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + err = h.hostMounts.AddRODevice(mvd.MountPath, devPath) + if err != nil { + return err + } + // Note: "When a function returns, its deferred calls are + // executed in last-in-first-out order." - so we are safe to + // call RemoveRODevice in this defer. + defer func() { + if err != nil { + _ = h.hostMounts.RemoveRODevice(mvd.MountPath, devPath) + } + }() + } } else { err = securityPolicy.EnforceRWDeviceMountPolicy(ctx, mvd.MountPath, mvd.Encrypted, mvd.EnsureFilesystem, mvd.Filesystem) if err != nil { return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + err = h.hostMounts.AddRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.RemoveRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + } + }() + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1286,10 +1334,36 @@ func (h *Host) modifyMappedVirtualDisk( if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + if err = h.hostMounts.RemoveRODevice(mvd.MountPath, devPath); err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.AddRODevice(mvd.MountPath, devPath) + } + }() + } } else { if err = securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + if err = h.hostMounts.RemoveRWDevice(mvd.MountPath, devPath, mvd.Encrypted); err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.AddRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + } + }() + } } // Check that the directory actually exists first, and if it does // not then we just refuse to do anything, without closing the dm @@ -1465,8 +1539,17 @@ func (h *Host) modifyCombinedLayers( ctx context.Context, rt guestrequest.RequestType, cl *guestresource.LCOWCombinedLayers, - scratchEncrypted bool, ) (err error) { + ctx, span := oc.StartSpan(ctx, "gcs::Host::modifyCombinedLayers") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("requestType", string(rt)), + trace.BoolAttribute("hasHostMounts", h.hostMounts != nil), + trace.StringAttribute("containerRootPath", cl.ContainerRootPath), + trace.StringAttribute("scratchPath", cl.ScratchPath), + ) + securityPolicy := h.securityOptions.PolicyEnforcer containerID := cl.ContainerID @@ -1480,6 +1563,12 @@ func (h *Host) modifyCombinedLayers( } defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + if h.hostMounts != nil { + // We will need this in multiple places, let's take the lock once here. + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + } + switch rt { case guestrequest.RequestTypeAdd: if h.HasSecurityPolicy() { @@ -1527,15 +1616,29 @@ func (h *Host) modifyCombinedLayers( } else { upperdirPath = filepath.Join(cl.ScratchPath, "upper") workdirPath = filepath.Join(cl.ScratchPath, "work") + scratchEncrypted := false + if h.hostMounts != nil { + scratchEncrypted = h.hostMounts.IsEncrypted(cl.ScratchPath) + } if err := securityPolicy.EnforceScratchMountPolicy(ctx, cl.ScratchPath, scratchEncrypted); err != nil { return fmt.Errorf("scratch mounting denied by policy: %w", err) } } - if err := securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { + if err = securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { return fmt.Errorf("overlay creation denied by policy: %w", err) } + if h.hostMounts != nil { + if err = h.hostMounts.AddOverlay(cl.ContainerRootPath, layerPaths, cl.ScratchPath); err != nil { + return err + } + defer func() { + if err != nil { + _, _ = h.hostMounts.RemoveOverlay(cl.ContainerRootPath) + } + }() + } // Correctness for policy revertable section: // MountLayer does two things - mkdir, then mount. On mount failure, the @@ -1549,6 +1652,23 @@ func (h *Host) modifyCombinedLayers( return errors.Wrap(err, "overlay removal denied by policy") } + // Check that no running container is using this overlay as its rootfs. + if h.HasSecurityPolicy() && h.IsOverlayInUse(cl.ContainerRootPath) { + return fmt.Errorf("overlay %q is in use by a running container", cl.ContainerRootPath) + } + + if h.hostMounts != nil { + var undoRemoveOverlay func() + if undoRemoveOverlay, err = h.hostMounts.RemoveOverlay(cl.ContainerRootPath); err != nil { + return err + } + defer func() { + if err != nil && undoRemoveOverlay != nil { + undoRemoveOverlay() + } + }() + } + // Note: storage.UnmountPath is a no-op if the path does not exist. err = storage.UnmountPath(ctx, cl.ContainerRootPath, true) if err != nil { @@ -1866,3 +1986,40 @@ func (h *Host) commitOrRollbackPolicyRevSection( rev.Commit() } } + +func (h *Host) DeleteContainerState(ctx context.Context, containerID string) error { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } + } + + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + c, err := h.GetCreatedContainer(containerID) + if err != nil { + return err + } + if h.HasSecurityPolicy() { + if !c.terminated.Load() { + return errors.Errorf("Denied deleting state of a running container %q", containerID) + } + overlay := c.spec.Root.Path + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + if h.hostMounts.HasOverlayMountedAt(overlay) { + return errors.Errorf("Denied deleting state of a container with a overlay mount still active") + } + } + + // remove container state regardless of delete's success + defer h.RemoveContainer(containerID) + + if err = c.Delete(ctx); err != nil { + return err + } + + return nil +} diff --git a/internal/guest/runtime/hcsv2/uvm_state.go b/internal/guest/runtime/hcsv2/uvm_state.go index dd1ff521f0..96e64371a2 100644 --- a/internal/guest/runtime/hcsv2/uvm_state.go +++ b/internal/guest/runtime/hcsv2/uvm_state.go @@ -4,91 +4,360 @@ package hcsv2 import ( + "context" + "errors" "fmt" "path/filepath" "strings" "sync" + + "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/sirupsen/logrus" +) + +type deviceType int + +const ( + DeviceTypeRW deviceType = iota + DeviceTypeRO + DeviceTypeOverlay ) -type rwDevice struct { +func (d deviceType) String() string { + switch d { + case DeviceTypeRW: + return "RW" + case DeviceTypeRO: + return "RO" + case DeviceTypeOverlay: + return "Overlay" + default: + return fmt.Sprintf("Unknown(%d)", d) + } +} + +type device struct { + // fields common to all mountPath string + ty deviceType + usage int sourcePath string - encrypted bool + + // rw devices + encrypted bool + + // overlay devices + referencedDevices []*device } +// hostMounts tracks the state of fs/overlay mounts and their usage +// relationship. Users of this struct must call hm.Lock() before calling any +// other methods and call hm.Unlock() when done. +// +// Since mount/unmount operations can fail, the expected way to use this struct +// is to first lock it, call the method to add/remove the device, then, with the +// lock still held, perform the actual operation. If the operation fails, the +// caller must undo the operation by calling the appropriate remove/add method +// or the returned undo function, before unlocking. type hostMounts struct { - stateMutex sync.Mutex + stateMutex sync.Mutex + stateMutexLocked bool - // Holds information about read-write devices, which can be encrypted and - // contain overlay fs upper/work directory mounts. - readWriteMounts map[string]*rwDevice + // Map from mountPath to device struct + devices map[string]*device } func newHostMounts() *hostMounts { return &hostMounts{ - readWriteMounts: map[string]*rwDevice{}, + devices: make(map[string]*device), } } -// AddRWDevice adds read-write device metadata for device mounted at `mountPath`. -// Returns an error if there's an existing device mounted at `mountPath` location. -func (hm *hostMounts) AddRWDevice(mountPath string, sourcePath string, encrypted bool) error { +func (hm *hostMounts) expectLocked() { + if !hm.stateMutexLocked { + gcs.UnrecoverableError(errors.New("hostMounts: expected stateMutex to be locked, but it was not")) + } +} + +// Locks the state mutex. This is not re-entrant, calling it twice in the same +// thread will deadlock/panic. +func (hm *hostMounts) Lock() { hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() + // Since we just acquired the lock, either it was not locked before, or + // somebody just unlocked it. Either case, hm.stateMutexLocked should be + // false. + if hm.stateMutexLocked { + gcs.UnrecoverableError(errors.New("hostMounts: stateMutexLocked already true when locking stateMutex")) + } + hm.stateMutexLocked = true +} + +// Unlocks the state mutex +func (hm *hostMounts) Unlock() { + hm.expectLocked() + hm.stateMutexLocked = false + hm.stateMutex.Unlock() +} - mountTarget := filepath.Clean(mountPath) - if source, ok := hm.readWriteMounts[mountTarget]; ok { - return fmt.Errorf("read-write with source %q and mount target %q already exists", source.sourcePath, mountPath) +func (hm *hostMounts) findDeviceAtPath(mountPath string) *device { + hm.expectLocked() + + if dev, ok := hm.devices[mountPath]; ok { + return dev } - hm.readWriteMounts[mountTarget] = &rwDevice{ - mountPath: mountTarget, - sourcePath: sourcePath, - encrypted: encrypted, + return nil +} + +func (hm *hostMounts) addDeviceToMapChecked(dev *device) error { + hm.expectLocked() + + if _, ok := hm.devices[dev.mountPath]; ok { + return fmt.Errorf("device at mount path %q already exists", dev.mountPath) } + hm.devices[dev.mountPath] = dev return nil } -// RemoveRWDevice removes the read-write device metadata for device mounted at -// `mountPath`. -func (hm *hostMounts) RemoveRWDevice(mountPath string, sourcePath string) error { - hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() +func (hm *hostMounts) findDeviceContainingPath(path string) *device { + hm.expectLocked() + + // TODO: can we refactor this function by walking each component of the path + // from leaf to root, each time checking if the current component is a mount + // point? (i.e. why do we have to use filepath.Rel?) + + var foundDev *device + cleanPath := filepath.Clean(path) + for devPath, dev := range hm.devices { + relPath, err := filepath.Rel(devPath, cleanPath) + // skip further checks if an error is returned or the relative path + // contains "..", meaning that the `path` isn't directly nested under + // `rwPath`. + if err != nil || strings.HasPrefix(relPath, "..") { + continue + } + if foundDev == nil { + foundDev = dev + } else if len(dev.mountPath) > len(foundDev.mountPath) { + // The current device is mounted on top of a previously found device. + foundDev = dev + } + } + return foundDev +} + +func (hm *hostMounts) usePath(path string) (*device, error) { + hm.expectLocked() + + // Find the device at the given path and increment its usage count. + dev := hm.findDeviceContainingPath(path) + if dev == nil { + return nil, nil + } + dev.usage++ + return dev, nil +} + +func (hm *hostMounts) releaseDeviceUsage(dev *device) { + hm.expectLocked() + + if dev.usage <= 0 { + log.G(context.Background()).WithFields(logrus.Fields{ + "device": dev.mountPath, + "deviceSource": dev.sourcePath, + "deviceType": dev.ty, + "usage": dev.usage, + }).Error("hostMounts::releaseDeviceUsage: unexpected zero usage count") + return + } + dev.usage-- +} + +// User should carefully handle side-effects of adding a device if the device +// fails to be added. +func (hm *hostMounts) doAddDevice(mountPath string, ty deviceType, sourcePath string) (*device, error) { + hm.expectLocked() + + dev := &device{ + mountPath: filepath.Clean(mountPath), + ty: ty, + usage: 0, + sourcePath: sourcePath, + } + + if err := hm.addDeviceToMapChecked(dev); err != nil { + return nil, err + } + return dev, nil +} + +// Once checks is called, unless it returns an error, this function will always +// succeed +func (hm *hostMounts) doRemoveDevice(mountPath string, ty deviceType, sourcePath string, checks func(*device) error) error { + hm.expectLocked() unmountTarget := filepath.Clean(mountPath) - device, ok := hm.readWriteMounts[unmountTarget] - if !ok { + device := hm.findDeviceAtPath(unmountTarget) + if device == nil { // already removed or didn't exist return nil } if device.sourcePath != sourcePath { - return fmt.Errorf("wrong sourcePath %s", sourcePath) + return fmt.Errorf("wrong sourcePath %s, expected %s", sourcePath, device.sourcePath) + } + if device.ty != ty { + return fmt.Errorf("wrong device type %s, expected %s", ty, device.ty) + } + if device.usage > 0 { + log.G(context.Background()).WithFields(logrus.Fields{ + "device": device.mountPath, + "deviceSource": device.sourcePath, + "deviceType": device.ty, + "usage": device.usage, + }).Error("hostMounts::doRemoveDevice: device still in use, refusing unmount") + return fmt.Errorf("device at %q is still in use, can't unmount", unmountTarget) + } + if checks != nil { + if err := checks(device); err != nil { + return err + } } - delete(hm.readWriteMounts, unmountTarget) + delete(hm.devices, unmountTarget) return nil } +func (hm *hostMounts) AddRODevice(mountPath string, sourcePath string) error { + hm.expectLocked() + + _, err := hm.doAddDevice(mountPath, DeviceTypeRO, sourcePath) + return err +} + +// AddRWDevice adds read-write device metadata for device mounted at `mountPath`. +// Returns an error if there's an existing device mounted at `mountPath` location. +func (hm *hostMounts) AddRWDevice(mountPath string, sourcePath string, encrypted bool) error { + hm.expectLocked() + + dev, err := hm.doAddDevice(mountPath, DeviceTypeRW, sourcePath) + if err != nil { + return err + } + dev.encrypted = encrypted + return nil +} + +func (hm *hostMounts) AddOverlay(mountPath string, layers []string, scratchDir string) (err error) { + hm.expectLocked() + + dev, err := hm.doAddDevice(mountPath, DeviceTypeOverlay, mountPath) + if err != nil { + return err + } + dev.referencedDevices = make([]*device, 0, len(layers)+1) + defer func() { + if err != nil { + // If we failed to use any of the paths, we need to release the ones + // that we did use. + for _, d := range dev.referencedDevices { + hm.releaseDeviceUsage(d) + } + delete(hm.devices, mountPath) + } + }() + + for _, layer := range layers { + refDev, err := hm.usePath(layer) + if err != nil { + return err + } + if refDev != nil { + dev.referencedDevices = append(dev.referencedDevices, refDev) + } + } + refDev, err := hm.usePath(scratchDir) + if err != nil { + return err + } + if refDev != nil { + dev.referencedDevices = append(dev.referencedDevices, refDev) + } + + return nil +} + +func (hm *hostMounts) RemoveRODevice(mountPath string, sourcePath string) error { + hm.expectLocked() + + return hm.doRemoveDevice(mountPath, DeviceTypeRO, sourcePath, nil) +} + +// RemoveRWDevice removes the read-write device metadata for device mounted at +// `mountPath`. +func (hm *hostMounts) RemoveRWDevice(mountPath string, sourcePath string, encrypted bool) error { + hm.expectLocked() + + return hm.doRemoveDevice(mountPath, DeviceTypeRW, sourcePath, func(dev *device) error { + if dev.encrypted != encrypted { + return fmt.Errorf("encrypted flag wrong, provided %v, expected %v", encrypted, dev.encrypted) + } + return nil + }) +} + +func (hm *hostMounts) RemoveOverlay(mountPath string) (undo func(), err error) { + hm.expectLocked() + + var dev *device + err = hm.doRemoveDevice(mountPath, DeviceTypeOverlay, mountPath, func(_dev *device) error { + dev = _dev + for _, refDev := range dev.referencedDevices { + hm.releaseDeviceUsage(refDev) + } + return nil + }) + if err != nil { + // If we get an error from doRemoveDevice, we have not released anything + // yet. + return nil, err + } + undo = func() { + hm.expectLocked() + + for _, refDev := range dev.referencedDevices { + refDev.usage++ + } + + if _, ok := hm.devices[mountPath]; ok { + log.G(context.Background()).WithField("mountPath", mountPath).Error( + "hostMounts::RemoveOverlay: failed to undo remove: device that was removed exists in map", + ) + return + } + + hm.devices[mountPath] = dev + } + return undo, nil +} + // IsEncrypted checks if the given path is a sub-path of an encrypted read-write // device. func (hm *hostMounts) IsEncrypted(path string) bool { - hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() + hm.expectLocked() - parentPath := "" - encrypted := false - cleanPath := filepath.Clean(path) - for rwPath, rwDev := range hm.readWriteMounts { - relPath, err := filepath.Rel(rwPath, cleanPath) - // skip further checks if an error is returned or the relative path - // contains "..", meaning that the `path` isn't directly nested under - // `rwPath`. - if err != nil || strings.HasPrefix(relPath, "..") { - continue - } - if len(rwDev.mountPath) > len(parentPath) { - parentPath = rwDev.mountPath - encrypted = rwDev.encrypted - } + dev := hm.findDeviceContainingPath(path) + if dev == nil { + return false + } + return dev.encrypted +} + +func (hm *hostMounts) HasOverlayMountedAt(path string) bool { + hm.expectLocked() + + dev := hm.findDeviceAtPath(filepath.Clean(path)) + if dev == nil { + return false } - return encrypted + return dev.ty == DeviceTypeOverlay } diff --git a/internal/guest/runtime/hcsv2/uvm_state_test.go b/internal/guest/runtime/hcsv2/uvm_state_test.go index b708caaeba..e87a207308 100644 --- a/internal/guest/runtime/hcsv2/uvm_state_test.go +++ b/internal/guest/runtime/hcsv2/uvm_state_test.go @@ -12,10 +12,13 @@ func Test_Add_Remove_RWDevice(t *testing.T) { mountPath := "/run/gcs/c/abcd" sourcePath := "/dev/sda" + hm.Lock() + defer hm.Unlock() + if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error adding RW device: %s", err) } - if err := hm.RemoveRWDevice(mountPath, sourcePath); err != nil { + if err := hm.RemoveRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error removing RW device: %s", err) } } @@ -25,29 +28,55 @@ func Test_Cannot_AddRWDevice_Twice(t *testing.T) { mountPath := "/run/gcs/c/abc" sourcePath := "/dev/sda" + hm.Lock() if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error: %s", err) } + hm.Unlock() + + hm.Lock() if err := hm.AddRWDevice(mountPath, sourcePath, false); err == nil { t.Fatalf("expected error adding %q for the second time", mountPath) } + hm.Unlock() } func Test_Cannot_RemoveRWDevice_Wrong_Source(t *testing.T) { hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + mountPath := "/run/gcs/c/abcd" sourcePath := "/dev/sda" wrongSource := "/dev/sdb" if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error: %s", err) } - if err := hm.RemoveRWDevice(mountPath, wrongSource); err == nil { + if err := hm.RemoveRWDevice(mountPath, wrongSource, false); err == nil { t.Fatalf("expected error removing wrong source %s", wrongSource) } } +func Test_Cannot_RemoveRWDevice_Wrong_Encrypted(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abcd" + sourcePath := "/dev/sda" + if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := hm.RemoveRWDevice(mountPath, sourcePath, true); err == nil { + t.Fatalf("expected error removing RW device with wrong encrypted flag") + } +} + func Test_HostMounts_IsEncrypted(t *testing.T) { hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + encryptedPath := "/run/gcs/c/encrypted" encryptedSource := "/dev/sda" if err := hm.AddRWDevice(encryptedPath, encryptedSource, true); err != nil { @@ -108,3 +137,189 @@ func Test_HostMounts_IsEncrypted(t *testing.T) { }) } } + +func Test_HostMounts_AddRemoveRODevice(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abcd" + sourcePath := "/dev/sda" + + if err := hm.AddRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + + if err := hm.RemoveRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error removing RO device: %s", err) + } +} + +func Test_HostMounts_Cannot_AddRODevice_Twice(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abc" + sourcePath := "/dev/sda" + + if err := hm.AddRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := hm.AddRODevice(mountPath, sourcePath); err == nil { + t.Fatalf("expected error adding %q for the second time", mountPath) + } +} + +func Test_HostMounts_AddRemoveOverlay(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/aaaa/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + scratchDir := "/run/gcs/c/aaaa/scratch" + if err := hm.AddRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(mountPath, layers, scratchDir); err != nil { + t.Fatalf("unexpected error adding overlay: %s", err) + } + undo, err := hm.RemoveOverlay(mountPath) + if err != nil { + t.Fatalf("unexpected error removing overlay: %s", err) + } + if undo == nil { + t.Fatalf("expected undo function to be non-nil") + } + undo() + if _, err = hm.RemoveOverlay(mountPath); err != nil { + t.Fatalf("unexpected error removing overlay again: %s", err) + } +} + +func Test_HostMounts_Cannot_RemoveInUseDeviceByOverlay(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/aaaa/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + scratchDir := "/run/gcs/c/aaaa/scratch" + if err := hm.AddRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(mountPath, layers, scratchDir); err != nil { + t.Fatalf("unexpected error adding overlay: %s", err) + } + + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s while in use by overlay", layer) + } + } + if err := hm.RemoveRWDevice(scratchDir, scratchDir, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay", scratchDir) + } + + if _, err := hm.RemoveOverlay(mountPath); err != nil { + t.Fatalf("unexpected error removing overlay: %s", err) + } + + // now we can remove + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error removing RO device %s: %s", layer, err) + } + } + if err := hm.RemoveRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error removing RW device %s: %s", scratchDir, err) + } +} + +func Test_HostMounts_Cannot_RemoveInUseDeviceByOverlay_MultipleUsers(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + overlay1 := "/run/gcs/c/aaaa/rootfs" + overlay2 := "/run/gcs/c/bbbb/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + sharedScratchMount := "/run/gcs/c/sandbox" + scratch1 := sharedScratchMount + "/scratch/aaaa" + scratch2 := sharedScratchMount + "/scratch/bbbb" + if err := hm.AddRWDevice(sharedScratchMount, sharedScratchMount, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(overlay1, layers, scratch1); err != nil { + t.Fatalf("unexpected error adding overlay1: %s", err) + } + + if err := hm.AddOverlay(overlay2, layers[0:2], scratch2); err != nil { + t.Fatalf("unexpected error adding overlay2: %s", err) + } + + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s while in use by overlay", layer) + } + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay", sharedScratchMount) + } + + if _, err := hm.RemoveOverlay(overlay1); err != nil { + t.Fatalf("unexpected error removing overlay 1: %s", err) + } + + for _, layer := range layers[0:2] { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s (still in use by overlay 2)", layer) + } + } + if err := hm.RemoveRODevice(layers[2], layers[2]); err != nil { + t.Fatalf("unexpected error removing layers[2] which is not being used by overlay 2: %s", err) + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay 2", scratch2) + } + + if _, err := hm.RemoveOverlay(overlay2); err != nil { + t.Fatalf("unexpected error removing overlay 2: %s", err) + } + for _, layer := range layers[0:2] { + if err := hm.RemoveRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error removing RO device %s: %s", layer, err) + } + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err != nil { + t.Fatalf("unexpected error removing RW device %s: %s", sharedScratchMount, err) + } +} From 2e2e558641744d735644c884e6098a2983c0ddc2 Mon Sep 17 00:00:00 2001 From: Tingmao Wang Date: Tue, 11 Nov 2025 16:55:12 +0000 Subject: [PATCH 12/12] Merged PR 13627088: guest: Don't allow host to set mount options [cherry-picked from 1dd0b7ea0b0f91d3698f6008fb0bd5b0de777da6] Blocks mount option passing for 9p (which is accidental) and SCSI disks. - guest: Restrict plan9 share names to digits only on Confidential mode - hcsv2/uvm: Restrict SCSI mount options in confidential mode (The only one we allow is `ro`) Related work items: #34370380 Signed-off-by: Tingmao Wang --- internal/guest/runtime/hcsv2/uvm.go | 24 ++++++++++++++++++++++++ internal/guest/storage/plan9/plan9.go | 14 ++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 6d61222d84..33094d9848 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -1265,6 +1265,24 @@ func (h *Host) modifyMappedVirtualDisk( mountCtx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() if mvd.MountPath != "" { + if h.HasSecurityPolicy() { + // The only option we allow if there is policy enforcement is + // "ro", and it must match the readonly field in the request. + mountOptionHasRo := false + for _, opt := range mvd.Options { + if opt == "ro" { + mountOptionHasRo = true + continue + } + return errors.Errorf("mounting scsi device controller %d lun %d onto %s: mount option %q denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath, opt) + } + if mvd.ReadOnly != mountOptionHasRo { + return errors.Errorf( + "mounting scsi device controller %d lun %d onto %s with mount option %q failed due to mount option mismatch: mvd.ReadOnly=%t but mountOptionHasRo=%t", + mvd.Controller, mvd.Lun, mvd.MountPath, strings.Join(mvd.Options, ","), mvd.ReadOnly, mountOptionHasRo, + ) + } + } if mvd.ReadOnly { var deviceHash string if verityInfo != nil { @@ -1429,6 +1447,12 @@ func (h *Host) modifyMappedDirectory( return errors.Wrapf(err, "mounting plan9 device at %s denied by policy", md.MountPath) } + if h.HasSecurityPolicy() { + if err = plan9.ValidateShareName(md.ShareName); err != nil { + return err + } + } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're // rolling back the policy metadata, plan9.Mount here must clean up // everything if it fails, which it does do. diff --git a/internal/guest/storage/plan9/plan9.go b/internal/guest/storage/plan9/plan9.go index 5c1f1d74f4..44ac0f4e4e 100644 --- a/internal/guest/storage/plan9/plan9.go +++ b/internal/guest/storage/plan9/plan9.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "regexp" "syscall" "github.com/Microsoft/hcsshim/internal/guest/transport" @@ -25,6 +26,19 @@ var ( unixMount = unix.Mount ) +// c.f. v9fs_parse_options in linux/fs/9p/v9fs.c - technically anything other +// than ',' is ok (quoting is not handled), however, this name is generated from +// a counter in AddPlan9 (internal/uvm/plan9.go), and therefore we expect only +// digits from a normal hcsshim host. +var validShareNameRegex = regexp.MustCompile(`^[0-9]+$`) + +func ValidateShareName(name string) error { + if !validShareNameRegex.MatchString(name) { + return fmt.Errorf("invalid plan9 share name %q: must match regex %q", name, validShareNameRegex.String()) + } + return nil +} + // Mount dials a connection from `vsock` and mounts a Plan9 share to `target`. // // `target` will be created. On mount failure the created `target` will be