Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 108 additions & 4 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ const (
deviceListEnvVar = "NVIDIA_VISIBLE_DEVICES"
deviceListAsVolumeMountsHostPath = "/dev/null"
deviceListAsVolumeMountsContainerPathRoot = "/var/run/nvidia-container-devices"

// healthChannelBufferSize defines the buffer capacity for the health
// channel. This is sized to handle bursts of unhealthy device reports
// without blocking the health check goroutine. With 8 GPUs and
// potential for multiple events per GPU (XID errors, ECC errors, etc.),
// a buffer of 64 provides ample headroom while using a power-of-2 size
// for cache-friendly alignment.
healthChannelBufferSize = 64
)

// nvidiaDevicePlugin implements the Kubernetes device plugin API
Expand All @@ -64,6 +72,10 @@ type nvidiaDevicePlugin struct {
health chan *rm.Device
stop chan interface{}

// deviceListUpdate is used to trigger ListAndWatch to send updated device
// list to kubelet (e.g., when devices recover from unhealthy state)
deviceListUpdate chan struct{}

imexChannels imex.Channels

mps mpsOptions
Expand Down Expand Up @@ -108,15 +120,20 @@ func getPluginSocketPath(resource spec.ResourceName) string {

func (plugin *nvidiaDevicePlugin) initialize() {
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
plugin.health = make(chan *rm.Device)
plugin.health = make(chan *rm.Device, healthChannelBufferSize)
plugin.stop = make(chan interface{})
plugin.deviceListUpdate = make(chan struct{}, 1)
}

func (plugin *nvidiaDevicePlugin) cleanup() {
close(plugin.stop)
if plugin.deviceListUpdate != nil {
close(plugin.deviceListUpdate)
}
plugin.server = nil
plugin.health = nil
plugin.stop = nil
plugin.deviceListUpdate = nil
}

// Devices returns the full set of devices associated with the plugin.
Expand Down Expand Up @@ -156,6 +173,9 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
}
}()

// Start recovery worker to detect when unhealthy devices become healthy
go plugin.runRecoveryWorker()

return nil
}

Expand Down Expand Up @@ -263,7 +283,9 @@ func (plugin *nvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *plugi
return options, nil
}

// ListAndWatch lists devices and update that list according to the health status
// ListAndWatch lists devices and update that list according to the health
// status. This now supports device recovery: when devices that were marked
// unhealthy recover, they are automatically re-advertised to kubelet.
func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return err
Expand All @@ -274,9 +296,17 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
case <-plugin.stop:
return nil
case d := <-plugin.health:
// FIXME: there is no way to recover from the Unhealthy state.
// Device marked unhealthy by health check
d.Health = pluginapi.Unhealthy
klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID)
klog.Infof("'%s' device marked unhealthy: %s (reason: %s)",
plugin.rm.Resource(), d.ID, d.UnhealthyReason)
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
case <-plugin.deviceListUpdate:
// Device recovery or other device list change
klog.Infof("'%s' device list updated, notifying kubelet",
plugin.rm.Resource())
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
Expand Down Expand Up @@ -512,6 +542,80 @@ func (plugin *nvidiaDevicePlugin) updateResponseForDeviceMounts(response *plugin
}
}

// runRecoveryWorker periodically checks if unhealthy devices have recovered
// and notifies kubelet when they do.
func (plugin *nvidiaDevicePlugin) runRecoveryWorker() {
const recoveryInterval = 30 * time.Second

ticker := time.NewTicker(recoveryInterval)
defer ticker.Stop()

klog.V(2).Infof("Recovery worker started for '%s' (interval=%v)",
plugin.rm.Resource(), recoveryInterval)

for {
select {
case <-plugin.stop:
klog.V(2).Info("Recovery worker stopped")
return
case <-ticker.C:
plugin.checkForRecoveredDevices()
}
}
}

