This GitHub repository is a general-purpose EEG neural network training framework. Some of the code is derived and reorganized from the author's upcoming private repository EEG-IRT
.🔥 🔥
Quick Start:
- Create environment by
$ conda create --name <your env name> --file <requirements.txt>
. In the current version, CUDA has been embedded into the environment. When creating the environment with this command, there's no need for an external installation or modification of your CUDA environment - Edit
config.ini
to modify training parameters. - Run
main.py
.
The architectures
directory includes five deep neural network models implemented using PyTorch
. Detailed descriptions are as follows:
The models in this project are mainly used for Motor Imagery classification tasks. They are implemented based on the PyTorch
framework, offering excellent scalability and ease of use. The author of EEGSym provides the model code based on tensorflow
, which can be directly downloaded from the links provided in the original text.
model_standard_EEGNet.py
: Contains the implementation code for the EEGNet model. Research Papermodel_standard_InceptionEEG.py
: Contains the implementation code for the InceptionEEG model. Research Papermodel_standard_Deep.py
: Contains the implementation code for the DeepConvNet model. Research Paper Code Referencemodel_standard_ShallowFBCSPNet.py
: Contains the implementation code for the ShallowConvNet model. Research Paper Code Referencemodel_standard_EEGSym.py
: Contains the implementation code for the EEGSym model. Research Paperutils
: Contains dependencies for the DeepConvNet and ShallowConvNet models.
To use the models, simply incorporate them into your framework. A demo for the model input and output is available in each model file. The input shape is: [batch_size, n, num_channels, num_sampling]
. n
is typically 1. If there are other operations, such as two trials as one sample, you can set n
to 2.
data_loader.py
: A class that implements a PyTorch-specific EEG data loading function.parse_config.py
: A class that implements a function for reading training parameters.test.py
: A function containing the model's test code.
Stores training results
Stores example datasets. It currently consists of BCICIV 2a dataset. The DL_BCICIV2a_c2RL_640_no_filter_pro
folder stores data indices divided by 5-fold cross-validation, these indices also abide by the criteria of cross-subject studies. For instance, MI_train0, MI_eval0, and MI_test0 represent the data indices for the training set, validation set, and test set of the 0th fold, respectively.
The code of this project follows the MIT LICENSE.