diff --git a/pkg/device_plugin/device_plugin.go b/pkg/device_plugin/device_plugin.go index 0a3681fb..a2fa9ea1 100644 --- a/pkg/device_plugin/device_plugin.go +++ b/pkg/device_plugin/device_plugin.go @@ -42,6 +42,10 @@ import ( pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" ) +const ( + nvidiaVendorID = "10de" +) + //Structure to hold details about Nvidia GPU Device type NvidiaGpuDevice struct { addr string // PCI address of device @@ -309,13 +313,11 @@ func getDeviceName(deviceID string) string { } defer file.Close() - // Find beginning of NVIDIA device list - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "10de") { - break - } + // Locate beginning of NVIDIA device list in pci.ids file + scanner, err := locateVendor(file, nvidiaVendorID) + if err != nil { + log.Printf("Error locating NVIDIA in pci.ds file: %v", err) + return "" } // Find NVIDIA device by device id @@ -354,3 +356,19 @@ func getDeviceName(deviceID string) string { } return deviceName } + +func locateVendor(pciIdsFile *os.File, vendorID string) (*bufio.Scanner, error) { + scanner := bufio.NewScanner(pciIdsFile) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, vendorID) { + return scanner, nil + } + } + + if err := scanner.Err(); err != nil { + return scanner, fmt.Errorf("error reading pci.ids file: %v", err) + } + + return scanner, fmt.Errorf("failed to find vendor id in pci.ids file: %s", vendorID) +}