// checkForRecoveredDevices checks all unhealthy devices to see if they have
// recovered. If any have recovered, triggers a device list update to
// kubelet.
func (plugin *nvidiaDevicePlugin) checkForRecoveredDevices() {
recoveredDevices := []*rm.Device{}

for _, d := range plugin.rm.Devices() {
if !d.IsUnhealthy() {
continue
}

// Increment recovery attempts
d.RecoveryAttempts++

// Check if device has recovered
healthy, err := plugin.rm.CheckDeviceHealth(d)
if err != nil {
klog.V(4).Infof("Device %s recovery check failed (attempt %d): %v",
d.ID, d.RecoveryAttempts, err)
continue
}

if healthy {
klog.Infof("Device %s has RECOVERED! Was unhealthy for %v (reason: %s)",
d.ID, d.UnhealthyDuration(), d.UnhealthyReason)
d.MarkHealthy()
recoveredDevices = append(recoveredDevices, d)
} else {
klog.V(3).Infof("Device %s still unhealthy (attempt %d, duration %v)",
d.ID, d.RecoveryAttempts, d.UnhealthyDuration())
}
}

// If any devices recovered, notify ListAndWatch
if len(recoveredDevices) > 0 {
klog.Infof("Total recovered devices: %d", len(recoveredDevices))
plugin.triggerDeviceListUpdate()
}
}

// triggerDeviceListUpdate sends a signal to ListAndWatch to send an updated
// device list to kubelet. Uses a buffered channel with non-blocking send to
// avoid blocking the recovery worker.
func (plugin *nvidiaDevicePlugin) triggerDeviceListUpdate() {
select {
case plugin.deviceListUpdate <- struct{}{}:
klog.V(3).Info("Device list update triggered")
default:
klog.V(4).Info("Device list update already pending, skipping")
}
}

