Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent layers with double state don't error when a single state is provided #47

Closed
MartinuzziFrancesco opened this issue Jan 20, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@MartinuzziFrancesco
Copy link
Owner

Pretty much title, MWE:

julia> ran = RAN(2=>10)
RAN(2 => 30)        # 280 parameters

julia> ran(rand(Float32, 2, 5), (rand(Float32, 10), rand(Float32, 10)))
10×5 Matrix{Float32}:
 -0.234597  -0.296605   -0.299842   -0.267407   -0.168646
 -0.12655   -0.117301   -0.117308   -0.104518   -0.0664909
  0.435177   0.27139     0.281547    0.256575    0.155766
  0.159563  -0.0951298  -0.229988   -0.253865   -0.165425
  0.321613   0.29099     0.237064    0.200287    0.13197
  0.635304   0.494455    0.502597    0.459889    0.312964
  0.247967   0.195159    0.198724    0.179123    0.110956
  0.10326   -0.0818294   0.0143353   0.0493054   0.0204579
  0.131437   0.0887093  -0.0433349  -0.0800482  -0.0446744
  0.223177   0.238569    0.28191     0.270312    0.167857

julia> ran(rand(Float32, 2, 5), rand(Float32, 10))
10×5×10 Array{Float32, 3}:
[:, :, 1] =
 -0.0884085   -0.251769    -0.304124   -0.275792   -0.33086
 -0.00909825  -0.061802    -0.0765486  -0.0934748  -0.10374
  0.208017     0.192099     0.14573     0.216099    0.223822
 -0.108327    -0.211199    -0.216408   -0.261386   -0.285244
  0.215404     0.272211     0.289356    0.21939     0.264094
  0.270969     0.262589     0.210557    0.29559     0.301221
  0.166093     0.177697     0.167387    0.173029    0.190619
  0.132261    -0.0315028   -0.169854   -0.0238708  -0.0817408
 -0.015829     0.00937752   0.0581976  -0.0608296  -0.0295524
  0.184363     0.217556     0.206879    0.232572    0.256938

[:, :, 2] =
 -0.0977714   -0.255776    -0.305831   -0.276546   -0.331266
 -0.00371165  -0.0589298   -0.074972   -0.092785   -0.103401
  0.203491     0.191577     0.14603     0.216509    0.224161
 -0.0891033   -0.204475    -0.214071   -0.260395   -0.284939
  0.181264     0.25953      0.284303    0.217384    0.263431
  0.245049     0.246835     0.20113     0.289493    0.297416
  0.177979     0.184697     0.171492    0.17543     0.192053
  0.13712     -0.0281015   -0.167784   -0.0227935  -0.0811153
 -0.0244088    0.00577614   0.0565333  -0.0615115  -0.0298372
  0.190705     0.220517     0.208067    0.233196    0.257179

[:, :, 3] =
 -0.0844863   -0.249968   -0.303554   -0.275663   -0.330926
  0.00132415  -0.0569072  -0.0740381  -0.0924534  -0.103267
  0.22217      0.198618    0.148792    0.217442    0.224466
 -0.0820324   -0.201175   -0.212528   -0.259495   -0.284507
  0.203697     0.268002    0.287701    0.218698    0.263914
  0.253258     0.252194    0.204461    0.291727    0.298874
  0.170999     0.180041    0.168578    0.173692    0.191019
  0.146965    -0.0247973  -0.166759   -0.0219035  -0.0806948
 -0.0134071    0.0102147   0.0583425  -0.0606009  -0.0294495
  0.205456     0.227134    0.210793    0.23467     0.257836

[:, :, 4] =
 -0.0884534   -0.251166    -0.303622   -0.275467   -0.330748
  0.00813491  -0.0533      -0.0721297  -0.09164    -0.102879
  0.214304     0.195047     0.147084    0.216678    0.224094
 -0.108513    -0.211984    -0.217073   -0.261917   -0.285565
  0.191465     0.2637       0.286133    0.218167    0.263718
  0.258898     0.254701     0.205619    0.292327    0.299163
  0.170663     0.180646     0.16921     0.174136    0.191308
  0.140691    -0.0272202   -0.167594   -0.0224835  -0.0809709
 -0.0143571    0.00995612   0.0583625  -0.060704   -0.029478
  0.181218     0.216952     0.206938    0.23271     0.257026

