Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch: fails on model with multiple return values. #1147

Open
4 tasks done
sei-jgwohlbier opened this issue Dec 12, 2024 · 1 comment
Open
4 tasks done

PyTorch: fails on model with multiple return values. #1147

sei-jgwohlbier opened this issue Dec 12, 2024 · 1 comment
Labels

Comments

@sei-jgwohlbier
Copy link

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.

Quick summary

hls4ml fails on a PyTorch model with multiple return values.

Details

hls4ml fails on code below that has two linear layers and returns output from both layers.

Steps to Reproduce

Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.

  1. Clone the hls4ml repository
  2. Checkout the master branch, with commit hash: [cc4fbf9]
  3. Run conversion for code below.
from pathlib import Path

import numpy as np
import os
import shutil
import torch
import torch.nn as nn
from torchinfo import summary

from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model

test_root_path = Path(__file__).parent

class test(nn.Module):
    def __init__(self, n_in, n1, n2):
        super().__init__()

        self.lin1 = nn.Linear(n_in, n1, bias=True)
        self.lin2 = nn.Linear(n_in, n2, bias=True)

    def forward(self, x):
        y = self.lin1(x)
        z = self.lin2(x)
        return y, z

if __name__ == "__main__":

    n_batch = 16
    n_in = 16
    n1 = 32
    n2 = 64
    X_input_shape = (n_batch, n_in)

    model = test(n_in, n1, n2)
    io_type='io_stream'
    backend='Vitis'
    output_dir = str(test_root_path / f'hls4mlprj_2lin_{backend}_{io_type}')
    if os.path.exists(output_dir):
        print("delete project dir")
        shutil.rmtree(output_dir)

    model.eval()
    summary(model, input_size=X_input_shape)

    X_input = np.random.rand(*X_input_shape)
    #X_input = np.ones(X_input_shape)
    with torch.no_grad():
        pytorch_prediction = [p.detach().numpy()
                              for p in model(torch.Tensor(X_input))]

    # transform X_input to channels last
    X_input_hls = np.ascontiguousarray(X_input)

    # write tb data
    ipf = "./tb_input_features.dat"
    if os.path.isfile(ipf):
        os.remove(ipf)
    with open(ipf, "ab") as f:
        for x in X_input_hls:
            np.savetxt(f, x.flatten(), newline=" ")
    opf = "./tb_output_predictions.dat"
    if os.path.isfile(opf):
        os.remove(opf)
    with open(opf, "ab") as f:
        for p0,p1 in zip(pytorch_prediction[0],
                         pytorch_prediction[1]):
            np.savetxt(f, p0.flatten(), newline=" ")
            np.savetxt(f, p1.flatten(), newline=" ")

    default_precision='ap_fixed<16,6>'
    default_precision='ap_fixed<32,12>'
    #default_precision='ap_fixed<64,24>'
    config = config_from_pytorch_model(model,
                                       input_shape=X_input_shape[-1:],
                                       backend=backend,
                                       default_precision=default_precision,
                                       default_reuse_factor=1,
                                       channels_last_conversion='internal',
                                       transpose_outputs=False)
    config['Model']['Strategy'] = 'Resource'
    print(config)
    print(output_dir)

    hls_model = convert_from_pytorch_model(
        model,
        output_dir=output_dir,
        input_data_tb=ipf,
        output_data_tb=opf,
        backend=backend,
        hls_config=config,
        io_type=io_type,
        part='xcvu9p-flga2104-2-e'
    )
    hls_model.compile()

    print("pytorch_prediction")
    print(pytorch_prediction)

    # reshape hls prediction to channels last, then transpose
    hls_prediction = hls_model.predict(X_input_hls)
    print("hls_prediction")
    print(hls_prediction)

    rtol = 1.0e-2
    atol = 1.0e-2
    assert len(pytorch_prediction) == len(hls_prediction), "length mismatch"

    for p0, h0 in zip(pytorch_prediction[0], hls_prediction[0]):
        np.testing.assert_allclose(p0,
                                   h0,
                                   rtol=rtol, atol=atol)
    for p1, h1 in zip(pytorch_prediction[1], hls_prediction[1]):
        np.testing.assert_allclose(p1,
                                   h1,
                                   rtol=rtol, atol=atol)
    # synthesize
    hls_model.build(csim=True, synth=True, cosim=True, validation=True)

Expected behavior

Sucessful synthesis.

Actual behavior