func (plugin *nvidiaDevicePlugin) apiDeviceSpecs(devRoot string, ids []string) []*pluginapi.DeviceSpec {
optional := map[string]bool{
"/dev/nvidiactl": true,
Expand Down
95 changes: 95 additions & 0 deletions internal/plugin/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package plugin

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
Expand Down Expand Up @@ -254,3 +256,96 @@ func TestCDIAllocateResponse(t *testing.T) {
func ptr[T any](x T) *T {
return &x
}

func TestTriggerDeviceListUpdate_Phase2(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a matter of interest, what is Phase2? (Were these tests generated?)

plugin := &nvidiaDevicePlugin{
deviceListUpdate: make(chan struct{}, 1),
}

// First trigger should send signal
plugin.triggerDeviceListUpdate()
select {
case <-plugin.deviceListUpdate:
t.Log("✓ Device list update signal sent")
case <-time.After(100 * time.Millisecond):
t.Fatal("Signal not sent")
}

// Second trigger with pending signal should not block
plugin.triggerDeviceListUpdate()
plugin.triggerDeviceListUpdate() // Should not block
t.Log("✓ triggerDeviceListUpdate doesn't block when signal pending")
}

func TestCheckForRecoveredDevices_Phase2(t *testing.T) {
// Create persistent device map
devices := rm.Devices{
"GPU-0": &rm.Device{
Device: pluginapi.Device{
ID: "GPU-0",
Health: pluginapi.Unhealthy,
},
UnhealthyReason: "XID-79",
},
"GPU-1": &rm.Device{
Device: pluginapi.Device{
ID: "GPU-1",
Health: pluginapi.Unhealthy,
},
UnhealthyReason: "XID-48",
},
"GPU-2": &rm.Device{
Device: pluginapi.Device{
ID: "GPU-2",
Health: pluginapi.Healthy,
},
},
}

// Create mock resource manager with persistent devices
mockRM := &rm.ResourceManagerMock{
DevicesFunc: func() rm.Devices {
return devices
},
CheckDeviceHealthFunc: func(d *rm.Device) (bool, error) {
// GPU-0 recovers, GPU-1 stays unhealthy
if d.ID == "GPU-0" {
return true, nil
}
return false, fmt.Errorf("still unhealthy")
},
}

plugin := &nvidiaDevicePlugin{
rm: mockRM,
deviceListUpdate: make(chan struct{}, 1),
}

plugin.checkForRecoveredDevices()

// Verify GPU-0 recovered
gpu0 := devices["GPU-0"]
require.Equal(t, pluginapi.Healthy, gpu0.Health, "GPU-0 should be healthy")
require.Equal(t, "", gpu0.UnhealthyReason)
t.Logf("✓ GPU-0 recovered: Health=%s, Reason=%s", gpu0.Health, gpu0.UnhealthyReason)

// Verify GPU-1 still unhealthy
gpu1 := devices["GPU-1"]
require.Equal(t, pluginapi.Unhealthy, gpu1.Health, "GPU-1 should still be unhealthy")
require.Equal(t, 1, gpu1.RecoveryAttempts, "GPU-1 recovery attempts should increment")
t.Logf("✓ GPU-1 still unhealthy: attempts=%d", gpu1.RecoveryAttempts)

// Verify GPU-2 unchanged
gpu2 := devices["GPU-2"]
require.Equal(t, pluginapi.Healthy, gpu2.Health)
require.Equal(t, 0, gpu2.RecoveryAttempts, "Healthy device shouldn't be probed")
t.Log("✓ GPU-2 unchanged (was already healthy)")

// Verify deviceListUpdate was triggered
select {
case <-plugin.deviceListUpdate:
t.Log("✓ Device list update triggered for recovery")
case <-time.After(100 * time.Millisecond):
t.Fatal("Device list update not triggered")
}
}
41 changes: 41 additions & 0 deletions internal/rm/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"strconv"
"strings"
"time"

"k8s.io/klog/v2"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
Expand All @@ -35,6 +36,12 @@ type Device struct {
// Replicas stores the total number of times this device is replicated.
// If this is 0 or 1 then the device is not shared.
Replicas int

// Health tracking fields for recovery detection
LastHealthyTime time.Time // Last time device was confirmed healthy
LastUnhealthyTime time.Time // When device became unhealthy
UnhealthyReason string // Human-readable reason (e.g., "XID-79")
RecoveryAttempts int // Number of recovery probes attempted
}

// deviceInfo defines the information the required to construct a Device
Expand Down Expand Up @@ -239,6 +246,40 @@ func (d *Device) GetUUID() string {
return AnnotatedID(d.ID).GetID()
}

// MarkUnhealthy marks the device as unhealthy and records the reason and
// timestamp. This should be called when a health check detects a device
// failure (e.g., XID error).
func (d *Device) MarkUnhealthy(reason string) {
d.Health = pluginapi.Unhealthy
d.LastUnhealthyTime = time.Now()
d.UnhealthyReason = reason
d.RecoveryAttempts = 0
}

// MarkHealthy marks the device as healthy and clears unhealthy state. This
// should be called when recovery detection confirms the device is working
// again.
func (d *Device) MarkHealthy() {
d.Health = pluginapi.Healthy
d.LastHealthyTime = time.Now()
d.UnhealthyReason = ""
d.RecoveryAttempts = 0
}

// IsUnhealthy returns true if the device is currently marked as unhealthy.
func (d *Device) IsUnhealthy() bool {
return d.Health == pluginapi.Unhealthy
}

// UnhealthyDuration returns how long the device has been unhealthy. Returns
// zero duration if the device is healthy.
func (d *Device) UnhealthyDuration() time.Duration {
if !d.IsUnhealthy() {
return 0
}
return time.Since(d.LastUnhealthyTime)
}

// NewAnnotatedID creates a new AnnotatedID from an ID and a replica number.
func NewAnnotatedID(id string, replica int) AnnotatedID {
return AnnotatedID(fmt.Sprintf("%s::%d", id, replica))
Expand Down
Loading