[:, :, 5] =
 -0.111979    -0.260902   -0.307747   -0.277255   -0.331446
 -0.00431808  -0.0588142  -0.0748834  -0.0927081  -0.103354
  0.209274     0.193765    0.146812    0.216767    0.224216
 -0.0724934   -0.197829   -0.211328   -0.259011   -0.28437
  0.183333     0.259358    0.283774    0.216988    0.26317
  0.262413     0.256376    0.206473    0.292779    0.299382
  0.174976     0.183403    0.170856    0.175079    0.191856
  0.145288    -0.0244487  -0.166119   -0.021684   -0.0805221
 -0.0251922    0.0055794   0.0566221  -0.0615905  -0.0298728
  0.176764     0.214534    0.205789    0.232077    0.256748

[:, :, 6] =
 -0.0809046  -0.247818    -0.302134   -0.274782   -0.330437
  0.0129058  -0.0510784   -0.0709705  -0.0911472  -0.102638
  0.209453    0.193288     0.146473    0.216416    0.224
 -0.110786   -0.213333    -0.2177     -0.262285   -0.285733
  0.184233    0.26085      0.284901    0.217633    0.263496
  0.263185    0.257436     0.207296    0.293444    0.299881
  0.166381    0.17797      0.16759     0.173193    0.190748
  0.140138   -0.0274404   -0.167718   -0.0225487  -0.0810157
 -0.0195211   0.00778811   0.0573778  -0.0611324  -0.0296728
  0.199926    0.225351     0.210366    0.234572    0.257869

[:, :, 7] =
 -0.109076    -0.260067   -0.307497   -0.277201   -0.331469
 -0.00463054  -0.0591667  -0.0750171  -0.092752   -0.103362
  0.211766     0.195143    0.147576    0.217149    0.224423
 -0.0968479   -0.207135   -0.214986   -0.260752   -0.285026
  0.181884     0.259607    0.284268    0.217344    0.263395
  0.251125     0.250834    0.203675    0.291168    0.2985
  0.167099     0.178527    0.167958    0.173418    0.190889
  0.13293     -0.0304691  -0.169376   -0.0237214  -0.0816701
 -0.0133478    0.0102496   0.0584658  -0.060541   -0.0294449
  0.201672     0.225935    0.210451    0.234572    0.257841

[:, :, 8] =
 -0.103377    -0.256662    -0.305416   -0.275946   -0.330765
  0.00312736  -0.0556237   -0.0732768  -0.0920657  -0.103052
  0.222484     0.198756     0.148792    0.217431    0.224435
 -0.0998801   -0.208357    -0.215652   -0.261252   -0.285334
  0.204024     0.266004     0.286046    0.217811    0.263369
  0.271827     0.26148      0.209301    0.294465    0.300382
  0.185329     0.188206     0.173094    0.176111    0.19228
  0.136387    -0.0293435   -0.1686     -0.0231225  -0.0813003
 -0.0272666    0.00471168   0.0560549  -0.0617457  -0.0299319
  0.187594     0.220245     0.208525    0.233774    0.257631

[:, :, 9] =
 -0.0946515   -0.254205   -0.305152   -0.276222   -0.331033
  0.00158951  -0.0564778  -0.0737916  -0.0923169  -0.103191
  0.199848     0.188372    0.144044    0.215293    0.223441
 -0.094657    -0.206617   -0.214863   -0.260794   -0.285093
  0.203132     0.267003    0.286927    0.218274    0.263621
  0.272284     0.263737    0.211374    0.296156    0.301618
  0.164233     0.176745    0.166898    0.172784    0.190505
  0.145762    -0.0252113  -0.166895   -0.0219828  -0.080731
 -0.0067542    0.0130785   0.0598592  -0.0600537  -0.0292274
  0.194089     0.22186     0.208625    0.233532    0.257377

[:, :, 10] =
 -0.102582    -0.256951   -0.306228   -0.276687   -0.331224
 -0.00740377  -0.0604406  -0.0757019  -0.0930406  -0.103503
  0.221388     0.198265    0.148517    0.217316    0.224357
 -0.0952371   -0.206479   -0.214603   -0.260534   -0.284927
  0.189681     0.263037    0.285771    0.217941    0.26358
  0.269091     0.261689    0.210136    0.295371    0.301143
  0.1654       0.177396    0.167188    0.172908    0.190564
  0.149186    -0.0234476  -0.166207   -0.0215564  -0.0805256
 -0.00164255   0.0151025   0.0606154  -0.0595892  -0.0290267
  0.190514     0.221431    0.20885     0.233826    0.257581
@MartinuzziFrancesco MartinuzziFrancesco added the bug Something isn't working label Jan 20, 2025
@MartinuzziFrancesco
Copy link
Owner Author

fixed by #48

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant