You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I tried to load PromptEHR from pretrained, a bug occurred:
AttributeError Traceback (most recent call last)
Input In [9], in <cell line: 5>()
3 vocs = data['voc']
4 model = PromptEHR()
----> 5 model.from_pretrained()
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/pytrial/tasks/trial_simulation/sequence/promptehr.py:222, in PromptEHR.from_pretrained(self, input_dir)
211 def from_pretrained(self, input_dir='./simulation/pretrained_promptEHR'):
212 '''
213 Load pretrained PromptEHR model and make patient EHRs generation.
214 Pretrained model was learned from MIMIC-III patient sequence data.
(...)
220 to this folder.
221 '''
--> 222 self.model.from_pretrained(input_dir=input_dir)
223 self.config.update(self.model.config)
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:359, in PromptEHR.from_pretrained(self, input_dir)
356 print(f'Download pretrained PromptEHR model, save to {input_dir}.')
358 print('Load pretrained PromptEHR model from', input_dir)
--> 359 self.load_model(input_dir)
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:298, in PromptEHR.load_model(self, checkpoint)
295 self._load_tokenizer(data_tokenizer_file, model_tokenizer_file)
297 # load configuration
--> 298 self.configuration = EHRBartConfig(self.data_tokenizer, self.model_tokenizer, n_num_feature=self.config['n_num_feature'], cat_cardinalities=self.config['cat_cardinalities'])
299 self.configuration.from_pretrained(checkpoint)
301 # build model
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/modeling_config.py:24, in EHRBartConfig(data_tokenizer, model_tokenizer, **kwargs)
22 bart_config = BartConfig.from_pretrained('facebook/bart-base')
23 kwargs.update(model_tokenizer.get_num_tokens)
---> 24 kwargs['data_tokenizer_num_vocab'] = len(data_tokenizer)
25 if 'd_prompt_hidden' not in kwargs:
26 kwargs['d_prompt_hidden'] = 128
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:431, in PreTrainedTokenizer.__len__(self)
426 def __len__(self):
427 """
428 Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
429 there is a hole in the vocab, we will add tokenizers at a wrong index.
430 """
--> 431 return len(set(self.get_vocab().keys()))
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/models/bart/tokenization_bart.py:243, in BartTokenizer.get_vocab(self)
242 def get_vocab(self):
--> 243 return dict(self.encoder, **self.added_tokens_encoder)
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:391, in PreTrainedTokenizer.added_tokens_encoder(self)
385 @property
386 def added_tokens_encoder(self) -> Dict[str, int]:
387 """
388 Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
389 optimisation in `self._added_tokens_encoder` for the slow tokenizers.
390 """
--> 391 return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
AttributeError: 'DataTokenizer' object has no attribute '_added_tokens_decoder'
——————————————————————————
My codes are:
from pytrial.tasks.trial_simulation.data import SequencePatient
from pytrial.data.demo_data import load_synthetic_ehr_sequence
data = load_synthetic_ehr_sequence()
train_data = SequencePatient(
data={
'v': data['visit'],
'y': data['y'],
'x': data['feature'],
},
metadata={
'visit': {'mode': 'dense'},
'label': {'mode': 'tensor'},
'voc': data['voc'],
'max_visit': 20,
'n_num_feature': data['n_num_feature'],
'cat_cardinalities': data['cat_cardinalities'],
}
)
from pytrial.tasks.trial_simulation.sequence import PromptEHR
vocs = data['voc']
model = PromptEHR()
model.from_pretrained()
When I tried to load PromptEHR from pretrained, a bug occurred:
——————————————————————————
My codes are:
I can directly load BartTokenizer successfully:
Could you please help me to fix this bug?
The text was updated successfully, but these errors were encountered: