diff --git a/internal/lm/nvml.go b/internal/lm/nvml.go index 4d3e00496..017df6337 100644 --- a/internal/lm/nvml.go +++ b/internal/lm/nvml.go @@ -85,6 +85,11 @@ func NewDeviceLabeler(manager resource.Manager, config *spec.Config) (Labeler, e return nil, fmt.Errorf("error creating IMEX labeler: %v", err) } + uuidLabler, err := newGPUUUIDLabeler(devices) + if err != nil { + return nil, fmt.Errorf("error creating UUID labeler: %v", err) + } + l := Merge( machineTypeLabeler, versionLabeler, @@ -93,6 +98,7 @@ func NewDeviceLabeler(manager resource.Manager, config *spec.Config) (Labeler, e resourceLabeler, gpuModeLabeler, imexLabeler, + uuidLabler, ) return l, nil @@ -261,3 +267,16 @@ func getDeviceClasses(devices []resource.Device) ([]uint32, error) { } return classes, nil } + +// newGPUUUIDLabeler creates a new labeler that reports the UUIDs of GPUs on the node. +func newGPUUUIDLabeler(devices []resource.Device) (Labeler, error) { + labels := make(Labels, len(devices)) + for idx, d := range devices { + uuid, err := d.GetUUID() + if err != nil { + return nil, err + } + labels[fmt.Sprintf("nvidia.com/gpu-%d.uuid", idx)] = uuid + } + return labels, nil +} diff --git a/internal/resource/device_mock.go b/internal/resource/device_mock.go index eadc8932b..83a8b6306 100644 --- a/internal/resource/device_mock.go +++ b/internal/resource/device_mock.go @@ -41,6 +41,9 @@ var _ Device = &DeviceMock{} // GetTotalMemoryMBFunc: func() (uint64, error) { // panic("mock out the GetTotalMemoryMB method") // }, +// GetUUIDFunc: func() (string, error) { +// panic("mock out the GetUUID method") +// }, // IsFabricAttachedFunc: func() (bool, error) { // panic("mock out the IsFabricAttached method") // }, @@ -81,6 +84,9 @@ type DeviceMock struct { // GetTotalMemoryMBFunc mocks the GetTotalMemoryMB method. GetTotalMemoryMBFunc func() (uint64, error) + // GetUUIDFunc mocks the GetUUID method. + GetUUIDFunc func() (string, error) + // IsFabricAttachedFunc mocks the IsFabricAttached method. IsFabricAttachedFunc func() (bool, error) @@ -116,6 +122,9 @@ type DeviceMock struct { // GetTotalMemoryMB holds details about calls to the GetTotalMemoryMB method. GetTotalMemoryMB []struct { } + // GetUUID holds details about calls to the GetUUID method. + GetUUID []struct { + } // IsFabricAttached holds details about calls to the IsFabricAttached method. IsFabricAttached []struct { } @@ -134,6 +143,7 @@ type DeviceMock struct { lockGetName sync.RWMutex lockGetPCIClass sync.RWMutex lockGetTotalMemoryMB sync.RWMutex + lockGetUUID sync.RWMutex lockIsFabricAttached sync.RWMutex lockIsMigCapable sync.RWMutex lockIsMigEnabled sync.RWMutex @@ -355,6 +365,33 @@ func (mock *DeviceMock) GetTotalMemoryMBCalls() []struct { return calls } +// GetUUID calls GetUUIDFunc. +func (mock *DeviceMock) GetUUID() (string, error) { + if mock.GetUUIDFunc == nil { + panic("DeviceMock.GetUUIDFunc: method is nil but Device.GetUUID was just called") + } + callInfo := struct { + }{} + mock.lockGetUUID.Lock() + mock.calls.GetUUID = append(mock.calls.GetUUID, callInfo) + mock.lockGetUUID.Unlock() + return mock.GetUUIDFunc() +} + +// GetUUIDCalls gets all the calls that were made to GetUUID. +// Check the length with: +// +// len(mockedDevice.GetUUIDCalls()) +func (mock *DeviceMock) GetUUIDCalls() []struct { +} { + var calls []struct { + } + mock.lockGetUUID.RLock() + calls = mock.calls.GetUUID + mock.lockGetUUID.RUnlock() + return calls +} + // IsFabricAttached calls IsFabricAttachedFunc. func (mock *DeviceMock) IsFabricAttached() (bool, error) { if mock.IsFabricAttachedFunc == nil { diff --git a/internal/resource/nvml-device.go b/internal/resource/nvml-device.go index 9b29dc7bd..56825b68c 100644 --- a/internal/resource/nvml-device.go +++ b/internal/resource/nvml-device.go @@ -81,6 +81,15 @@ func (d nvmlDevice) GetName() (string, error) { return name, nil } +// GetUUID returns the device UUID. +func (d nvmlDevice) GetUUID() (string, error) { + uuid, ret := d.Device.GetUUID() + if ret != nvml.SUCCESS { + return "", ret + } + return uuid, nil +} + // GetTotalMemoryMB returns the total memory on a device in MB func (d nvmlDevice) GetTotalMemoryMB() (uint64, error) { info, ret := d.Device.GetMemoryInfo() diff --git a/internal/resource/nvml-mig-device.go b/internal/resource/nvml-mig-device.go index cf3d05300..2c74a918c 100644 --- a/internal/resource/nvml-mig-device.go +++ b/internal/resource/nvml-mig-device.go @@ -104,6 +104,15 @@ func (d nvmlMigDevice) GetName() (string, error) { return resourceName, nil } +// GetUUID returns the UUID of the nvmlMigDevice. +func (d nvmlMigDevice) GetUUID() (string, error) { + uuid, ret := d.MigDevice.GetUUID() + if ret != nvml.SUCCESS { + return "", ret + } + return uuid, nil +} + // GetTotalMemoryMB returns the total memory on a device in MB func (d nvmlMigDevice) GetTotalMemoryMB() (uint64, error) { attr, err := d.GetAttributes() diff --git a/internal/resource/sysfs-device.go b/internal/resource/sysfs-device.go index a3097a108..f5c8a4415 100644 --- a/internal/resource/sysfs-device.go +++ b/internal/resource/sysfs-device.go @@ -51,6 +51,11 @@ func (d vfioDevice) GetName() (string, error) { return d.nvidiaPCIDevice.DeviceName, nil } +// GetUUID is unsupported for vfio devices +func (d vfioDevice) GetUUID() (string, error) { + return "", fmt.Errorf("GetUUID is not supported for vfio devices") +} + // GetTotalMemoryMB returns the total memory on a device in MB func (d vfioDevice) GetTotalMemoryMB() (uint64, error) { _, val := d.nvidiaPCIDevice.Resources.GetTotalAddressableMemory(true) @@ -72,6 +77,7 @@ func (d vfioDevice) GetPCIClass() (uint32, error) { func (d vfioDevice) IsFabricAttached() (bool, error) { return false, nil } + func (d vfioDevice) GetFabricIDs() (string, string, error) { return "", "", fmt.Errorf("GetFabricIDs is not supported for vfio devices") } diff --git a/internal/resource/types.go b/internal/resource/types.go index dc6fa6a77..8e790052e 100644 --- a/internal/resource/types.go +++ b/internal/resource/types.go @@ -37,6 +37,7 @@ type Device interface { GetMigDevices() ([]Device, error) GetAttributes() (map[string]interface{}, error) GetName() (string, error) + GetUUID() (string, error) GetTotalMemoryMB() (uint64, error) GetDeviceHandleFromMigDeviceHandle() (Device, error) GetCudaComputeCapability() (int, int, error)