python test_2lin.py 
2024-12-12 18:12:01.988121: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-12 18:12:02.040888: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-12 18:12:02.892484: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
delete project dir
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
test                                     [16, 32]                  --
├─Linear: 1-1                            [16, 32]                  544
├─Linear: 1-2                            [16, 64]                  1,088
==========================================================================================
Total params: 1,632
Trainable params: 1,632
Non-trainable params: 0
Total mult-adds (M): 0.03
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.01
Estimated Total Size (MB): 0.02
==========================================================================================
{'Model': {'Precision': {'default': 'ap_fixed<32,12>'}, 'ReuseFactor': 1, 'ChannelsLastConversion': 'internal', 'TransposeOutputs': False, 'Strategy': 'Resource', 'BramFactor': 1000000000, 'TraceOutput': False}, 'PytorchModel': test(
  (lin1): Linear(in_features=16, out_features=32, bias=True)
  (lin2): Linear(in_features=16, out_features=64, bias=True)
), 'InputShape': (16,)}
/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/hls4mlprj_2lin_Vitis_io_stream
Interpreting Model ...
Topology:
Layer name: lin1, layer type: Dense, input shape: [[None, 16]]
Layer name: lin2, layer type: Dense, input shape: [[None, 16]]
Creating HLS model
Writing HLS project
Done
pytorch_prediction
[array([[ 0.8891059 , -0.44395483, -0.05747134,  0.18016575,  0.04786346,
         0.5514014 , -0.16852657,  0.00964493,  0.3273672 ,  0.42060843,
         0.06706502,  0.15498346,  0.2457329 , -0.15184441,  0.09685186,
        -0.6596167 ,  0.2790345 ,  0.40409216, -0.23034032,  0.26463172,
        -0.46979874, -0.11001211,  0.35551935,  0.09460301, -0.10833421,
        -0.4492357 , -0.28191066, -0.26569235, -0.12289155, -0.5352483 ,
         0.5751673 , -0.2317074 ],
       [ 0.76985276, -0.1049726 ,  0.07535005,  0.10176656,  0.09320992,
         0.28592783, -0.04348151,  0.03626189,  0.00936881,  0.4154517 ,
         0.37312955,  0.14893359,  0.1893669 , -0.22227341, -0.08566827,
        -0.58724916, -0.25961903,  0.65872145, -0.34750587,  0.315466  ,
        -0.47754753,  0.00142413,  0.28266868,  0.29222852,  0.03858323,
        -0.2562538 , -0.35519725,  0.18092948, -0.07686479, -0.6124781 ,
         0.31275678, -0.27654955],
       [ 0.85511225, -0.24324456, -0.00846682,  0.22668763,  0.03866964,
         0.2381162 ,  0.0457862 ,  0.05140384,  0.04126164,  0.32769206,
         0.2474686 ,  0.20140621,  0.0427063 , -0.2780934 , -0.01511081,
        -0.58079964, -0.19579059,  0.5630405 , -0.37406617,  0.4501509 ,
        -0.47031593, -0.1698381 ,  0.46137342,  0.19732511, -0.02999109,
        -0.27819347, -0.33464026,  0.00155999, -0.07882459, -0.51667523,
         0.4007225 , -0.1982677 ],
       [ 1.0654068 , -0.30002496,  0.16547504,  0.2570746 ,  0.07158022,
         0.4347104 , -0.06412914,  0.16597775,  0.16384612,  0.4160096 ,
         0.08880451,  0.1005227 ,  0.1824699 , -0.19954087,  0.34508896,
        -0.53782004,  0.09642816,  0.8185116 , -0.34626994,  0.471716  ,
        -0.5092526 , -0.06822003,  0.3831837 , -0.01965211, -0.01387932,
        -0.37834692, -0.3783682 , -0.3562213 , -0.27375486, -0.6525427 ,
         0.6037679 ,  0.17533389],
       [ 0.52793086,  0.08541805,  0.03517117, -0.4244916 ,  0.10885802,
         0.43530622,  0.3118299 , -0.01598971,  0.3790553 ,  0.5554543 ,
         0.05826975,  0.11390461,  0.2410459 ,  0.0613706 ,  0.26139343,
        -0.27970743,  0.26997155,  0.46432167,  0.00322317, -0.15576953,
        -0.340056  , -0.08219175,  0.24044743, -0.10614166,  0.1167696 ,
        -0.38514078, -0.20315412, -0.13610272, -0.13506019, -0.39643157,
         0.43387794, -0.22893703],
       [ 0.83738965,  0.01773065,  0.01746632,  0.0049476 , -0.02727026,
         0.17095442,  0.26207945,  0.1697861 ,  0.34357035,  0.2642256 ,
         0.29654276,  0.2556939 ,  0.06309891, -0.10552   ,  0.08774575,
        -0.5153153 , -0.06944568,  0.31070724, -0.21419683,  0.21724322,
        -0.45854414, -0.04687934,  0.29160213,  0.29456928,  0.14869723,
        -0.2757703 , -0.3541801 ,  0.08705469, -0.09899832, -0.37215212,
         0.6330352 , -0.5796311 ],
       [ 0.58500886,  0.2640052 , -0.0189429 , -0.2794629 ,  0.13246663,
        -0.267674  ,  0.24941778, -0.04296389,  0.15840055,  0.01208394,
         0.1177678 ,  0.39987636,  0.08620736, -0.03397053,  0.12804201,
        -0.65928245, -0.05545972,  0.69912994, -0.16601579,  0.18794903,
        -0.7339839 ,  0.03901096,  0.30852503, -0.01032168,  0.08174405,
        -0.27913028, -0.23137385,  0.00499156,  0.09213072, -0.759608  ,
         0.91822934, -0.5346441 ],
       [ 0.56878453,  0.11198848,  0.05960959, -0.12241329,  0.12977597,
         0.08147588,  0.3533719 ,  0.16589719, -0.06445619,  0.4639053 ,
         0.462967  ,  0.2932239 ,  0.13533969, -0.2153621 ,  0.14075479,
        -0.5042372 , -0.26714593,  0.48706523, -0.33529457,  0.3466363 ,
        -0.34024402, -0.11915696,  0.3307217 ,  0.36106905,  0.18427725,
        -0.20332047, -0.35370904,  0.09610368, -0.09901851, -0.46560162,
         0.5428121 , -0.42115593],
       [ 0.6866636 , -0.11477496, -0.01831607, -0.02805769, -0.01344819,
         0.68515   , -0.04925771, -0.1972478 ,  0.31140062,  0.40612757,
         0.2530442 ,  0.21337444,  0.6395557 , -0.09065704,  0.19372801,
        -0.45185912,  0.50229234,  0.3983093 , -0.14483142,  0.07841846,
        -0.49485716, -0.02537266,  0.29264736,  0.1069174 ,  0.11361703,
        -0.20951061, -0.26409623, -0.32147884, -0.02064542, -0.511343  ,
         0.38575244,  0.04794483],
       [ 0.67789733, -0.30219573, -0.12434944,  0.11558396,  0.06762291,
         0.3116027 ,  0.15201744,  0.15036204,  0.06727348,  0.42700085,
         0.3871081 ,  0.3823516 ,  0.24762037, -0.17611447,  0.13901351,
        -0.53381497,  0.14353468,  0.49727145, -0.15057111,  0.32427734,
        -0.40415937, -0.0112884 ,  0.31515226,  0.16169925,  0.0040657 ,
        -0.19852826, -0.22190264, -0.18831491, -0.13794504, -0.5023341 ,
         0.69033474, -0.38809985],
       [ 0.81814086, -0.05858143,  0.07434263,  0.01335097,  0.01402169,
         0.5058249 ,  0.16288948, -0.10923052,  0.21130612,  0.52750933,
         0.12909204,  0.04708859,  0.51017034, -0.48467067,  0.236849  ,
        -0.53199136,  0.5741215 ,  0.66357696, -0.08221325,  0.04293117,
        -0.21072648,  0.13694671,  0.34113112,  0.00190126,  0.07781912,
        -0.01927111, -0.48293623, -0.401911  , -0.00399454, -0.6709269 ,
         0.76886785, -0.07476626],
       [ 0.6913032 , -0.1981028 , -0.08275409,  0.10008418, -0.07262716,
         0.36380088,  0.08553496, -0.16448833,  0.21087572,  0.53764087,
         0.19602291,  0.09081438,  0.2667737 , -0.33534533, -0.2282128 ,
        -0.5492    ,  0.21781437,  0.72637093, -0.14848016,  0.04423207,
        -0.2934043 ,  0.05480177,  0.37749898,  0.06654172,  0.00630023,
        -0.1037505 , -0.30250746, -0.19109204, -0.05297701, -0.60283387,
         0.3600675 , -0.24646895],
       [ 0.81141824, -0.33150065,  0.06518545, -0.02965383,  0.24818233,
         0.43532866,  0.10186243,  0.38129237,  0.31362757,  0.4576823 ,
         0.2271005 ,  0.2522714 ,  0.22506769,  0.23416433,  0.22466838,
        -0.17880167, -0.13405931,  0.50129116, -0.32637626,  0.40901405,
        -0.30316994, -0.19033791,  0.07678111,  0.04296248,  0.0158764 ,
        -0.37330478, -0.03181724, -0.10275158, -0.16623837, -0.33406097,
         0.5428152 , -0.32601902],
       [ 0.93497413, -0.09970862,  0.05688836,  0.05871909,  0.20382336,
         0.2938374 , -0.08205896,  0.13411754,  0.06540591,  0.35490224,
         0.16667452,  0.32861888,  0.15358543, -0.31738937,  0.5358051 ,
        -0.7712355 ,  0.17287332,  0.72606945, -0.15157634,  0.25406668,
        -0.7806479 , -0.15099436,  0.36042407, -0.00913737,  0.08302039,
        -0.4754683 , -0.47949797, -0.18461673, -0.27660465, -0.78692245,
         0.6459955 ,  0.0683116 ],
       [ 0.689624  ,  0.2917918 ,  0.23204178, -0.24007678,  0.15531424,
         0.06384554,  0.05830847, -0.05809493,  0.2642526 , -0.0223131 ,
         0.05240817,  0.23094033,  0.34774926, -0.06030922,  0.3801004 ,
        -0.45984933,  0.06846657,  0.45807868, -0.28368205,  0.19050267,
        -0.60009193,  0.00462461,  0.19878045,  0.07578797,  0.16144395,
        -0.32184047, -0.33067778, -0.05224532,  0.04741496, -0.5540669 ,
         0.75985986, -0.11818236],
       [ 0.42422694, -0.13985819,  0.15179643, -0.0994494 ,  0.05480643,
         0.4122148 ,  0.03583536, -0.03997545,  0.0027165 ,  0.5133945 ,
         0.23952836,  0.02225108,  0.21865284,  0.06876859,  0.21143112,
        -0.6492269 ,  0.18834093,  0.43246025, -0.3369369 ,  0.12497532,
        -0.4983435 , -0.05300211,  0.35259238,  0.36499974, -0.01017824,
        -0.51138484, -0.400999  , -0.19466041, -0.20390618, -0.6352968 ,
         0.6848689 , -0.19771972]], dtype=float32), array([[-0.64182013,  0.40184742, -0.07485253, ...,  0.11808984,
        -0.22370778,  0.11206529],
       [-0.5201124 ,  0.18107897, -0.02317163, ..., -0.2986831 ,
         0.10982952,  0.18085258],
       [-0.50899744,  0.2951257 , -0.13951811, ..., -0.02748931,
        -0.06415796,  0.18319744],
       ...,
       [-0.63268775,  0.62129   , -0.16334109, ..., -0.16879866,
        -0.31751645,  0.15808316],
       [-0.71309054,  0.2369004 ,  0.10258856, ..., -0.11202749,
        -0.46535045,  0.2605038 ],
       [-0.71718186,  0.42549616,  0.10741499, ..., -0.22839454,
        -0.1128529 ,  0.39642552]], dtype=float32)]
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
hls_prediction
[[-0.64183044  0.40183544 -0.07486248 ...  0.11807919 -0.22371769
   0.11205196]
 [-0.52012253  0.18106556 -0.02318382 ... -0.29869556  0.10981655
   0.18084049]
 [-0.50900841  0.29511356 -0.13953018 ... -0.02750206 -0.06417084
   0.18318558]
 ...
 [-0.63269997  0.62127781 -0.16335392 ... -0.1688118  -0.31753063
   0.15807152]
 [-0.71309853  0.23688602  0.10257721 ... -0.11204052 -0.46536255
   0.26049042]
 [-0.7171917   0.42548656  0.10740471 ... -0.22840595 -0.11286354
   0.3964119 ]]
Traceback (most recent call last):
  File "/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/test_2lin.py", line 107, in <module>
    assert len(pytorch_prediction) == len(hls_prediction), "length mismatch"
AssertionError: length mismatch
@sei-jgwohlbier sei-jgwohlbier changed the title PyTorch: fail on model with multiple return values. PyTorch: fails on model with multiple return values. Dec 12, 2024
@JanFSchulte
Copy link
Contributor

Hi @sei-jgwohlbier Thanks for reporting this issue, and thanks for always having very clear and complete bug reports, makes it very easy to fix the issues.

I had indeed overlooked this case when implementing the pytorch parser, fixed in this PR: #1151

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants