diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index a736e98f..22f4296c 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -230,6 +230,8 @@ def __init__( def forward(self, data: AtomicDataDict.Type): + if data.get(AtomicDataDict.EDGE_TYPE_KEY, None) is None: + self.idp(data) data = self.embedding(data) if hasattr(self, "overlap"): @@ -368,6 +370,10 @@ def __init__( def forward(self, data: AtomicDataDict.Type): + + if data.get(AtomicDataDict.EDGE_TYPE_KEY, None) is None: + self.idp(data) + data_nnenv = self.nnenv(data) if hasattr(self, "nnsk"): data_sk = self.nnsk(data) diff --git a/dptb/nn/hr2hk.py b/dptb/nn/hr2hk.py index dfcdb3f4..0f34889b 100644 --- a/dptb/nn/hr2hk.py +++ b/dptb/nn/hr2hk.py @@ -57,8 +57,10 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: bondwise_hopping.type(self.dtype) onsite_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device) - - if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all(): + soc = data.get(AtomicDataDict.NODE_SOC_SWITCH_KEY, False) + if isinstance(soc, torch.Tensor): + soc = soc.all() + if soc: if self.overlap: raise NotImplementedError("Overlap is not implemented for SOC.") @@ -92,7 +94,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if i <= j: onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1) - if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all() and i==j: + if soc and i==j: soc_updn_tmp = orbpair_soc[:,self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1)) soc_upup_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,:2*lj+1] soc_updn_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,2*lj+1:] @@ -150,7 +152,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: block = block + block.transpose(1,2).conj() block = block.contiguous() - if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all(): + if soc: HK_SOC = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device) #HK_SOC[:,:all_norb,:all_norb] = block + block_uu #HK_SOC[:,:all_norb,all_norb:] = block_ud diff --git a/dptb/nn/nnsk.py b/dptb/nn/nnsk.py index 33900b3e..31c1af95 100644 --- a/dptb/nn/nnsk.py +++ b/dptb/nn/nnsk.py @@ -250,6 +250,8 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data = AtomicDataDict.with_edge_vectors(data, with_lengths=True) + if data.get(AtomicDataDict.EDGE_TYPE_KEY, None) is None: + self.idp_sk(data) # edge_number = data[AtomicDataDict.ATOMIC_NUMBERS_KEY][data[AtomicDataDict.EDGE_INDEX_KEY]].reshape(2, -1) # edge_index = self.idp_sk.transform_reduced_bond(*edge_number)