diff --git a/Transfer Learning/Atomistic graph data/config.json b/Transfer Learning/Atomistic graph data/config.json new file mode 100644 index 00000000..898298ea --- /dev/null +++ b/Transfer Learning/Atomistic graph data/config.json @@ -0,0 +1,26 @@ +{ + "_name_or_path": "/mnt/data2/s2ef_all_10epochs_weights/checkpoint-292950", + "architectures": [ + "AtomformerModel" + ], + "auto_map": { + "AutoConfig": "configuration_atomformer.AtomformerConfig", + "AutoModel": "modeling_atomformer.AtomformerModel" + }, + "bos_token_id": 120, + "cls_token_id": 122, + "depth": 12, + "dim": 768, + "dropout": 0.0, + "eos_token_id": 121, + "gradient_checkpointing": false, + "k": 128, + "mask_token_id": 0, + "mlp_ratio": 4, + "model_type": "atomformer", + "num_heads": 32, + "pad_token_id": 119, + "torch_dtype": "float32", + "transformers_version": "4.40.0", + "vocab_size": 123 +} diff --git a/Transfer Learning/Atomistic graph data/configuration_atomformer.py b/Transfer Learning/Atomistic graph data/configuration_atomformer.py new file mode 100644 index 00000000..927a7476 --- /dev/null +++ b/Transfer Learning/Atomistic graph data/configuration_atomformer.py @@ -0,0 +1,42 @@ +from transformers.configuration_utils import PretrainedConfig +from typing import Any + +class AtomformerConfig(PretrainedConfig): # type: ignore + r""" + Configuration of a :class:`~transform:class:`~transformers.AtomformerModel`. + + It is used to instantiate an Atomformer model according to the specified arguments. + """ + + model_type = "atomformer" + + def __init__( + self, + vocab_size: int = 123, + dim: int = 768, + num_heads: int = 32, + depth: int = 12, + mlp_ratio: int = 1, + k: int = 128, + dropout: float = 0.0, + mask_token_id: int = 0, + pad_token_id: int = 119, + bos_token_id: int = 120, + eos_token_id: int = 121, + cls_token_id: int = 122, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.dim = dim + self.num_heads = num_heads + self.depth = depth + self.mlp_ratio = mlp_ratio + self.k = k + + self.dropout = dropout + self.mask_token_id = mask_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.cls_token_id = cls_token_id \ No newline at end of file diff --git a/Transfer Learning/Atomistic graph data/model.py b/Transfer Learning/Atomistic graph data/model.py new file mode 100644 index 00000000..76e5bc52 --- /dev/null +++ b/Transfer Learning/Atomistic graph data/model.py @@ -0,0 +1,6 @@ +import kagglehub + +# Download latest version +path = kagglehub.model_download("tedlord/atomformer/pyTorch/default") + +print("Path to model files:", path) \ No newline at end of file diff --git a/Transfer Learning/Atomistic graph data/modeling_atomformer.py b/Transfer Learning/Atomistic graph data/modeling_atomformer.py new file mode 100644 index 00000000..a863760d --- /dev/null +++ b/Transfer Learning/Atomistic graph data/modeling_atomformer.py @@ -0,0 +1,2867 @@ +"""Implementation of the Atomformer model.""" + +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as f +from torch import nn +from transformers.modeling_utils import PreTrainedModel +from .configuration_atomformer import AtomformerConfig + + +ATOM_METADATA = [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.106761565836299, + 0.4573170731707318, + 0.46896368424867707, + 0.0, + 0.0, + 0.0027383806383189145, + 0.0, + 1.0, + 0.0, + 0.0, + ], + [ + 0.008547008547008548, + 0.010187317385107808, + 0.011235955056179775, + 0.008547008547008548, + 0.008547008547008548, + 0.0, + 1.0, + 0.0, + -1.0, + 0.9999999999999999, + 2.1731754967921256e-06, + -1.0, + 0.0, + 0.010000000000000002, + 0.3588318085855031, + 0.0, + -1.0, + ], + [ + 0.017094017094017096, + 0.02018415404448405, + 0.02247191011235955, + 0.017094017094017096, + 0.017094017094017096, + 0.16666666666666666, + 0.0, + 0.5729537366548044, + 0.08536585365853658, + 0.0723802160098582, + 0.01302222611458848, + 0.1117635470484688, + 0.2746530986669577, + 0.010000000000000002, + 0.2454609429978888, + 0.16666666666666666, + 0.0, + ], + [ + 0.025641025641025644, + 0.027228539455021038, + 0.028089887640449437, + 0.025641025641025644, + 0.025641025641025644, + 0.16666666666666666, + 0.058823529411764705, + 0.32384341637010683, + 0.2652439024390244, + 0.2623432478797689, + 0.0451198574701265, + 0.39298038243761085, + 0.4668171696125004, + 0.015, + 0.12181562280084446, + 0.16666666666666666, + 0.14285714285714285, + ], + [ + 0.03418803418803419, + 0.03334773276914757, + 0.033707865168539325, + 0.03418803418803419, + 0.03418803418803419, + 0.16666666666666666, + 0.7058823529411764, + 0.25266903914590755, + 0.4085365853658537, + 0.2128252833015198, + 0.057071103187614054, + 0.6504807478441018, + 0.715419845245687, + 0.015, + 0.06558761435608726, + 0.16666666666666666, + 0.2857142857142857, + ], + [ + 0.042735042735042736, + 0.03742946260625253, + 0.033707865168539325, + 0.042735042735042736, + 0.042735042735042736, + 0.16666666666666666, + 0.7647058823529411, + 0.14946619217081855, + 0.5640243902439024, + 0.3559765143644139, + 0.055363782370830124, + 1.0000000000000002, + 0.7324707832177849, + 0.020000000000000004, + 0.04327938071780436, + 0.16666666666666666, + 0.42857142857142855, + ], + [ + 0.051282051282051294, + 0.04421873990197045, + 0.03932584269662921, + 0.051282051282051294, + 0.051282051282051294, + 0.16666666666666666, + 0.8235294117647058, + 0.09252669039145908, + 0.7134146341463414, + 0.514180781404789, + 2.8295183993586364e-05, + 0.012484827687008686, + 0.012471056032792366, + 0.025, + 0.06657283603096412, + 0.16666666666666666, + 0.5714285714285714, + ], + [ + 0.05982905982905984, + 0.050994411431564704, + 0.0449438202247191, + 0.05982905982905984, + 0.05982905982905984, + 0.16666666666666666, + 0.8823529411764706, + 0.05693950177935947, + 0.8353658536585367, + 0.4699156740039142, + 3.268543752245935e-05, + 0.00923366315240946, + 0.014660396468409729, + 0.025, + 0.057987332864180154, + 0.16666666666666666, + 0.7142857142857142, + ], + [ + 0.06837606837606838, + 0.06119533458279619, + 0.056179775280898875, + 0.06837606837606838, + 0.06837606837606838, + 0.16666666666666666, + 0.9411764705882353, + 0.028469750889679707, + 1.0, + 0.6537753400826345, + 3.9270817815768815e-05, + 0.01002929606822616, + 0.01377886297525227, + 0.015, + 0.051372273047149884, + 0.16666666666666666, + 0.8571428571428572, + ], + [ + 0.07692307692307693, + 0.06521583847234458, + 0.056179775280898875, + 0.07692307692307693, + 0.07692307692307693, + 0.16666666666666666, + 1.0, + 0.007117437722419961, + -1.0, + 0.8539203131418077, + 1.9758579909666677e-05, + 0.002676173590325307, + 0.003896139326624358, + 0.025, + 0.06586910626319493, + 0.16666666666666666, + 1.0, + ], + [ + 0.08547008547008549, + 0.07477388917423204, + 0.06741573033707865, + 0.08547008547008549, + 0.08547008547008549, + 0.33333333333333337, + 0.0, + 0.6085409252669042, + 0.07012195121951223, + 0.06017348442747725, + 0.023680786070796774, + 0.09074155275516495, + 0.19638929337502856, + 0.020000000000000004, + 0.07980295566502463, + 0.33333333333333337, + 0.0, + ], + [ + 0.09401709401709403, + 0.0792467847873929, + 0.06741573033707865, + 0.09401709401709403, + 0.09401709401709403, + 0.33333333333333337, + 0.058823529411764705, + 0.43060498220640586, + 0.1859756097560976, + 0.18132746997849566, + 0.04243692475803745, + 0.23105764525702377, + 0.2316847349772711, + 0.025, + 0.0653764954257565, + 0.33333333333333337, + 0.14285714285714285, + ], + [ + 0.10256410256410257, + 0.08835244376566789, + 0.07865168539325842, + 0.10256410256410257, + 0.10256410256410257, + 0.33333333333333337, + 0.7058823529411764, + 0.4661921708185055, + 0.2774390243902439, + 0.10108971416145165, + 0.06585161024536003, + 0.23366315240945865, + 0.4753426385985493, + 0.025, + 0.05650950035186488, + 0.33333333333333337, + 0.2857142857142857, + ], + [ + 0.11111111111111113, + 0.09210763521580445, + 0.07865168539325842, + 0.11111111111111113, + 0.11111111111111113, + 0.33333333333333337, + 0.7647058823529411, + 0.35943060498220647, + 0.36585365853658536, + 0.20575543044917485, + 0.05682720021378778, + 0.4242464682668294, + 0.6025426358703994, + 0.025, + 0.04299788881069668, + 0.33333333333333337, + 0.42857142857142855, + ], + [ + 0.11965811965811968, + 0.10193099835710374, + 0.0898876404494382, + 0.11965811965811968, + 0.11965811965811968, + 0.33333333333333337, + 0.8235294117647058, + 0.25266903914590755, + 0.4542682926829268, + 0.3185927948389591, + 0.04438814854864767, + 0.07704039807065373, + 0.09357213740327856, + 0.020000000000000004, + 0.04750175932441942, + 0.33333333333333337, + 0.5714285714285714, + ], + [ + 0.12820512820512822, + 0.10564197106733833, + 0.0898876404494382, + 0.12820512820512822, + 0.12820512820512822, + 0.33333333333333337, + 0.8823529411764706, + 0.21708185053380794, + 0.5731707317073171, + 0.31247009930654546, + 0.05048572289430458, + 0.09515439218602051, + 0.1216720831812958, + 0.035, + 0.04334975369458127, + 0.33333333333333337, + 0.7142857142857142, + ], + [ + 0.13675213675213677, + 0.11716605497409803, + 0.10112359550561797, + 0.13675213675213677, + 0.13675213675213677, + 0.33333333333333337, + 0.9411764705882353, + 0.17081850533807832, + 0.75, + 0.43848068233986515, + 7.610016686353661e-05, + 0.04019725595612581, + 0.040050948202660634, + 0.04, + 0.0270935960591133, + 0.33333333333333337, + 0.8571428571428572, + ], + [ + 0.1452991452991453, + 0.13245553465558702, + 0.12359550561797752, + 0.1452991452991453, + 0.1452991452991453, + 0.33333333333333337, + 1.0, + 0.13879003558718864, + -1.0, + 0.573402276077029, + 4.1222041606379033e-05, + 0.0177390552812359, + 0.01416591926721889, + 0.025, + 0.029978888106966927, + 0.33333333333333337, + 1.0, + ], + [ + 0.15384615384615385, + 0.12956430935430432, + 0.11235955056179775, + 0.15384615384615385, + 0.15384615384615385, + 0.5, + 0.0, + 0.8220640569395019, + 0.036585365853658514, + 0.021591320946190845, + 0.02102224365609036, + 0.08193366760083631, + 0.17524613028962724, + 0.035, + 0.04665728360309641, + 0.5, + 0.0, + ], + [ + 0.1623931623931624, + 0.13289772205460673, + 0.11235955056179775, + 0.1623931623931624, + 0.1623931623931624, + 0.5, + 0.058823529411764705, + 0.6085409252669042, + 0.09146341463414634, + 0.10724623674100561, + 0.03755886528151192, + 0.27910065518972543, + 0.2988654305873366, + 0.05500000000000001, + 0.038916256157635463, + 0.5, + 0.14285714285714285, + ], + [ + 0.17094017094017097, + 0.14948995384243843, + 0.1348314606741573, + 0.17094017094017097, + 0.17094017094017097, + 0.5, + 0.1176470588235294, + 0.5729537366548044, + 0.201219512195122, + 0.128910044216783, + 0.07292479648632205, + 0.4570377290145464, + 0.5293941119700996, + 0.06, + 0.03335679099225897, + 0.5, + -1.0, + ], + [ + 0.17948717948717952, + 0.15939155013894887, + 0.14606741573033707, + 0.17948717948717952, + 0.17948717948717952, + 0.5, + 0.1764705882352941, + 0.5373665480427048, + 0.25609756097560976, + 0.14179331674197213, + 0.11072975742939495, + 0.48779542320426544, + 0.6062938422242609, + 0.03, + 0.03019000703729768, + 0.5, + -1.0, + ], + [ + 0.18803418803418806, + 0.16985098284653036, + 0.15730337078651685, + 0.18803418803418806, + 0.18803418803418806, + 0.5, + 0.23529411764705882, + 0.5017793594306051, + 0.2835365853658536, + 0.13783555222654456, + 0.14902252432012042, + 0.5493108115837035, + 0.6267549677907782, + 0.03, + 0.027797325826882473, + 0.5, + -1.0, + ], + [ + 0.1965811965811966, + 0.17343610222012087, + 0.15730337078651685, + 0.1965811965811966, + 0.1965811965811966, + 0.5, + 0.2941176470588235, + 0.5017793594306051, + 0.29268292682926833, + 0.13881653659361634, + 0.1743884335980532, + 0.5378719996949651, + 0.5012600643161381, + 0.03, + 0.024982406755805767, + 0.5, + -1.0, + ], + [ + 0.20512820512820515, + 0.1834431432040899, + 0.16853932584269662, + 0.20512820512820515, + 0.20512820512820515, + 0.5, + 0.3529411764705882, + 0.4661921708185055, + 0.25914634146341464, + 0.17107304225964678, + 0.1814616198390152, + 0.38255835382787134, + 0.3972493426863412, + 0.04, + 0.0270935960591133, + 0.5, + -1.0, + ], + [ + 0.2136752136752137, + 0.18652825067263504, + 0.16853932584269662, + 0.2136752136752137, + 0.2136752136752137, + 0.5, + 0.4117647058823529, + 0.43060498220640586, + 0.3445121951219513, + 0.19370816923188441, + 0.1919494477135451, + 0.4560209457355474, + 0.5336568464631241, + 0.035, + 0.024982406755805767, + 0.5, + -1.0, + ], + [ + 0.22222222222222224, + 0.19703190212011848, + 0.1797752808988764, + 0.22222222222222224, + 0.22222222222222224, + 0.5, + 0.47058823529411764, + 0.43060498220640586, + 0.3597560975609756, + 0.19267402807644915, + 0.2160958421223465, + 0.4458531129455577, + 0.5449104655247086, + 0.05500000000000001, + 0.023011963406052074, + 0.5, + -1.0, + ], + [ + 0.2307692307692308, + 0.19621555615269748, + 0.1741573033707865, + 0.2307692307692308, + 0.2307692307692308, + 0.5, + 0.5294117647058824, + 0.3950177935943062, + 0.36890243902439024, + 0.18101819411892625, + 0.2173153569914779, + 0.4351768885160684, + 0.5425233342086149, + 0.04, + 0.024630541871921183, + 0.5, + -1.0, + ], + [ + 0.23931623931623935, + 0.21272275190225615, + 0.19662921348314605, + 0.23931623931623935, + 0.23931623931623935, + 0.5, + 0.5882352941176471, + 0.3950177935943062, + 0.36585365853658536, + 0.1852030830937251, + 0.2185348718606093, + 0.3415311485202626, + 0.4826745419265514, + 0.04, + 0.02047853624208304, + 0.5, + -1.0, + ], + [ + 0.2478632478632479, + 0.2189609956699649, + 0.19662921348314605, + 0.2478632478632479, + 0.2478632478632479, + 0.5, + 0.6470588235294117, + 0.35943060498220647, + 0.28963414634146345, + 0.2657984391233962, + 0.17390062765040062, + 0.17252397384325016, + 0.20048151848833204, + 0.06, + 0.020689655172413793, + 0.5, + -1.0, + ], + [ + 0.2564102564102564, + 0.23373345623875397, + 0.2191011235955056, + 0.2564102564102564, + 0.2564102564102564, + 0.5, + 0.7058823529411764, + 0.4661921708185055, + 0.33841463414634154, + 0.10174209292773093, + 0.14414446484359486, + 0.0733952300154424, + 0.4216321839864411, + 0.05500000000000001, + 0.019493314567206193, + 0.5, + 0.2857142857142857, + ], + [ + 0.26495726495726496, + 0.24365546118444995, + 0.2303370786516854, + 0.26495726495726496, + 0.26495726495726496, + 0.5, + 0.7647058823529411, + 0.35943060498220647, + 0.39939024390243894, + 0.19356319617271125, + 0.12975418938784455, + 0.30434230009087504, + 0.5288825838309366, + 0.07, + 0.015904292751583393, + 0.5, + 0.42857142857142855, + ], + [ + 0.27350427350427353, + 0.2514175507580112, + 0.23595505617977527, + 0.27350427350427353, + 0.27350427350427353, + 0.5, + 0.8235294117647058, + 0.2882562277580072, + 0.451219512195122, + 0.28485756396936235, + 0.1409737261838533, + 0.2735083471552311, + 0.15052227023008535, + 0.05500000000000001, + 0.016537649542575653, + 0.5, + 0.5714285714285714, + ], + [ + 0.28205128205128205, + 0.2651525716598694, + 0.25280898876404495, + 0.28205128205128205, + 0.28205128205128205, + 0.5, + 0.8823529411764706, + 0.25266903914590755, + 0.5640243902439024, + 0.2831082223886727, + 0.11731513772270441, + 0.12200763858438347, + 0.1626284361902748, + 0.085, + 0.01597466572836031, + 0.5, + 0.7142857142857142, + ], + [ + 0.2905982905982906, + 0.2683635324650587, + 0.25280898876404495, + 0.2905982905982906, + 0.2905982905982906, + 0.5, + 0.9411764705882353, + 0.21708185053380794, + 0.6890243902439025, + 0.38272404378186387, + 0.07609553514606364, + 0.06402557209946683, + 0.055889564484942325, + 0.08, + 0.026741731175228708, + 0.5, + 0.8571428571428572, + ], + [ + 0.29914529914529914, + 0.2816087457864643, + 0.2696629213483146, + 0.29914529914529914, + 0.29914529914529914, + 0.5, + 1.0, + 0.18149466192170824, + -1.0, + 0.48835141469543564, + 8.878312150250298e-05, + 0.025865695638635226, + 0.019729640327514418, + 0.1, + 0.010837438423645322, + 0.5, + 1.0, + ], + [ + 0.3076923076923077, + 0.28728915314310205, + 0.2696629213483146, + 0.3076923076923077, + 0.3076923076923077, + 0.6666666666666666, + 0.0, + 0.8932384341637012, + 0.036585365853658514, + 0.013685456785947292, + 0.03731496230768564, + 0.07590668471456988, + 0.16313996432943775, + 0.085, + 0.01893033075299085, + 0.6666666666666666, + 0.0, + ], + [ + 0.3162393162393162, + 0.2946090553176436, + 0.2808988764044944, + 0.3162393162393162, + 0.3162393162393162, + 0.6666666666666666, + 0.058823529411764705, + 0.7153024911032031, + 0.07621951219512194, + 0.08703215985695992, + 0.06438819240240236, + 0.26130694780724334, + 0.2814734738557968, + 0.075, + 0.014567206192821956, + 0.6666666666666666, + 0.14285714285714285, + ], + [ + 0.3247863247863248, + 0.29898330912640775, + 0.2808988764044944, + 0.3247863247863248, + 0.3247863247863248, + 0.6666666666666666, + 0.1176470588235294, + 0.6441281138790037, + 0.15853658536585366, + 0.11227680189431463, + 0.109022436612611, + 0.4537331833577997, + 0.6146488018305888, + 0.09, + 0.014356087262491202, + 0.6666666666666666, + -1.0, + ], + [ + 0.3333333333333333, + 0.3068678505950822, + 0.28651685393258425, + 0.3333333333333333, + 0.3333333333333333, + 0.6666666666666666, + 0.1764705882352941, + 0.6085409252669042, + 0.19207317073170735, + 0.1324087273781622, + 0.15877864327317145, + 0.5366010205962164, + 0.7976053662711987, + 0.085, + 0.012948627726952853, + 0.6666666666666666, + -1.0, + ], + [ + 0.3418803418803419, + 0.312589075250091, + 0.29213483146067415, + 0.3418803418803419, + 0.3418803418803419, + 0.6666666666666666, + 0.23529411764705882, + 0.5729537366548044, + 0.27439024390243905, + 0.13844927151037764, + 0.2090226558813845, + 0.6931856455620589, + 0.8547260084777264, + 0.105, + 0.012033779028852921, + 0.6666666666666666, + -1.0, + ], + [ + 0.35042735042735046, + 0.3229770776855231, + 0.3033707865168539, + 0.35042735042735046, + 0.35042735042735046, + 0.6666666666666666, + 0.2941176470588235, + 0.5373665480427048, + 0.44512195121951226, + 0.15456544325512842, + 0.24877884061506758, + 0.7310608227047707, + 0.8368225236070237, + 0.085, + 0.011048557353976075, + 0.6666666666666666, + -1.0, + ], + [ + 0.358974358974359, + 0.3299160184086015, + 0.3089887640449438, + 0.358974358974359, + 0.358974358974359, + 0.6666666666666666, + 0.3529411764705882, + 0.5373665480427048, + 0.36585365853658536, + 0.16363109188875738, + 0.28048622721248356, + 0.6250611658691273, + 0.8774037559806166, + 0.1, + -1.0, + 0.6666666666666666, + -1.0, + ], + [ + 0.36752136752136755, + 0.3403584439085284, + 0.3202247191011236, + 0.36752136752136755, + 0.36752136752136755, + 0.6666666666666666, + 0.4117647058823529, + 0.5017793594306051, + 0.4573170731707318, + 0.16752120230990405, + 0.30243749485684845, + 0.6377709568566146, + 0.7534434369234653, + 0.065, + 0.010133708655876143, + 0.6666666666666666, + -1.0, + ], + [ + 0.37606837606837606, + 0.346603490559299, + 0.3258426966292135, + 0.37606837606837606, + 0.37606837606837606, + 0.6666666666666666, + 0.47058823529411764, + 0.4661921708185055, + 0.4817073170731707, + 0.17227631865078405, + 0.30243749485684845, + 0.5655793440476872, + 0.67586166915042, + 0.085, + 0.01048557353976073, + 0.6666666666666666, + -1.0, + ], + [ + 0.38461538461538464, + 0.35855615609895475, + 0.33707865168539325, + 0.38461538461538464, + 0.38461538461538464, + 0.6666666666666666, + 0.5294117647058824, + 0.4661921708185055, + 0.4573170731707318, + 0.21470510063546525, + 0.2926813759037974, + 0.4603422746712931, + 0.5510488031946638, + 0.09, + 0.01055594651653765, + 0.6666666666666666, + -1.0, + ], + [ + 0.39316239316239315, + 0.363481443435728, + 0.34269662921348315, + 0.39316239316239315, + 0.39316239316239315, + 0.6666666666666666, + 0.5882352941176471, + 0.4661921708185055, + 0.375, + 0.17794476526445505, + 0.25609592982985585, + 0.31011254519919423, + 0.41447079003816, + 0.12000000000000001, + 0.00992258972554539, + 0.6666666666666666, + -1.0, + ], + [ + 0.4017094017094017, + 0.37893419231070125, + 0.3595505617977528, + 0.4017094017094017, + 0.4017094017094017, + 0.6666666666666666, + 0.6470588235294117, + 0.43060498220640586, + 0.301829268292683, + 0.24644936815908378, + 0.21194949156729978, + 0.1474729758069129, + 0.17661020532739505, + 0.095, + 0.00971147079521464, + 0.6666666666666666, + -1.0, + ], + [ + 0.4102564102564103, + 0.38712146207562764, + 0.3707865168539326, + 0.4102564102564103, + 0.4102564102564103, + 0.6666666666666666, + 0.7058823529411764, + 0.5373665480427048, + 0.3292682926829269, + 0.09145383816174166, + 0.1782908811792736, + 0.10567809912365993, + 0.39912494586327196, + 0.15500000000000003, + 0.009781843771991556, + 0.6666666666666666, + 0.2857142857142857, + ], + [ + 0.4188034188034188, + 0.40035987251397137, + 0.38764044943820225, + 0.4188034188034188, + 0.4188034188034188, + 0.6666666666666666, + 0.7647058823529411, + 0.43060498220640586, + 0.38414634146341464, + 0.16671901804914588, + 0.17780307523162106, + 0.12481904435081566, + 0.4894949171153905, + 0.125, + 0.009429978888106968, + 0.6666666666666666, + 0.42857142857142855, + ], + [ + 0.4273504273504274, + 0.4107342691832799, + 0.398876404494382, + 0.4273504273504274, + 0.4273504273504274, + 0.6666666666666666, + 0.8235294117647058, + 0.35943060498220647, + 0.41158536585365846, + 0.22782516249063714, + 0.16316889680204447, + 0.22620250509980366, + 0.3164278966985974, + 0.13, + 0.007952146375791697, + 0.6666666666666666, + 0.5714285714285714, + ], + [ + 0.4358974358974359, + 0.43059868772385734, + 0.42696629213483145, + 0.4358974358974359, + 0.4358974358974359, + 0.6666666666666666, + 0.8823529411764706, + 0.32384341637010683, + 0.426829268292683, + 0.24721289293739582, + 0.15194936000603573, + 0.1801295127701625, + 0.21429277824573129, + 0.13, + 0.007600281491907108, + 0.6666666666666666, + 0.7142857142857142, + ], + [ + 0.4444444444444445, + 0.42823128441833647, + 0.4157303370786517, + 0.4444444444444445, + 0.4444444444444445, + 0.6666666666666666, + 0.9411764705882353, + 0.2882562277580072, + 0.5975609756097562, + 0.31688211274071565, + 0.12024197340861972, + 0.09468158796128598, + 0.07727144070195302, + 0.105, + 0.008444757213230118, + 0.6666666666666666, + 0.8571428571428572, + ], + [ + 0.452991452991453, + 0.4431602112975479, + 0.43258426966292135, + 0.452991452991453, + 0.452991452991453, + 0.6666666666666666, + 1.0, + 0.25266903914590755, + -1.0, + 0.39799453934810447, + 0.0001414661638489788, + 0.03743668935364358, + 0.027419613352930545, + 0.14, + 0.00450387051372273, + 0.6666666666666666, + 1.0, + ], + [ + 0.46153846153846156, + 0.4486433350453922, + 0.4382022471910112, + 0.46153846153846156, + 0.46153846153846156, + 0.8333333333333334, + 0.0, + 1.0000000000000002, + 0.0274390243902439, + 0.0, + 0.045607663417779054, + 0.0730876530735452, + 0.16024130487418112, + 0.095, + 0.010415200562983815, + 0.8333333333333334, + 0.0, + ], + [ + 0.47008547008547014, + 0.463684509495124, + 0.4550561797752809, + 0.47008547008547014, + 0.47008547008547014, + 0.8333333333333334, + 0.058823529411764705, + 0.8220640569395019, + 0.05792682926829271, + 0.06368183245946799, + 0.08755897491589865, + 0.25113911501725356, + 0.36928580441210074, + 0.11, + 0.007741027445460941, + 0.8333333333333334, + 0.14285714285714285, + ], + [ + 0.47863247863247865, + 0.46905198423091704, + 0.4606741573033708, + 0.47863247863247865, + 0.47863247863247865, + 0.8333333333333334, + 0.1176470588235294, + 0.7864768683274024, + 0.12195121951219515, + 0.08132988619614859, + 0.14999813621542551, + 0.2996905165894547, + 0.6364740024348741, + 0.08, + 0.007107670654468685, + 0.8333333333333334, + -1.0, + ], + [ + 0.4871794871794872, + 0.47317112992486215, + 0.4606741573033708, + 0.4871794871794872, + 0.4871794871794872, + 0.8333333333333334, + -1.0, + 0.7864768683274024, + 0.1280487804878049, + 0.07948389590934354, + 0.16512012059265466, + 0.2686786265799859, + 0.6328933054607335, + 0.08, + 0.006896551724137932, + 0.8333333333333334, + -1.0, + ], + [ + 0.4957264957264958, + 0.47586507161735137, + 0.4606741573033708, + 0.4957264957264958, + 0.4957264957264958, + 0.8333333333333334, + -1.0, + 0.7864768683274024, + 0.13109756097560973, + 0.07630898591345106, + 0.16512012059265466, + 0.3024866706067019, + 0.6460225276992488, + 0.06, + 0.006966924700914849, + 0.8333333333333334, + -1.0, + ], + [ + 0.5042735042735044, + 0.48720547768144135, + 0.47191011235955055, + 0.5042735042735044, + 0.5042735042735044, + 0.8333333333333334, + -1.0, + 0.7508896797153027, + 0.13414634146341461, + 0.07882185227245272, + 0.1709737919644853, + 0.3240933152854302, + 0.5699753443436925, + 0.065, + 0.006755805770584096, + 0.8333333333333334, + -1.0, + ], + [ + 0.5128205128205129, + 0.4897837703618793, + 0.47191011235955055, + 0.5128205128205129, + 0.5128205128205129, + 0.8333333333333334, + -1.0, + 0.7508896797153027, + 0.13109756097560973, + 0.08157634039674291, + 0.17707136631014223, + 0.3024866706067019, + 0.55735765024434, + 0.05500000000000001, + -1.0, + 0.8333333333333334, + -1.0, + ], + [ + 0.5213675213675214, + 0.5080154969676149, + 0.4943820224719101, + 0.5213675213675214, + 0.5213675213675214, + 0.8333333333333334, + -1.0, + 0.7508896797153027, + 0.14329268292682926, + 0.0845579529804045, + 0.18341284362962543, + 0.33832828119141584, + 0.35172333830083996, + 0.07, + 0.007248416608022519, + 0.8333333333333334, + -1.0, + ], + [ + 0.52991452991453, + 0.5134714091832119, + 0.5, + 0.52991452991453, + 0.52991452991453, + 0.8333333333333334, + -1.0, + 0.7508896797153027, + 0.1524390243902439, + 0.08584821320704569, + 0.12780296559723434, + 0.27477932625397977, + 0.30653835267478063, + 0.09, + 0.0061928219563687536, + 0.8333333333333334, + -1.0, + ], + [ + 0.5384615384615385, + 0.5314514291156592, + 0.5224719101123595, + 0.5384615384615385, + 0.5384615384615385, + 0.8333333333333334, + -1.0, + 0.7153024911032031, + 0.1524390243902439, + 0.10902940536883562, + 0.19268115663502394, + 0.3993352779313545, + 0.6039067109081672, + 0.07, + 0.009992962702322309, + 0.8333333333333334, + -1.0, + ], + [ + 0.5470085470085471, + 0.5371488436799516, + 0.5280898876404494, + 0.5470085470085471, + 0.5470085470085471, + 0.8333333333333334, + -1.0, + 0.7153024911032031, + 0.1524390243902439, + 0.09519414308840943, + 0.20072995477129107, + 0.41077408982009295, + 0.596574807580165, + 0.105, + 0.0061928219563687536, + 0.8333333333333334, + -1.0, + ], + [ + 0.5555555555555557, + 0.5493089971529934, + 0.5449438202247191, + 0.5555555555555557, + 0.5555555555555557, + 0.8333333333333334, + -1.0, + 0.7153024911032031, + 0.15853658536585366, + 0.09882330200304443, + 0.20853484993373195, + 0.42348388080758015, + 0.48352708882515627, + 0.09, + 0.005348346235045743, + 0.8333333333333334, + -1.0, + ], + [ + 0.5641025641025642, + 0.557574500073131, + 0.550561797752809, + 0.5641025641025642, + 0.5641025641025642, + 0.8333333333333334, + -1.0, + 0.7153024911032031, + 0.16158536585365854, + 0.10281489356561238, + 0.21463242427938886, + 0.43949821745181405, + 0.509615023922466, + 0.13, + 0.004996481351161154, + 0.8333333333333334, + -1.0, + ], + [ + 0.5726495726495727, + 0.5654964573986455, + 0.5561797752808989, + 0.5726495726495727, + 0.5726495726495727, + 0.8333333333333334, + -1.0, + 0.7153024911032031, + 0.16463414634146342, + 0.10698045279918819, + 0.22121780457269832, + 0.45271640007880076, + 0.596574807580165, + 0.065, + 0.005207600281491907, + 0.8333333333333334, + -1.0, + ], + [ + 0.5811965811965812, + 0.5711938719629379, + 0.5617977528089888, + 0.5811965811965812, + 0.5811965811965812, + 0.8333333333333334, + -1.0, + 0.6797153024911033, + 0.1676829268292683, + 0.11068209824340977, + 0.22731537891835524, + 0.4585629039330449, + 0.37832280153731257, + 0.075, + 0.004644616467276566, + 0.8333333333333334, + -1.0, + ], + [ + 0.5897435897435899, + 0.5852078110703316, + 0.5786516853932584, + 0.5897435897435899, + 0.5897435897435899, + 0.8333333333333334, + -1.0, + 0.6797153024911033, + 0.12195121951219515, + 0.11405997052214464, + 0.1699981800691802, + 0.2752877178934793, + 0.24975872922769482, + 0.065, + 0.004292751583391977, + 0.8333333333333334, + -1.0, + ], + [ + 0.5982905982905984, + 0.5917147687189832, + 0.5842696629213483, + 0.5982905982905984, + 0.5982905982905984, + 0.8333333333333334, + -1.0, + 0.6441281138790037, + 0.17378048780487806, + 0.07403290888443234, + 0.2399983335573216, + 0.4885580106635147, + 0.6259024208921734, + 0.095, + 0.00422237860661506, + 0.8333333333333334, + -1.0, + ], + [ + 0.6068376068376069, + 0.6036980472324172, + 0.5955056179775281, + 0.6068376068376069, + 0.6068376068376069, + 0.8333333333333334, + 0.1764705882352941, + 0.6085409252669042, + 0.1829268292682927, + 0.14164834368279897, + 0.32438876250121335, + 0.6319244530023704, + 0.8306841859370685, + 0.07, + 0.0035186488388458817, + 0.8333333333333334, + -1.0, + ], + [ + 0.6153846153846155, + 0.6120587905154204, + 0.6067415730337078, + 0.6153846153846155, + 0.6153846153846155, + 0.8333333333333334, + 0.23529411764705882, + 0.5729537366548044, + 0.2439024390243902, + 0.1766593374731196, + 0.40731577360214744, + 0.8274010383899237, + 0.9764697055985051, + 0.08, + 0.003237156931738213, + 0.8333333333333334, + -1.0, + ], + [ + 0.623931623931624, + 0.6218957594228435, + 0.6179775280898876, + 0.623931623931624, + 0.623931623931624, + 0.8333333333333334, + 0.2941176470588235, + 0.5373665480427048, + 0.5060975609756098, + 0.19185251407446782, + 0.4707305467969794, + 0.9318755203070687, + 0.99300911543144, + 0.095, + 0.002674173117522871, + 0.8333333333333334, + -1.0, + ], + [ + 0.6324786324786326, + 0.6299469715265329, + 0.6235955056179775, + 0.6324786324786326, + 0.6324786324786326, + 0.8333333333333334, + 0.3529411764705882, + 0.5373665480427048, + 0.36585365853658536, + 0.19037862130620728, + 0.5121940523474464, + 0.8741730692238767, + 1.0, + 0.09, + 0.00302603800140746, + 0.8333333333333334, + -1.0, + ], + [ + 0.6410256410256411, + 0.6436309708054273, + 0.6404494382022472, + 0.6410256410256411, + 0.6410256410256411, + 0.8333333333333334, + 0.4117647058823529, + 0.5017793594306051, + 0.4573170731707318, + 0.21960035760021265, + 0.5512185281596508, + 0.8352811088021659, + 0.9004225222429487, + 0.08, + 0.0025334271639690367, + 0.8333333333333334, + -1.0, + ], + [ + 0.6495726495726497, + 0.650389635127367, + 0.6460674157303371, + 0.6495726495726497, + 0.6495726495726497, + 0.8333333333333334, + 0.47058823529411764, + 0.5017793594306051, + 0.4573170731707318, + 0.24515427549713678, + 0.5512185281596508, + 0.6868307500683152, + 0.8008450444858972, + 0.11, + 0.002603800140745954, + 0.8333333333333334, + -1.0, + ], + [ + 0.6581196581196582, + 0.6601415679965169, + 0.6573033707865168, + 0.6581196581196582, + 0.6581196581196582, + 0.8333333333333334, + 0.5294117647058824, + 0.4661921708185055, + 0.4817073170731707, + 0.2447531833667577, + 0.5243892010387603, + 0.5162653550162368, + 0.6980278885141472, + 0.14500000000000002, + 0.00274454609429979, + 0.8333333333333334, + -1.0, + ], + [ + 0.6666666666666667, + 0.6665464823992409, + 0.6629213483146067, + 0.6666666666666667, + 0.6666666666666667, + 0.8333333333333334, + 0.5882352941176471, + 0.4661921708185055, + 0.5609756097560976, + 0.25764612076255833, + 0.4707305467969794, + 0.33644214820887275, + 0.5328042995645191, + 0.09, + 0.002463054187192118, + 0.8333333333333334, + -1.0, + ], + [ + 0.6752136752136753, + 0.6788699050657669, + 0.6797752808988764, + 0.6752136752136753, + 0.6752136752136753, + 0.8333333333333334, + 0.6470588235294117, + 0.4661921708185055, + 0.39634146341463417, + 0.31621523666851903, + 0.3292668219777389, + 0.05598790027897992, + 0.1067013596417939, + 0.115, + 0.003237156931738213, + 0.8333333333333334, + -1.0, + ], + [ + 0.6837606837606839, + 0.6917715727925495, + 0.6910112359550562, + 0.6837606837606839, + 0.6837606837606839, + 0.8333333333333334, + 0.7058823529411764, + 0.5729537366548044, + 0.4085365853658537, + 0.10700461497571703, + 0.2902423461655346, + 0.14310589162361226, + 0.29698982741040586, + 0.125, + 0.002463054187192118, + 0.8333333333333334, + 0.2857142857142857, + ], + [ + 0.6923076923076924, + 0.7013534335851533, + 0.7022471910112359, + 0.6923076923076924, + 0.6923076923076924, + 0.8333333333333334, + 0.7647058823529411, + 0.4661921708185055, + 0.4969512195121951, + 0.1702370309517481, + 0.27560816773595803, + 0.14910491296970624, + 0.3440504162133959, + 0.13, + 0.002463054187192118, + 0.8333333333333334, + 0.42857142857142855, + ], + [ + 0.7008547008547009, + 0.7074079995101924, + 0.7078651685393258, + 0.7008547008547009, + 0.7008547008547009, + 0.8333333333333334, + 0.8235294117647058, + 0.3950177935943062, + 0.4024390243902439, + 0.1639017082658806, + 0.2392666246358428, + 0.13484961139814058, + 0.31250618096501487, + 0.08, + 0.0019704433497536944, + 0.8333333333333334, + 0.5714285714285714, + ], + [ + 0.7094017094017095, + 0.7108774698717316, + 0.7078651685393258, + 0.7094017094017095, + 0.7094017094017095, + 0.8333333333333334, + 0.8823529411764706, + 0.35943060498220647, + 0.39634146341463417, + 0.21857588131538888, + 0.22731537891835524, + 0.13039610063612506, + 0.20985953437298585, + 0.15500000000000003, + -1.0, + 0.8333333333333334, + 0.7142857142857142, + ], + [ + 0.7179487179487181, + 0.7108774698717316, + 0.7022471910112359, + 0.7179487179487181, + 0.7179487179487181, + 0.8333333333333334, + 0.9411764705882353, + 0.32384341637010683, + 0.4573170731707318, + 0.26124628506535874, + 0.17072988899065902, + 0.14259749998411278, + 0.10329117204737433, + 0.09, + -1.0, + 0.8333333333333334, + 0.8571428571428572, + ], + [ + 0.7264957264957266, + 0.7516947682427814, + 0.7640449438202247, + 0.7264957264957266, + 0.7264957264957266, + 0.8333333333333334, + 1.0, + 0.2882562277580072, + -1.0, + 0.3312441104694711, + 0.00023512490579826907, + 0.047782459217458176, + 0.03530908235262022, + 0.085, + 0.0, + 0.8333333333333334, + 1.0, + ], + [ + 0.7350427350427351, + 0.7550962097737021, + 0.7640449438202247, + 0.7350427350427351, + 0.7350427350427351, + 0.9999999999999999, + 0.0, + -1.0, + 0.0, + 0.00864039432672098, + 0.045607663417779054, + 0.0726936495529331, + 0.161264361152507, + 0.09, + -1.0, + 0.9999999999999999, + 0.0, + ], + [ + 0.7435897435897437, + 0.7653005343664645, + 0.7752808988764045, + 0.7435897435897437, + 0.7435897435897437, + 0.9999999999999999, + 0.058823529411764705, + -1.0, + 0.06097560975609759, + 0.06690506680841815, + 0.1341444429167175, + 0.243767436244511, + 0.34200430365674417, + 0.06, + -1.0, + 0.9999999999999999, + 0.14285714285714285, + ], + [ + 0.7521367521367522, + 0.7687019758973853, + 0.7752808988764045, + 0.7521367521367522, + 0.7521367521367522, + 0.9999999999999999, + 0.1176470588235294, + -1.0, + 0.12195121951219515, + 0.061666706936960886, + 0.24633981087680482, + 0.3327359731569215, + 0.5911185074290938, + 0.04, + 0.0018296973961998584, + 0.9999999999999999, + -1.0, + ], + [ + 0.7606837606837608, + 0.7858384383301644, + 0.797752808988764, + 0.7606837606837608, + 0.7606837606837608, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.11659699905767515, + 0.28536428668900904, + 0.5119440260804912, + 0.8622284211854495, + 0.045, + 0.001337086558761435, + 0.9999999999999999, + -1.0, + ], + [ + 0.7692307692307694, + 0.7824301939161817, + 0.7865168539325842, + 0.7692307692307694, + 0.7692307692307694, + 0.9999999999999999, + -1.0, + -1.0, + 0.2439024390243902, + 0.09646024113852172, + 0.3756083870047315, + 0.4725436740192808, + 0.7324707832177849, + 0.05500000000000001, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.7777777777777779, + 0.8062164745419109, + 0.8202247191011236, + 0.7777777777777779, + 0.7777777777777779, + 0.9999999999999999, + -1.0, + -1.0, + 0.2073170731707317, + 0.11115567690337544, + 0.4634134575821911, + 0.3535800303764005, + 0.7502037587087667, + 0.06, + 0.00154820548909219, + 0.9999999999999999, + -1.0, + ], + [ + 0.7863247863247864, + 0.8027163912065933, + 0.8089887640449438, + 0.7863247863247864, + 0.7863247863247864, + 0.9999999999999999, + -1.0, + -1.0, + 0.201219512195122, + 0.11461570058230844, + 0.49999890365613264, + 0.22851568705952632, + 0.7278670299653185, + 0.75, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.7948717948717949, + 0.826526481923039, + 0.8426966292134831, + 0.7948717948717949, + 0.7948717948717949, + 0.9999999999999999, + -1.0, + -1.0, + 0.17682926829268295, + 0.10304201802498372, + 0.4829256954882932, + 0.22851568705952632, + 0.5962337888207231, + 0.8, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8034188034188036, + 0.8231250403921182, + 0.8314606741573034, + 0.8034188034188036, + 0.8034188034188036, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.10050982192475896, + 0.3341448814542644, + 0.3185010072509358, + 0.4903474640139954, + 0.65, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8119658119658121, + 0.8367308065158015, + 0.848314606741573, + 0.8119658119658121, + 0.8119658119658121, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.10136516297388068, + 0.3292668219777389, + 0.3370573020926671, + 0.5761136820136477, + 0.65, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8205128205128206, + 0.8367308065158015, + 0.8426966292134831, + 0.8205128205128206, + 0.8205128205128206, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.11133930944499479, + 0.3609742085751549, + 0.31646744069293786, + 0.16689117068329928, + 0.4, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8290598290598292, + 0.8503365726394846, + 0.8595505617977528, + 0.8290598290598292, + 0.8290598290598292, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.11538889023123203, + 0.3682912977899432, + 0.48576185664626753, + 0.1992879528302852, + 0.6, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8376068376068377, + 0.8537380141704054, + 0.8595505617977528, + 0.8376068376068377, + 0.8376068376068377, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.12207214825911519, + 0.3292668219777389, + 0.28443876740447005, + -1.0, + 0.6, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8461538461538463, + 0.8707452218250095, + 0.8820224719101123, + 0.8461538461538463, + 0.8461538461538463, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.12593809650373308, + -1.0, + -1.0, + -1.0, + 0.5, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8547008547008548, + 0.8741466633559303, + 0.8820224719101123, + 0.8547008547008548, + 0.8547008547008548, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.12980404474835092, + -1.0, + -1.0, + -1.0, + 0.15000000000000002, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8632478632478634, + 0.8775481048868511, + 0.8820224719101123, + 0.8632478632478634, + 0.8632478632478634, + 0.9999999999999999, + -1.0, + -1.0, + 0.1829268292682927, + 0.1331867494623916, + -1.0, + -1.0, + -1.0, + 0.35, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8717948717948719, + 0.8877524294796135, + 0.8932584269662921, + 0.8717948717948719, + 0.8717948717948719, + 0.9999999999999999, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 1.0000000000000002, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8803418803418804, + 0.8843509879486927, + 0.8820224719101123, + 0.8803418803418804, + 0.8803418803418804, + 0.9999999999999999, + 0.1764705882352941, + -1.0, + -1.0, + -1.0, + 0.4414621899378262, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8888888888888891, + 0.8877524294796135, + 0.8820224719101123, + 0.8888888888888891, + 0.8888888888888891, + 0.9999999999999999, + 0.23529411764705882, + -1.0, + -1.0, + -1.0, + 0.9512194052347446, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.8974358974358976, + 0.9013581956032967, + 0.898876404494382, + 0.8974358974358976, + 0.8974358974358976, + 0.9999999999999999, + 0.2941176470588235, + -1.0, + -1.0, + -1.0, + 0.8536582157042338, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9059829059829061, + 0.894555312541455, + 0.8820224719101123, + 0.9059829059829061, + 0.9059829059829061, + 0.9999999999999999, + 0.3529411764705882, + -1.0, + -1.0, + -1.0, + 0.9024388104694893, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9145299145299146, + 0.9047596371342175, + 0.8932584269662921, + 0.9145299145299146, + 0.9145299145299146, + 0.9999999999999999, + 0.4117647058823529, + -1.0, + -1.0, + -1.0, + 1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9230769230769232, + 0.9081610786651383, + 0.8932584269662921, + 0.9230769230769232, + 0.9230769230769232, + 0.9999999999999999, + 0.47058823529411764, + -1.0, + -1.0, + -1.0, + 0.8536582157042338, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9316239316239318, + 0.9183654032579007, + 0.9044943820224719, + 0.9316239316239318, + 0.9316239316239318, + 0.9999999999999999, + 0.5294117647058824, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9401709401709403, + 0.9217668447888215, + 0.9044943820224719, + 0.9401709401709403, + 0.9401709401709403, + 0.9999999999999999, + 0.5882352941176471, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9487179487179489, + 0.965985584690792, + 0.9719101123595505, + 0.9487179487179489, + 0.9487179487179489, + 0.9999999999999999, + 0.6470588235294117, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + -1.0, + ], + [ + 0.9572649572649574, + 0.9625841431598712, + 0.9606741573033708, + 0.9572649572649574, + 0.9572649572649574, + 0.9999999999999999, + 0.7058823529411764, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + 0.2857142857142857, + ], + [ + 0.9658119658119659, + 0.9795913508144752, + 0.9831460674157303, + 0.9658119658119659, + 0.9658119658119659, + 0.9999999999999999, + 0.7647058823529411, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + 0.42857142857142855, + ], + [ + 0.9743589743589745, + 0.9761899092835544, + 0.9719101123595505, + 0.9743589743589745, + 0.9743589743589745, + 0.9999999999999999, + 0.8235294117647058, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + 0.5714285714285714, + ], + [ + 0.9829059829059831, + 0.9897956754072376, + 0.9887640449438202, + 0.9829059829059831, + 0.9829059829059831, + 0.9999999999999999, + 0.8823529411764706, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + 0.7142857142857142, + ], + [ + 0.9914529914529915, + 1.0, + 1.0, + 0.9914529914529915, + 0.9914529914529915, + 0.9999999999999999, + 0.9411764705882353, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + 0.8571428571428572, + ], + [ + 1.0000000000000002, + 0.9965985584690792, + 0.9887640449438202, + 1.0000000000000002, + 1.0000000000000002, + 0.9999999999999999, + 1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.9999999999999999, + 1.0, + ], +] + + +@torch.jit.script +def gaussian(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """Compute the Gaussian distribution probability density.""" + pi = 3.14159 + a = (2 * pi) ** 0.5 + output: torch.Tensor = torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) + return output + + +class GaussianLayer(nn.Module): + """Gaussian pairwise positional embedding layer.""" + + def __init__(self, k: int = 128, edge_types: int = 1024): + super().__init__() + self.k = k + self.means = nn.Embedding(1, k) + self.stds = nn.Embedding(1, k) + self.mul = nn.Embedding(edge_types, 1) + self.bias = nn.Embedding(edge_types, 1) + nn.init.uniform_(self.means.weight, 0, 3) + nn.init.uniform_(self.stds.weight, 0, 3) + nn.init.constant_(self.bias.weight, 0) + nn.init.constant_(self.mul.weight, 1) + + def forward(self, x: torch.Tensor, edge_types: int) -> torch.Tensor: + """Forward pass to compute the Gaussian pos. embeddings.""" + mul = self.mul(edge_types) + bias = self.bias(edge_types) + x = mul * x.unsqueeze(-1) + bias + x = x.expand(-1, -1, -1, self.k) + mean = self.means.weight.float().view(-1) + std = self.stds.weight.float().view(-1).abs() + 1e-5 + output: torch.Tensor = gaussian(x.float(), mean, std).type_as(self.means.weight) + return output + + +class ParallelBlock(nn.Module): + """Parallel transformer block (MLP & Attention in parallel). + + Based on: + 'Scaling Vision Atomformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 + + Adapted from TIMM implementation. + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: int = 4, + dropout: float = 0.0, + k: int = 128, + gradient_checkpointing: bool = False, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divisible by num_heads {num_heads}" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.mlp_hidden_dim = int(mlp_ratio * dim) + self.proj_drop = nn.Dropout(dropout) + self.attn_drop = nn.Dropout(dropout) + self.gradient_checkpointing = gradient_checkpointing + + self.in_proj_in_dim = dim + self.in_proj_out_dim = self.mlp_hidden_dim + 3 * dim + self.out_proj_in_dim = self.mlp_hidden_dim + dim + self.out_proj_out_dim = 2 * dim + + self.in_split = [self.mlp_hidden_dim] + [dim] * 3 + self.out_split = [dim] * 2 + + self.in_norm = nn.LayerNorm(dim) + self.q_norm = nn.LayerNorm(self.head_dim) + self.k_norm = nn.LayerNorm(self.head_dim) + self.in_proj = nn.Linear(self.in_proj_in_dim, self.in_proj_out_dim, bias=False) + self.act_fn = nn.GELU() + self.out_proj = nn.Linear( + self.out_proj_in_dim, self.out_proj_out_dim, bias=False + ) + self.gaussian_proj = nn.Linear(k, 1) + self.pos_embed_ff_norm = nn.LayerNorm(k) + + def forward( + self, + x: torch.Tensor, + pos_embed: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for the parallel block.""" + b, n, c = x.shape + res = x + + # Combined MLP fc1 & qkv projections + x = self.in_proj(self.in_norm(x)) + x, q, k, v = torch.split(x, self.in_split, dim=-1) + x = self.act_fn(x) + x = self.proj_drop(x) + + # Dot product attention + q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)) + k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)) + v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2) + + x_attn = ( + f.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask + + self.gaussian_proj(self.pos_embed_ff_norm(pos_embed)).permute( + 0, 3, 1, 2 + ), + is_causal=False, + ) + .transpose(1, 2) + .reshape(b, n, c) + ) + + # Combined MLP fc2 & attn_output projection + x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split( + self.out_split, dim=-1 + ) + # Residual connections + x = x_mlp + x_attn + res + del x_mlp, x_attn, res + + return x, pos_embed + + +class AtomformerEncoder(nn.Module): + """Atomformer encoder. + + The transformer encoder consists of a series of parallel blocks, + each containing a multi-head self-attention mechanism and a feed-forward network. + """ + + def __init__(self, config: AtomformerConfig): + super().__init__() + self.vocab_size = config.vocab_size + self.dim = config.dim + self.num_heads = config.num_heads + self.depth = config.depth + self.mlp_ratio = config.mlp_ratio + self.dropout = config.dropout + self.k = config.k + self.gradient_checkpointing = config.gradient_checkpointing + + self.metadata_vocab = nn.Embedding(self.vocab_size, 17) + self.metadata_vocab.weight.requires_grad = False + self.metadata_vocab.weight.fill_(-1) + self.metadata_vocab.weight[1:-4] = torch.tensor( + ATOM_METADATA, dtype=torch.float32 + ) + self.embed_metadata = nn.Linear(17, self.dim) + + self.gaussian_embed = GaussianLayer( + k=self.k, edge_types=(self.vocab_size + 1) ** 2 + ) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.dim) + nn.init.normal_(self.embed_tokens.weight, std=0.02) + + self.blocks = nn.ModuleList() + for _ in range(self.depth): + self.blocks.append( + ParallelBlock( + self.dim, + self.num_heads, + self.mlp_ratio, + self.dropout, + self.k, + self.gradient_checkpointing, + ) + ) + + def _expand_mask( + self, + mask: torch.Tensor, + dtype: torch.dtype, + device: torch.device, + tgt_len: Optional[int] = None, + ) -> torch.Tensor: + """ + Expand attention mask. + + Expands attention_mask from `[bsz, seq_len]` to + `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask: torch.Tensor = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ).to(device) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for the transformer encoder.""" + # pad coords by zeros for graph token + coords_center = torch.sum(coords, dim=1, keepdim=True) / coords.shape[1] + coords = torch.cat([coords_center, coords], dim=1) + + r_ij = torch.cdist(coords, coords, p=2) # [B, N, N] + # pad input_ids by graph token + input_ids = torch.cat( + [ + torch.zeros( + input_ids.size(0), 1, dtype=torch.long, device=input_ids.device + ).fill_(122), + input_ids, + ], + dim=1, + ) + edge_type = input_ids.unsqueeze(-1) * self.vocab_size + input_ids.unsqueeze( + -2 + ) # [B, N, N] + pos_embeds = self.gaussian_embed(r_ij, edge_type) # [B, N, N, K] + + input_embeds = self.embed_tokens(input_ids) + atom_metadata = self.metadata_vocab(input_ids) + input_embeds = input_embeds + self.embed_metadata(atom_metadata) # [B, N, C] + + attention_mask = ( + torch.cat( + [ + torch.ones( + attention_mask.size(0), + 1, + dtype=torch.bool, + device=attention_mask.device, + ), + attention_mask.bool(), + ], + dim=1, + ) + if attention_mask is not None + else None + ) + + attention_mask = ( + self._expand_mask(attention_mask, input_embeds.dtype, input_embeds.device) + if attention_mask is not None + else None + ) + + for blk in self.blocks: + input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask) + + return input_embeds, pos_embeds + + +class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore + """Base class for all transformer models.""" + + config_class = AtomformerConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ParallelBlock"] + + def _set_gradient_checkpointing( + self, module: nn.Module, value: bool = False + ) -> None: + if isinstance(module, (AtomformerEncoder)): + module.gradient_checkpointing = value + + +class AtomformerModel(AtomformerPreTrainedModel): + """Atomformer model for atom modeling.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward function call for the transformer model.""" + output: torch.Tensor = self.encoder(input_ids, coords, attention_mask) + return output[0][:, :-1] + + +class AtomformerForMaskedAM(AtomformerPreTrainedModel): + """Atomformer with an atom modeling head on top for masked atom modeling.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.am_head = nn.Linear(config.dim, config.vocab_size, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + labels: Optional[torch.Tensor] = None, + fixed: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function call for the masked atom modeling model.""" + hidden_states = self.encoder(input_ids, coords, attention_mask) + logits = self.am_head(hidden_states) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1) + loss = loss_fct(logits, labels) + + return loss, logits + + +class AtomformerForCoordinateAM(AtomformerPreTrainedModel): + """Atomformer with an atom coordinate head on top for coordinate denoising.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.coords_head = nn.Linear(config.dim, 3) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + labels_coords: Optional[torch.Tensor] = None, + fixed: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function call for the coordinate atom modeling model.""" + hidden_states = self.encoder(input_ids, coords, attention_mask) + coords_pred = self.coords_head(hidden_states) + + loss = None + if labels_coords is not None: + labels_coords = labels_coords.to(coords_pred.device) + loss_fct = nn.L1Loss() + loss = loss_fct(coords_pred, labels_coords) + + return loss, coords_pred + + +class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel): + """Atomformer with an coordinate head on top for relaxed structure prediction.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.coords_head = nn.Linear(config.dim, 3) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + labels_coords: Optional[torch.Tensor] = None, + fixed: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function call. + + Initial structure to relaxed structure model. + """ + hidden_states = self.encoder(input_ids, coords, attention_mask) + coords_pred = self.coords_head(hidden_states) + + loss = None + if labels_coords is not None: + labels_coords = labels_coords.to(coords_pred.device) + loss_fct = nn.L1Loss() + loss = loss_fct(coords_pred, labels_coords) + + return loss, coords_pred + + +class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel): + """Atomformer with an energy head on top for relaxed energy prediction.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.energy_norm = nn.LayerNorm(config.dim) + self.energy_head = nn.Linear(config.dim, 1, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + labels_energy: Optional[torch.Tensor] = None, + fixed: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function call for the relaxed energy prediction model.""" + hidden_states = self.encoder(input_ids, coords, attention_mask) + energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1) + + loss = None + if labels_energy is not None: + loss_fct = nn.L1Loss() + loss = loss_fct(energy, labels_energy) + + return loss, energy + + +class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel): + """Atomformer with an coordinate and energy head.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.energy_norm = nn.LayerNorm(config.dim) + self.energy_head = nn.Linear(config.dim, 1, bias=False) + self.coords_head = nn.Linear(config.dim, 3) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + labels_coords: Optional[torch.Tensor] = None, + forces: Optional[torch.Tensor] = None, + total_energy: Optional[torch.Tensor] = None, + formation_energy: Optional[torch.Tensor] = None, + has_formation_energy: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Forward function call for the relaxed structure and energy model.""" + atom_hidden_states, pos_hidden_states = self.encoder( + input_ids, coords, attention_mask + ) + + formation_energy_pred = self.formation_energy_head( + self.energy_norm(atom_hidden_states[:, 0]) + ).squeeze(-1) + loss_formation_energy = None + if formation_energy is not None: + loss_fct = nn.L1Loss() + loss_formation_energy = loss_fct( + formation_energy_pred[has_formation_energy], + formation_energy[has_formation_energy], + ) + coords_pred = self.coords_head(atom_hidden_states[:, 1:]) + loss_coords = None + if labels_coords is not None: + loss_fct = nn.L1Loss() + loss_coords = loss_fct(coords_pred, labels_coords) + + loss = torch.Tensor(0).to(coords.device) + loss = ( + loss + loss_formation_energy if loss_formation_energy is not None else loss + ) + loss = loss + loss_coords if loss_coords is not None else loss + + return loss, (formation_energy_pred, coords_pred) + + +class Structure2Energy(AtomformerPreTrainedModel): + """Atomformer with an atom modeling head on top for masked atom modeling.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.energy_norm = nn.LayerNorm(config.dim) + self.formation_energy_head = nn.Linear(config.dim, 1, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + forces: Optional[torch.Tensor] = None, + total_energy: Optional[torch.Tensor] = None, + formation_energy: Optional[torch.Tensor] = None, + has_formation_energy: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward function call for the structure to energy model.""" + atom_hidden_states, pos_hidden_states = self.encoder( + input_ids, coords, attention_mask + ) + + formation_energy_pred: torch.Tensor = self.formation_energy_head( + self.energy_norm(atom_hidden_states[:, 0]) + ).squeeze(-1) + loss = torch.Tensor(0).to(coords.device) + if formation_energy is not None: + loss_fct = nn.L1Loss() + loss = loss_fct( + formation_energy_pred[has_formation_energy], + formation_energy[has_formation_energy], + ) + + return loss, ( + formation_energy_pred, + attention_mask.bool() if attention_mask is not None else None, + ) + + +class Structure2Forces(AtomformerPreTrainedModel): + """Atomformer with a forces head on top for forces prediction.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.force_norm = nn.LayerNorm(config.dim) + self.force_head = nn.Linear(config.dim, 3) + self.energy_norm = nn.LayerNorm(config.dim) + self.formation_energy_head = nn.Linear(config.dim, 1, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + forces: Optional[torch.Tensor] = None, + total_energy: Optional[torch.Tensor] = None, + formation_energy: Optional[torch.Tensor] = None, + has_formation_energy: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward function call for the structure to forces model.""" + atom_hidden_states, pos_hidden_states = self.encoder( + input_ids, coords, attention_mask + ) + attention_mask = attention_mask.bool() if attention_mask is not None else None + + forces_pred: torch.Tensor = self.force_head( + self.force_norm(atom_hidden_states[:, 1:]) + ) + loss = torch.Tensor(0).to(coords.device) + if forces is not None: + loss_fct = nn.L1Loss() + loss = loss_fct(forces_pred[attention_mask], forces[attention_mask]) + + return loss, ( + forces_pred, + attention_mask if attention_mask is not None else None, + ) + + +class Structure2EnergyAndForces(AtomformerPreTrainedModel): + """Atomformer with an energy and forces head for energy and forces prediction.""" + + def __init__(self, config: AtomformerConfig): + super().__init__(config) + self.config = config + self.encoder = AtomformerEncoder(config) + self.force_norm = nn.LayerNorm(config.dim) + self.force_head = nn.Linear(config.dim, 3) + self.energy_norm = nn.LayerNorm(config.dim) + self.formation_energy_head = nn.Linear(config.dim, 1, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + coords: torch.Tensor, + forces: Optional[torch.Tensor] = None, + total_energy: Optional[torch.Tensor] = None, + formation_energy: Optional[torch.Tensor] = None, + has_formation_energy: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: + """Forward function call for the structure to energy and forces model.""" + atom_hidden_states, pos_hidden_states = self.encoder( + input_ids, coords, attention_mask + ) + + formation_energy_pred: torch.Tensor = self.formation_energy_head( + self.energy_norm(atom_hidden_states[:, 0]) + ).squeeze(-1) + loss_formation_energy = None + if formation_energy is not None: + loss_fct = nn.L1Loss() + loss_formation_energy = loss_fct( + formation_energy_pred[has_formation_energy], + formation_energy[has_formation_energy], + ) + attention_mask = attention_mask.bool() if attention_mask is not None else None + forces_pred: torch.Tensor = self.force_head( + self.force_norm(atom_hidden_states[:, 1:]) + ) + loss_forces = None + if forces is not None: + loss_fct = nn.L1Loss() + loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask]) + + loss = torch.Tensor(0).to(coords.device) + loss = ( + loss + loss_formation_energy if loss_formation_energy is not None else loss + ) + loss = loss + loss_forces if loss_forces is not None else loss + + return loss, (formation_energy_pred, forces_pred, attention_mask) diff --git a/Transfer Learning/Atomistic graph data/readme.md b/Transfer Learning/Atomistic graph data/readme.md new file mode 100644 index 00000000..5eacaf13 --- /dev/null +++ b/Transfer Learning/Atomistic graph data/readme.md @@ -0,0 +1,32 @@ +# AtomFormer Base Model + +This model is a transformer-based architecture that utilizes Gaussian pair-wise positional embeddings to train on atomistic graph data. AtomFormer is part of the AtomGen project, which supports a range of methods for pre-training and fine-tuning models on atomistic graphs. + +## Model Description + +AtomFormer is a transformer model designed to work with atomistic graph data. It builds on the work from Uni-Mol+ by adding pair-wise positional embeddings to the attention mask, enabling the model to leverage 3D positional information. The model has been pre-trained on a diverse set of aggregated atomistic datasets, with the target tasks being per-atom force prediction and per-system energy prediction. + +In addition to the graph data, the model includes metadata about the atomic species being modeled, such as atomic radius, electronegativity, and valency. This metadata is normalized and projected into the atom embeddings within the model. + +## Intended Uses & Limitations + +While AtomFormer can be used for force and energy prediction, it is primarily intended to be fine-tuned for downstream tasks. The model’s performance on force and energy prediction tasks has not been extensively validated, as it was primarily used for pre-training tasks. + + +## Training Data + +AtomFormer was trained on an aggregated S2EF dataset sourced from multiple datasets, including OC20, OC22, ODAC23, MPtrj, and SPICE. This dataset contains structures, energies, and forces for pre-training. The model was trained using formation energy, although this data is not available for OC22, as indicated by the "has_formation_energy" column in the dataset. + +### Preprocessing + +The model expects input in the form of tokenized atomic symbols (`input_ids`) and 3D coordinates (`coords`). During pre-training, the model also requires labels for `forces` and `formation_energy`. + +The `DataCollatorForAtomModeling` utility in the AtomGen library provides dynamic padding for batching data and supports flattening the data for graph neural network (GNN)-style training. + +### Pretraining Details + +The model was trained on a node with 4xA40 (48 GB) GPUs for 10 epochs, which took approximately two weeks. For detailed hyperparameters and training code, refer to the [training code repository](https://github.com/VectorInstitute/AtomGen). + + + +