Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: NPU compile: L0 zeFenceHostSynchronize result: ZE_RESULT_ERROR_UNKNOWN #27099

Open
3 tasks done
Zctoylm0927 opened this issue Oct 17, 2024 · 1 comment
Open
3 tasks done
Assignees
Labels
bug Something isn't working category: NPU OpenVINO NPU plugin support_request

Comments

@Zctoylm0927
Copy link

OpenVINO Version

2024.3

Operating System

Ubuntu 20.04 (LTS)

Device used for inference

NPU

Framework

PyTorch

Model used

torch.nn.MultiheadAttention

Issue description

I have handwritten a Transformer model that includes three parts: self-attention, cross-attention, and MLP. It can run on the NPU, but when I run only the cross-attention part, the following problem occurs.

RuntimeError: Exception from src/inference/src/cpp/infer_request.cpp:223:
Exception from src/plugins/intel_npu/src/backend/include/zero_utils.hpp:21:
L0 zeFenceHostSynchronize result: ZE_RESULT_ERROR_UNKNOWN, code 0x7ffffffe - an action is required to complete the desired operation

Step-by-step reproduction

my cross_attention code is here:

class cross_block(nn.Module):
    def __init__(self, hidden_size=1200, num_heads=16):
        super(cross_block, self).__init__()
        self.head_dim = hidden_size // num_heads
        self.dim = hidden_size
        self.d_model = hidden_size
        self.num_heads = num_heads 
        
        self.mha = nn.MultiheadAttention(embed_dim=self.d_model, num_heads=self.num_heads)
    
    def cross_attn(self, q, k, v):
        N,B,C = q.shape
        x, output_weights  = self.mha(q, k, v)
        x = x.view(2, N//2, C) # just for testing
        return x
        
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        return self.cross_attn(q, k, v)

And followed by my convert code:

example_input = {
    "q": torch.randn(q_shape),
    "k": torch.randn(k_shape),
    "v": torch.randn(v_shape),
}

model = cross_block()
print("--------after model-------")
model = ov.convert_model(model, input=[[1920, 1, 1200], [300, 1, 1200], [300, 1, 1200]], example_input=example_input)
ov.save_model(model, CROSS_OV_PATH)
print("--------after convert-------")
compiled_model = core.compile_model(model, device_name="NPU") #check
print("--------after compile-------")

When I try to use the ov cross block, the problem occurs:

t = compiled_model(example_input)

But I use the original model, there is no such problem. And here is my cross block xml.

cross.xml.txt

Relevant log output

Traceback (most recent call last):
  File "/home/mla/model.py", line 50, in <module>
    t = compiled_model(example_input)
  File "/home/xxx/anaconda3/envs/env1/lib/python3.10/site-packages/openvino/runtime/ie_api.py", line 388, in __call__
    return self._infer_request.infer(
  File "/home/xxx/anaconda3/envs/env1/lib/python3.10/site-packages/openvino/runtime/ie_api.py", line 132, in infer
    return OVDict(super().infer(_data_dispatch(
RuntimeError: Exception from src/inference/src/cpp/infer_request.cpp:223:
Exception from src/plugins/intel_npu/src/backend/include/zero_utils.hpp:21:
L0 zeFenceHostSynchronize result: ZE_RESULT_ERROR_UNKNOWN, code 0x7ffffffe - an action is required to complete the desired operation

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@Zctoylm0927 Zctoylm0927 added bug Something isn't working support_request labels Oct 17, 2024
@andrei-kochin andrei-kochin added the category: NPU OpenVINO NPU plugin label Oct 17, 2024
@avitial
Copy link
Contributor

avitial commented Oct 25, 2024

@Zctoylm0927 thanks for reaching out, do you observe the same behavior on the latest 2024.4 release or nightly release? If you can please share minimal sample reproducer and IR model. Also provide the NPU driver version you are using.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working category: NPU OpenVINO NPU plugin support_request
Projects
None yet
Development

No branches or pull requests

5 participants