Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Error for v9-s model with CUDA #154

Open
tahsinalamin opened this issue Jan 9, 2025 · 2 comments
Open

Inference Error for v9-s model with CUDA #154

tahsinalamin opened this issue Jan 9, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@tahsinalamin
Copy link

Describe the bug

Inference not working as expected for v9-s model when using with cuda. However, this error is not coming when using v9-c model. Full error is below:

{
	"name": "RuntimeError",
	"message": "Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 7
      3 model = create_model(cfg.model, class_num=CLASS_NUM).to(device)
      5 transform = AugmentationComposer([], cfg.image_size)
----> 7 converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
      8 post_proccess = PostProcess(converter, cfg.task.nms)

File /app/yolo/utils/bounding_box_utils.py:456, in create_converter(model_version, *args, **kwargs)
    454     converter = Anc2Box(*args, **kwargs)
    455 else:
--> 456     converter = Vec2Box(*args, **kwargs)
    457 return converter

File /app/yolo/utils/bounding_box_utils.py:347, in Vec2Box.__init__(self, model, anchor_cfg, image_size, device)
    345 else:
    346     logger.info(\":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size\")
--> 347     self.strides = self.create_auto_anchor(model, image_size)
    349 anchor_grid, scaler = generate_anchors(image_size, self.strides)
    350 self.image_size = image_size

File /app/yolo/utils/bounding_box_utils.py:357, in Vec2Box.create_auto_anchor(self, model, image_size)
    355 # TODO: need accelerate dummy test
    356 dummy_input = torch.zeros(1, 3, H, W)
--> 357 dummy_output = model(dummy_input)
    358 strides = []
    359 for predict_head in dummy_output[\"Main\"]:

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /app/yolo/model/yolo.py:79, in YOLO.forward(self, x)
     77 else:
     78     model_input = y[layer.source]
---> 79 x = layer(model_input)
     80 y[-1] = x
     81 if layer.usable:

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /app/yolo/model/module.py:33, in Conv.forward(self, x)
     32 def forward(self, x: Tensor) -> Tensor:
---> 33     return self.act(self.bn(self.conv(x)))

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/conv.py:458, in Conv2d.forward(self, input)
    457 def forward(self, input: Tensor) -> Tensor:
--> 458     return self._conv_forward(input, self.weight, self.bias)

File /app/.env/lib/python3.8/site-packages/torch/nn/modules/conv.py:454, in Conv2d._conv_forward(self, input, weight, bias)
    450 if self.padding_mode != 'zeros':
    451     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    452                     weight, bias, self.stride,
    453                     _pair(0), self.dilation, self.groups)
--> 454 return F.conv2d(input, weight, bias, self.stride,
    455                 self.padding, self.dilation, self.groups)

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor"
}

To Reproduce

Steps to reproduce the behavior:

  1. Go to /examples/notebook_inference.ipynb
  2. Change: MODEL = "v9-s"
  3. Run the notebook

Expected behavior

It should run without any error.

Screenshots

image.
image

System Info (please complete the following ## information):

  • OS: Ubuntu 20.04
  • Python Version: 3.8.10
  • PyTorch Version: 2.4.1+cu118
  • CUDA/cuDNN/MPS Version: CUDA 11.1
  • YOLO Model Version: YOLOv9-s

Additional context

If the device is CPU, it works perfectly (which is expected).

@tahsinalamin tahsinalamin added the bug Something isn't working label Jan 9, 2025
@Adamusen
Copy link
Contributor

Adamusen commented Jan 10, 2025

Hey,

this bug happens because there is a fallback mechanism to compute the anchor strides of a model if it is not provided in the {model}.yaml file by running a dummy computation on the model and calculate it from its outputs. However, this dummy computation is currently run on the CPU, because the model is not yet copied by the lighting module on the GPU when it is executed, which leads to a tensor location missmatch later. This bug will certainly be fixed later, for now your easiest solution is to simply put the strides in the model config yourself (just like it is in v9-c):

anchor:
  reg_max: 16
  strides: [8, 16, 32]  # add this

@tahsinalamin
Copy link
Author

@Adamusen Thanks for the explanation. The worked around you mentioned worked!

For anyone else, the add the line in: /yolo/config/v9-s.yaml file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants