diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
new file mode 100644
index 000000000..80a301542
--- /dev/null
+++ b/.github/workflows/main.yml
@@ -0,0 +1,30 @@
+name: build
+
+on:
+ push:
+ branches: [ master ]
+ pull_request:
+ branches: [ master ]
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [ 3.7, 3.8, 3.9 ]
+ max-parallel: 3
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ - name: Test
+ run: |
+ python -m unittest discover
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index eeff9f585..028221ea9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,3 +41,4 @@ tmp.py
# models files
*.dat
+!examples/*.dat
diff --git a/README.md b/README.md
index a043b85ea..2437541b8 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,12 @@
# Medical oncept Annotation Tool
+[![Build Status](https://github.com/CogStack/MedCAT/actions/workflows/main.yml/badge.svg?branch=master)](https://github.com/CogStack/MedCAT/actions/workflows/main.yml?query=branch%3Amaster)
+[![Latest release](https://img.shields.io/github/v/release/CogStack/MedCAT)](https://github.com/CogStack/MedCAT/releases/latest)
+[![pypi Version](https://img.shields.io/pypi/v/medcat.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/medcat/)
+
MedCAT can be used to extract information from Electronic Health Records (EHRs) and link it to biomedical ontologies like SNOMED-CT and UMLS. Paper on [arXiv](https://arxiv.org/abs/2010.01165).
-## News
+## News
- **New Feature and Tutorial \[8. July 2021\]**: [Integrating 🤗 Transformers with MedCAT for biomedical NER+L](https://towardsdatascience.com/integrating-transformers-with-medcat-for-biomedical-ner-l-8869c76762a)
- **General \[1. April 2021\]**: MedCAT is upgraded to v1, unforunately this introduces breaking changes with older models (MedCAT v0.4),
as well as potential problems with all code that used the MedCAT package. MedCAT v0.4 is available on the legacy
@@ -30,9 +34,9 @@ A guide on how to use MedCAT is available in the [tutorial](https://github.com/C
2. Get the scispacy models:
-`pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.3.0/en_core_sci_md-0.3.0.tar.gz`
+`pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_md-0.4.0.tar.gz`
-`pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.3.0/en_core_sci_lg-0.3.0.tar.gz`
+`pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_md-0.4.0.tar.gz`
3. Downlad the Vocabulary and CDB from the Models section bellow
@@ -98,7 +102,7 @@ CDB [Download](https://medcat.rosalind.kcl.ac.uk/media/cdb-medmen-v1.dat) - Buil
MetaCAT Status [Download](https://medcat.rosalind.kcl.ac.uk/media/mc_status.zip) - Built from a sample from MIMIC-III, detects is an annotation Affirmed (Positve) or Other (Negated or Hypothetical)
-(Note: This is was compiled from MedMentions and does not have any data from [NLM](https://www.nlm.nih.gov/research/umls/) as
+(Note: This was compiled from MedMentions and does not have any data from [NLM](https://www.nlm.nih.gov/research/umls/) as
that data is not publicaly available.)
### SNOMED-CT and UMLS
diff --git a/examples/cdb.dat b/examples/cdb.dat
new file mode 100644
index 000000000..d972e22f5
Binary files /dev/null and b/examples/cdb.dat differ
diff --git a/examples/vocab.dat b/examples/vocab.dat
new file mode 100644
index 000000000..32fd3ed2b
Binary files /dev/null and b/examples/vocab.dat differ
diff --git a/examples/vocab_data.txt b/examples/vocab_data.txt
new file mode 100644
index 000000000..da3f43048
--- /dev/null
+++ b/examples/vocab_data.txt
@@ -0,0 +1,2 @@
+house 34444 0.3232 0.123213 1.231231
+dog 14444 0.76762 0.76767 1.45454
diff --git a/medcat/cat.py b/medcat/cat.py
index 9a7db9f15..ce8a24628 100644
--- a/medcat/cat.py
+++ b/medcat/cat.py
@@ -70,7 +70,7 @@ def __init__(self, cdb, config, vocab, meta_cats=[]):
# Build the pipeline
self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
- self.nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=self.config),
+ self.nlp.add_tagger(tagger=tag_skip_and_punct,
name='skip_and_punct',
additional_fields=['is_punct'])
@@ -116,7 +116,7 @@ def __call__(self, text, do_train=False):
Returns:
A spacy document with the extracted entities
'''
- # Should we train - do not use this for training, unles you know what you are doing. Use the
+ # Should we train - do not use this for training, unless you know what you are doing. Use the
#self.train() function
self.config.linking['train'] = do_train
diff --git a/medcat/cdb_maker.py b/medcat/cdb_maker.py
index c1cdc758e..3597831aa 100644
--- a/medcat/cdb_maker.py
+++ b/medcat/cdb_maker.py
@@ -44,28 +44,28 @@ def __init__(self, config, cdb=None, name_max_words=20):
# Build the required spacy pipeline
self.nlp = Pipe(tokenizer=spacy_split_all, config=config)
- self.nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=self.config),
+ self.nlp.add_tagger(tagger=tag_skip_and_punct,
name='skip_and_punct',
additional_fields=['is_punct'])
def prepare_csvs(self, csv_paths, sep=',', encoding=None, escapechar=None, index_col=False, full_build=False, only_existing_cuis=False, **kwargs):
- r''' Compile one or multipe CSVs into a CDB.
+ r''' Compile one or multiple CSVs into a CDB.
Args:
csv_paths (`List[str]`):
An array of paths to the csv files that should be processed
- full_build (`bool`, defautls to `True`):
+ full_build (`bool`, defaults to `True`):
If False only the core portions of the CDB will be built (the ones required for
the functioning of MedCAT). If True, everything will be added to the CDB - this
usually includes concept descriptions, various forms of names etc (take care that
this option produces a much larger CDB).
sep (`str`, defaults to `,`):
- If necessarya a custom separator for the csv files
+ If necessary a custom separator for the csv files
encoding (`str`, optional):
- Encoing to be used for reading the CSV file
+ Encoding to be used for reading the CSV file
escapechar (`str`, optional):
- Escapechar for the CSV
+ Escape char for the CSV
index_col (`bool`, defaults_to `False`):
Index column for pandas read_csv
only_existing_cuis (`bool`, defaults to False):
diff --git a/medcat/linking/context_based_linker.py b/medcat/linking/context_based_linker.py
index e7fb09578..94d1a01dc 100644
--- a/medcat/linking/context_based_linker.py
+++ b/medcat/linking/context_based_linker.py
@@ -1,7 +1,8 @@
-from medcat.utils.filters import check_filters
-from medcat.linking.vector_context_model import ContextModel
import random
import logging
+from medcat.utils.filters import check_filters
+from medcat.linking.vector_context_model import ContextModel
+
class Linker(object):
r''' Link to a biomedical database.
diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py
index a511074f7..a8566bef9 100644
--- a/medcat/meta_cat.py
+++ b/medcat/meta_cat.py
@@ -10,6 +10,7 @@
from medcat.preprocessing.tokenizers import TokenizerWrapperBPE
from medcat.preprocessing.tokenizers import TokenizerWrapperBERT
+
class MetaCAT(object):
r''' TODO: Add documentation
'''
diff --git a/medcat/ner/vocab_based_ner.py b/medcat/ner/vocab_based_ner.py
index 9aef8f41f..7c12c5ff2 100644
--- a/medcat/ner/vocab_based_ner.py
+++ b/medcat/ner/vocab_based_ner.py
@@ -1,5 +1,6 @@
-from medcat.ner.vocab_based_annotator import maybe_annotate_name
import logging
+from medcat.ner.vocab_based_annotator import maybe_annotate_name
+
class NER(object):
r'''
diff --git a/medcat/pipe.py b/medcat/pipe.py
index 470fc146c..7fd7c5a10 100644
--- a/medcat/pipe.py
+++ b/medcat/pipe.py
@@ -1,7 +1,8 @@
+import spacy
from spacy.tokens import Token, Doc, Span
+from spacy.language import Language
from medcat.utils.normalizers import TokenNormalizer
-import spacy
-import os
+
class Pipe(object):
r''' A wrapper around the standard spacy pipeline.
@@ -21,7 +22,7 @@ def __init__(self, tokenizer, config):
if config.preprocessing['stopwords'] is not None:
self.nlp.Defaults.stop_words = set(config.preprocessing['stopwords'])
self.nlp.tokenizer = tokenizer(self.nlp)
-
+ self.config = config
def add_tagger(self, tagger, name, additional_fields=[]):
r''' Add any kind of a tagger for tokens.
@@ -35,7 +36,9 @@ def add_tagger(self, tagger, name, additional_fields=[]):
additional_fields (`List[str]`):
Fields to be added to the `_` properties of a token.
'''
- self.nlp.add_pipe(tagger, name='tag_' + name, first=True)
+ component_factory_name = spacy.util.get_object_name(tagger)
+ Language.factory(name=component_factory_name, default_config={"config": self.config}, func=tagger)
+ self.nlp.add_pipe(component_factory_name, name='tag_' + name, first=True)
# Add custom fields needed for this usecase
Token.set_extension('to_skip', default=False, force=True)
@@ -43,21 +46,23 @@ def add_tagger(self, tagger, name, additional_fields=[]):
for field in additional_fields:
Token.set_extension(field, default=False, force=True)
-
def add_token_normalizer(self, config, spell_checker=None):
token_normalizer = TokenNormalizer(spell_checker=spell_checker, config=config)
- self.nlp.add_pipe(token_normalizer, name='token_normalizer', last=True)
+ component_name = spacy.util.get_object_name(token_normalizer)
+ Language.component(name=component_name, func=token_normalizer)
+ self.nlp.add_pipe(component_name, name='token_normalizer', last=True)
# Add custom fields needed for this usecase
Token.set_extension('norm', default=None, force=True)
-
def add_ner(self, ner):
r''' Add NER from CAT to the pipeline, will also add the necessary fields
to the document and Span objects.
'''
- self.nlp.add_pipe(ner, name='cat_ner', last=True)
+ component_name = spacy.util.get_object_name(ner)
+ Language.component(name=component_name, func=ner)
+ self.nlp.add_pipe(component_name, name='cat_ner', last=True)
Doc.set_extension('ents', default=[], force=True)
Span.set_extension('confidence', default=-1, force=True)
@@ -67,7 +72,6 @@ def add_ner(self, ner):
Span.set_extension('detected_name', default=None, force=True)
Span.set_extension('link_candidates', default=None, force=True)
-
def add_linker(self, linker):
r''' Add entity linker to the pipeline, will also add the necessary fields
to Span object.
@@ -76,18 +80,20 @@ def add_linker(self, linker):
Any object/function created based on the requirements for a spaCy pipeline components. Have
a look at https://spacy.io/usage/processing-pipelines#custom-components
'''
- self.nlp.add_pipe(linker, name='cat_linker', last=True)
+ component_name = spacy.util.get_object_name(linker)
+ Language.component(name=component_name, func=linker)
+ self.nlp.add_pipe(component_name, name='cat_linker', last=True)
Span.set_extension('cui', default=-1, force=True)
Span.set_extension('context_similarity', default=-1, force=True)
-
def add_meta_cat(self, meta_cat, name):
- self.nlp.add_pipe(meta_cat, name=name, last=True)
+ component_name = spacy.util.get_object_name(meta_cat)
+ Language.component(name=component_name, func=meta_cat)
+ self.nlp.add_pipe(component_name, name=name, last=True)
# Only the meta_anns field is needed, it will be a dictionary
#of {category_name: value, ...}
Span.set_extension('meta_anns', default=None, force=True)
-
def __call__(self, text):
return self.nlp(text)
diff --git a/medcat/preprocessing/taggers.py b/medcat/preprocessing/taggers.py
index 474e2d019..6ec2ac6e0 100644
--- a/medcat/preprocessing/taggers.py
+++ b/medcat/preprocessing/taggers.py
@@ -1,30 +1,38 @@
-import re
-
-def tag_skip_and_punct(doc, config):
+def tag_skip_and_punct(nlp, name, config):
r''' Detects and tags spacy tokens that are punctuation and that should be skipped.
- Args:
- doc (`spacy.tokens.Doc`):
- Spacy document that will be tagged.
- config (`medcat.config.Config`):
- Global config for medcat.
-
- Return:
- (`spacy.tokens.Doc):
- Tagged spacy document
+ Args:
+ nlp (spacy.language.):
+ The base spacy NLP pipeline.
+ name (`str`):
+ The component instance name.
+ config (`medcat.config.Config`):
+ Global config for medcat.
'''
- # Make life easier
- cnf_p = config.preprocessing
-
- for token in doc:
- if config.punct_checker.match(token.lower_) and token.text not in cnf_p['keep_punct']:
- # There can't be punct in a token if it also has text
- token._.is_punct = True
- token._.to_skip = True
- elif config.word_skipper.match(token.lower_):
- # Skip if specific strings
- token._.to_skip = True
- elif cnf_p['skip_stopwords'] and token.is_stop:
- token._.to_skip = True
-
- return doc
+
+ return _Tagger(nlp, name, config)
+
+
+class _Tagger(object):
+
+ def __init__(self, nlp, name, config):
+ self.nlp = nlp
+ self.name = name
+ self.config = config
+
+ def __call__(self, doc):
+ # Make life easier
+ cnf_p = self.config.preprocessing
+
+ for token in doc:
+ if self.config.punct_checker.match(token.lower_) and token.text not in cnf_p['keep_punct']:
+ # There can't be punct in a token if it also has text
+ token._.is_punct = True
+ token._.to_skip = True
+ elif self.config.word_skipper.match(token.lower_):
+ # Skip if specific strings
+ token._.to_skip = True
+ elif cnf_p['skip_stopwords'] and token.is_stop:
+ token._.to_skip = True
+
+ return doc
diff --git a/medcat/utils/make_vocab.py b/medcat/utils/make_vocab.py
index 01547d4c9..a9ebfc18f 100644
--- a/medcat/utils/make_vocab.py
+++ b/medcat/utils/make_vocab.py
@@ -40,7 +40,7 @@ def __init__(self, config, cdb=None, vocab=None, word_tokenizer=None):
# Build the required spacy pipeline
self.nlp = Pipe(tokenizer=spacy_split_all, config=config)
- self.nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=self.config),
+ self.nlp.add_tagger(tagger=tag_skip_and_punct,
name='skip_and_punct',
additional_fields=['is_punct'])
diff --git a/medcat/utils/normalizers.py b/medcat/utils/normalizers.py
index cc15bb0f6..8e6a07e27 100644
--- a/medcat/utils/normalizers.py
+++ b/medcat/utils/normalizers.py
@@ -1,9 +1,6 @@
-#import hunspell
import re
-from collections import Counter
-from spacy.tokens import Span
import spacy
-import os
+
CONTAINS_NUMBER = re.compile('[0-9]+')
diff --git a/medcat/vocab.py b/medcat/vocab.py
index d2558744a..f21de2c8a 100644
--- a/medcat/vocab.py
+++ b/medcat/vocab.py
@@ -134,7 +134,7 @@ def add_word(self, word, cnt=1, vec=None, replace=True):
cnt (int):
count of this word in your dataset
vec (np.array):
- the vector repesentation of the word
+ the vector representation of the word
replace (bool):
will replace old vector representation
"""
@@ -170,17 +170,16 @@ def add_words(self, path, replace=True):
replace (bool):
existing words in the vocabulary will be replaced
"""
- f = open(path)
+ with open(path) as f:
+ for line in f:
+ parts = line.split("\t")
+ word = parts[0]
+ cnt = int(parts[1].strip())
+ vec = None
+ if len(parts) == 3:
+ vec = np.array([float(x) for x in parts[2].strip().split(" ")])
- for line in f:
- parts = line.split("\t")
- word = parts[0]
- cnt = int(parts[1].strip())
- vec = None
- if len(parts) == 3:
- vec = np.array([float(x) for x in parts[2].strip().split(" ")])
-
- self.add_word(word, cnt, vec)
+ self.add_word(word, cnt, vec, replace)
def make_unigram_table(self, table_size=100000000):
@@ -232,13 +231,17 @@ def get_negative_samples(self, n=6, ignore_punct_and_num=False):
def __getitem__(self, word):
- return self.vocab[word]['cnt']
+ return self.count(word)
def vec(self, word):
return self.vocab[word]['vec']
+ def count(self, word):
+ return self.vocab[word]['cnt']
+
+
def item(self, word):
return self.vocab[word]
diff --git a/requirements-lg.txt b/requirements-lg.txt
new file mode 100644
index 000000000..7d293df4e
--- /dev/null
+++ b/requirements-lg.txt
@@ -0,0 +1,2 @@
+.
+https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_lg-0.4.0.tar.gz
\ No newline at end of file
diff --git a/requirements-sm.txt b/requirements-sm.txt
new file mode 100644
index 000000000..03885cbf2
--- /dev/null
+++ b/requirements-sm.txt
@@ -0,0 +1,2 @@
+.
+https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_sm-0.4.0.tar.gz
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 9c558e357..78c7cf284 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,2 @@
.
+https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_md-0.4.0.tar.gz
\ No newline at end of file
diff --git a/setup.py b/setup.py
index a3a5652a0..f65830479 100644
--- a/setup.py
+++ b/setup.py
@@ -21,11 +21,11 @@
'numpy~=1.20',
'pandas~=1.0',
'gensim~=3.8',
- 'spacy==2.3.4',
+ 'spacy<3.1.0,>=3.0.1',
'scipy~=1.5',
'transformers~=4.5.1',
'torch~=1.8.1',
- 'Flask~=1.1',
+ 'tqdm<4.50.0,>=4.27',
'sklearn~=0.0',
'elasticsearch~=7.10',
'dill~=0.3.3',
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/archive_tests/test_cdb_maker_archive.py b/tests/archive_tests/test_cdb_maker_archive.py
index 422fde2cd..8b8bb2acd 100644
--- a/tests/archive_tests/test_cdb_maker_archive.py
+++ b/tests/archive_tests/test_cdb_maker_archive.py
@@ -1,119 +1,124 @@
-r''' The tests here are a bit messy but they work, should be converted to python unittests.
-'''
-from medcat.cdb_maker import CDBMaker
-from medcat.config import Config
-import numpy as np
import logging
-
-config = Config()
-config.general['log_level'] = logging.DEBUG
-maker = CDBMaker(config)
-
-# Building a new CDB from two files (full_build)
-csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv']
-cdb = maker.prepare_csvs(csvs, full_build=True)
-
-assert len(cdb.cui2names) == 3
-assert len(cdb.cui2snames) == 3
-assert len(cdb.name2cuis) == 5
-assert len(cdb.cui2tags) == 3
-assert len(cdb.cui2preferred_name) == 2
-assert len(cdb.cui2context_vectors) == 3
-assert len(cdb.cui2count_train) == 3
-assert cdb.name2cuis2status['virus']['C0000039'] == 'P'
-assert cdb.cui2type_ids['C0000039'] == {'T234', 'T109', 'T123'}
-assert cdb.addl_info['cui2original_names']['C0000039'] == {'Virus', 'Virus K', 'Virus M', 'Virus Z'}
-assert cdb.addl_info['cui2description']['C0000039'].startswith("Synthetic")
-
-# Test name addition
-from medcat.preprocessing.cleaners import prepare_name
-cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', maker.nlp, {}, config), name_status='P', full_build=True)
-assert cdb.addl_info['cui2original_names']['C0000239'] == {'MY: new,-_! Name.', 'Second csv'}
-assert 'my:newname.' in cdb.name2cuis
-assert 'my:new' in cdb.snames
-assert 'my:newname.' in cdb.name2cuis2status
-assert cdb.name2cuis2status['my:newname.'] == {'C0000239': 'P'}
-
-# Test name removal
-cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', maker.nlp, {}, config))
-# Run again to make sure it does not break anything
-cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', maker.nlp, {}, config))
-assert len(cdb.name2cuis) == 5
-assert 'my:newname.' not in cdb.name2cuis2status
-
-# Test filtering
-cuis_to_keep = {'C0000039'} # Because of transition 2 will be kept
-cdb.filter_by_cui(cuis_to_keep=cuis_to_keep)
-assert len(cdb.cui2names) == 2
-assert len(cdb.name2cuis) == 4
-assert len(cdb.snames) == 4
-
-# Test vector addition
+import unittest
import numpy as np
-cdb.reset_training()
-np.random.seed(11)
-cuis = list(cdb.cui2names.keys())
-for i in range(2):
- for cui in cuis:
- vectors = {}
- for cntx_type in config.linking['context_vector_sizes']:
- vectors[cntx_type] = np.random.rand(300)
- cdb.update_context_vector(cui, vectors, negative=False)
-
-assert cdb.cui2count_train['C0000139'] == 2
-assert cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300
-
-
-# Test negative
-for cui in cuis:
- vectors = {}
- for cntx_type in config.linking['context_vector_sizes']:
- vectors[cntx_type] = np.random.rand(300)
- cdb.update_context_vector(cui, vectors, negative=True)
-
-assert cdb.cui2count_train['C0000139'] == 2
-assert cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300
-
-# Test save/load
from medcat.cdb import CDB
-cdb.save("./tmp_cdb.dat")
-cdb2 = CDB.load('./tmp_cdb.dat')
-# Check a random thing
-assert cdb2.cui2context_vectors['C0000139']['long'][7] == cdb.cui2context_vectors['C0000139']['long'][7]
-
-# Test training import
-cdb.reset_training()
-cdb2.reset_training()
-np.random.seed(11)
-cuis = list(cdb.cui2names.keys())
-for i in range(2):
- for cui in cuis:
- vectors = {}
- for cntx_type in config.linking['context_vector_sizes']:
- vectors[cntx_type] = np.random.rand(300)
- cdb.update_context_vector(cui, vectors, negative=False)
-
-cdb2.import_training(cdb=cdb, overwrite=True)
-assert cdb2.cui2context_vectors['C0000139']['long'][7] == cdb.cui2context_vectors['C0000139']['long'][7]
-assert cdb2.cui2count_train['C0000139'] == cdb.cui2count_train['C0000139']
-
-# Test concept similarity
-cdb = CDB(config=config)
-np.random.seed(11)
-for i in range(500):
- cui = "C" + str(i)
- type_ids = {'T-' + str(i%10)}
- cdb.add_concept(cui=cui, names=prepare_name('Name: ' + str(i), maker.nlp, {}, config), ontologies=set(),
- name_status='P', type_ids=type_ids, description='', full_build=True)
-
- vectors = {}
- for cntx_type in config.linking['context_vector_sizes']:
- vectors[cntx_type] = np.random.rand(300)
- cdb.update_context_vector(cui, vectors, negative=False)
-res = cdb.most_similar('C200', 'long', type_id_filter=['T-0'], min_cnt=1, topn=10, force_build=True)
-assert len(res) == 10
-
-# Test training reset
-cdb.reset_training()
-assert len(cdb.cui2context_vectors['C0']) == 0
-assert cdb.cui2count_train['C0'] == 0
+from medcat.cdb_maker import CDBMaker
+from medcat.config import Config
+from medcat.preprocessing.cleaners import prepare_name
+
+
+class CdbMakerArchiveTests(unittest.TestCase):
+
+ def setUp(self):
+ self.config = Config()
+ self.config.general['log_level'] = logging.DEBUG
+ self.maker = CDBMaker(self.config)
+
+ # Building a new CDB from two files (full_build)
+ csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv']
+ self.cdb = self.maker.prepare_csvs(csvs, full_build=True)
+
+ def test_prepare_csvs(self):
+ assert len(self.cdb.cui2names) == 3
+ assert len(self.cdb.cui2snames) == 3
+ assert len(self.cdb.name2cuis) == 5
+ assert len(self.cdb.cui2tags) == 3
+ assert len(self.cdb.cui2preferred_name) == 2
+ assert len(self.cdb.cui2context_vectors) == 3
+ assert len(self.cdb.cui2count_train) == 3
+ assert self.cdb.name2cuis2status['virus']['C0000039'] == 'P'
+ assert self.cdb.cui2type_ids['C0000039'] == {'T234', 'T109', 'T123'}
+ assert self.cdb.addl_info['cui2original_names']['C0000039'] == {'Virus', 'Virus K', 'Virus M', 'Virus Z'}
+ assert self.cdb.addl_info['cui2description']['C0000039'].startswith("Synthetic")
+
+ def test_name_addition(self):
+ self.cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.nlp, {}, self.config), name_status='P', full_build=True)
+ assert self.cdb.addl_info['cui2original_names']['C0000239'] == {'MY: new,-_! Name.', 'Second csv'}
+ assert 'my:newname.' in self.cdb.name2cuis
+ assert 'my:new' in self.cdb.snames
+ assert 'my:newname.' in self.cdb.name2cuis2status
+ assert self.cdb.name2cuis2status['my:newname.'] == {'C0000239': 'P'}
+
+ def test_name_removal(self):
+ self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.nlp, {}, self.config))
+ # Run again to make sure it does not break anything
+ self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.nlp, {}, self.config))
+ assert len(self.cdb.name2cuis) == 5
+ assert 'my:newname.' not in self.cdb.name2cuis2status
+
+ def test_filtering(self):
+ cuis_to_keep = {'C0000039'} # Because of transition 2 will be kept
+ self.cdb.filter_by_cui(cuis_to_keep=cuis_to_keep)
+ assert len(self.cdb.cui2names) == 2
+ assert len(self.cdb.name2cuis) == 4
+ assert len(self.cdb.snames) == 4
+
+ def test_vector_addition(self):
+ self.cdb.reset_training()
+ np.random.seed(11)
+ cuis = list(self.cdb.cui2names.keys())
+ for i in range(2):
+ for cui in cuis:
+ vectors = {}
+ for cntx_type in self.config.linking['context_vector_sizes']:
+ vectors[cntx_type] = np.random.rand(300)
+ self.cdb.update_context_vector(cui, vectors, negative=False)
+
+ assert self.cdb.cui2count_train['C0000139'] == 2
+ assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300
+
+
+ def test_negative(self):
+ cuis = list(self.cdb.cui2names.keys())
+ for cui in cuis:
+ vectors = {}
+ for cntx_type in self.config.linking['context_vector_sizes']:
+ vectors[cntx_type] = np.random.rand(300)
+ self.cdb.update_context_vector(cui, vectors, negative=True)
+
+ assert self.cdb.cui2count_train['C0000139'] == 2
+ assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300
+
+ def test_save_and_load(self):
+ self.cdb.save("./tmp_cdb.dat")
+ cdb2 = CDB.load('./tmp_cdb.dat')
+ # Check a random thing
+ assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7]
+
+ def test_training_import(self):
+ cdb2 = CDB.load('./tmp_cdb.dat')
+ self.cdb.reset_training()
+ cdb2.reset_training()
+ np.random.seed(11)
+ cuis = list(self.cdb.cui2names.keys())
+ for i in range(2):
+ for cui in cuis:
+ vectors = {}
+ for cntx_type in self.config.linking['context_vector_sizes']:
+ vectors[cntx_type] = np.random.rand(300)
+ self.cdb.update_context_vector(cui, vectors, negative=False)
+
+ cdb2.import_training(cdb=self.cdb, overwrite=True)
+ assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7]
+ assert cdb2.cui2count_train['C0000139'] == self.cdb.cui2count_train['C0000139']
+
+ def test_concept_similarity(self):
+ cdb = CDB(config=self.config)
+ np.random.seed(11)
+ for i in range(500):
+ cui = "C" + str(i)
+ type_ids = {'T-' + str(i%10)}
+ cdb.add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.nlp, {}, self.config), ontologies=set(),
+ name_status='P', type_ids=type_ids, description='', full_build=True)
+
+ vectors = {}
+ for cntx_type in self.config.linking['context_vector_sizes']:
+ vectors[cntx_type] = np.random.rand(300)
+ cdb.update_context_vector(cui, vectors, negative=False)
+ res = cdb.most_similar('C200', 'long', type_id_filter=['T-0'], min_cnt=1, topn=10, force_build=True)
+ assert len(res) == 10
+
+ def test_training_reset(self):
+ self.cdb.reset_training()
+ assert len(self.cdb.cui2context_vectors['C0']) == 0
+ assert self.cdb.cui2count_train['C0'] == 0
diff --git a/tests/archive_tests/test_ner_archive.py b/tests/archive_tests/test_ner_archive.py
index 1be695f35..6037f3c16 100644
--- a/tests/archive_tests/test_ner_archive.py
+++ b/tests/archive_tests/test_ner_archive.py
@@ -1,3 +1,9 @@
+import logging
+import os
+import unittest
+import numpy as np
+from timeit import default_timer as timer
+from medcat.cdb import CDB
from medcat.preprocessing.tokenizers import spacy_split_all
from medcat.ner.vocab_based_ner import NER
from medcat.preprocessing.taggers import tag_skip_and_punct
@@ -6,127 +12,127 @@
from medcat.vocab import Vocab
from medcat.preprocessing.cleaners import prepare_name
from medcat.linking.vector_context_model import ContextModel
-from functools import partial
from medcat.linking.context_based_linker import Linker
from medcat.config import Config
-import logging
-from medcat.cdb import CDB
-import os
-import requests
-
-config = Config()
-config.general['log_level'] = logging.INFO
-cdb = CDB(config=config)
-
-nlp = Pipe(tokenizer=spacy_split_all, config=config)
-nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=config),
- name='skip_and_punct',
- additional_fields=['is_punct'])
-
-# Add a couple of names
-cdb.add_names(cui='S-229004', names=prepare_name('Movar', nlp, {}, config))
-cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', nlp, {}, config))
-cdb.add_names(cui='S-229005', names=prepare_name('CDB', nlp, {}, config))
-# Check
-#assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}}
-
-vocab_path = "./tmp_vocab.dat"
-if not os.path.exists(vocab_path):
- import requests
- tmp = requests.get("https://s3-eu-west-1.amazonaws.com/zkcl/vocab.dat")
- with open(vocab_path, 'wb') as f:
- f.write(tmp.content)
-
-vocab = Vocab.load(vocab_path)
-# Make the pipeline
-nlp = Pipe(tokenizer=spacy_split_all, config=config)
-nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=config),
- name='skip_and_punct',
- additional_fields=['is_punct'])
-spell_checker = BasicSpellChecker(cdb_vocab=cdb.vocab, config=config, data_vocab=vocab)
-nlp.add_token_normalizer(spell_checker=spell_checker, config=config)
-ner = NER(cdb, config)
-nlp.add_ner(ner)
-
-# Add Linker
-link = Linker(cdb, vocab, config)
-nlp.add_linker(link)
-
-# Test limits for tokens and uppercase
-config.ner['max_skip_tokens'] = 1
-config.ner['upper_case_limit_len'] = 4
-config.linking['disamb_length_limit'] = 2
-text = "CDB - I was running and then Movar Virus attacked and CDb"
-d = nlp(text)
-
-assert len(d._.ents) == 2
-assert d._.ents[0]._.link_candidates[0] == 'S-229004'
-
-# Change limit for skip
-config.ner['max_skip_tokens'] = 3
-d = nlp(text)
-assert len(d._.ents) == 3
-
-# Change limit for upper_case
-config.ner['upper_case_limit_len'] = 3
-d = nlp(text)
-assert len(d._.ents) == 4
-
-# Check name length limit
-config.ner['min_name_len'] = 4
-d = nlp(text)
-assert len(d._.ents) == 2
-
-# Speed tests
-from timeit import default_timer as timer
-text = "CDB - I was running and then Movar Virus attacked and CDb"
-text = text * 300
-config.general['spell_check'] = True
-start = timer()
-for i in range(50):
- d = nlp(text)
-end = timer()
-print("Time: ", end - start)
-
-# Now without spell check
-config.general['spell_check'] = False
-start = timer()
-for i in range(50):
- d = nlp(text)
-end = timer()
-print("Time: ", end - start)
-
-
-# Test for linker
-import numpy as np
-
-config = Config()
-config.general['log_level'] = logging.DEBUG
-cdb = CDB(config=config)
-
-# Add a couple of names
-cdb.add_names(cui='S-229004', names=prepare_name('Movar', nlp, {}, config))
-cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', nlp, {}, config))
-cdb.add_names(cui='S-229005', names=prepare_name('CDB', nlp, {}, config))
-cdb.add_names(cui='S-2290045', names=prepare_name('Movar', nlp, {}, config))
-# Check
-#assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}, 'S-2290045': {'movar'}}
-
-cuis = list(cdb.cui2names.keys())
-for cui in cuis[0:50]:
- vectors = {'short': np.random.rand(300),
- 'long': np.random.rand(300),
- 'medium': np.random.rand(300)
- }
- cdb.update_context_vector(cui, vectors, negative=False)
-
-vocab = Vocab.load(vocab_path)
-cm = ContextModel(cdb, vocab, config)
-cm.train_using_negative_sampling('S-229004')
-config.linking['train_count_threshold'] = 0
-
-cm.train('S-229004', d._.ents[1], d)
-cm.similarity('S-229004', d._.ents[1], d)
-cm.disambiguate(['S-2290045', 'S-229004'], d._.ents[1], 'movar', d)
+class NerArchiveTests(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.config = Config()
+ self.config.general['log_level'] = logging.INFO
+ cdb = CDB(config=self.config)
+
+ self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
+ self.nlp.add_tagger(tagger=tag_skip_and_punct,
+ name='skip_and_punct',
+ additional_fields=['is_punct'])
+
+ # Add a couple of names
+ cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config))
+ cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config))
+ cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config))
+ # Check
+ #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}}
+
+ self.vocab_path = "./tmp_vocab.dat"
+ if not os.path.exists(self.vocab_path):
+ import requests
+ tmp = requests.get("https://medcat.rosalind.kcl.ac.uk/media/vocab.dat")
+ with open(self.vocab_path, 'wb') as f:
+ f.write(tmp.content)
+
+ vocab = Vocab.load(self.vocab_path)
+ # Make the pipeline
+ self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
+ self.nlp.add_tagger(tagger=tag_skip_and_punct,
+ name='skip_and_punct',
+ additional_fields=['is_punct'])
+ spell_checker = BasicSpellChecker(cdb_vocab=cdb.vocab, config=self.config, data_vocab=vocab)
+ self.nlp.add_token_normalizer(spell_checker=spell_checker, config=self.config)
+ ner = NER(cdb, self.config)
+ self.nlp.add_ner(ner)
+
+ # Add Linker
+ link = Linker(cdb, vocab, self.config)
+ self.nlp.add_linker(link)
+
+ self.text = "CDB - I was running and then Movar Virus attacked and CDb"
+
+ def test_limits_for_tokens_and_uppercase(self):
+ self.config.ner['max_skip_tokens'] = 1
+ self.config.ner['upper_case_limit_len'] = 4
+ self.config.linking['disamb_length_limit'] = 2
+
+ d = self.nlp(self.text)
+
+ assert len(d._.ents) == 2
+ assert d._.ents[0]._.link_candidates[0] == 'S-229004'
+
+ def test_change_limit_for_skip(self):
+ self.config.ner['max_skip_tokens'] = 3
+ d = self.nlp(self.text)
+ assert len(d._.ents) == 3
+
+ def test_change_limit_for_upper_case(self):
+ self.config.ner['upper_case_limit_len'] = 3
+ d = self.nlp(self.text)
+ assert len(d._.ents) == 4
+
+ def test_check_name_length_limit(self):
+ self.config.ner['min_name_len'] = 4
+ d = self.nlp(self.text)
+ assert len(d._.ents) == 2
+
+ def test_speed(self):
+ text = "CDB - I was running and then Movar Virus attacked and CDb"
+ text = text * 300
+ self.config.general['spell_check'] = True
+ start = timer()
+ for i in range(50):
+ d = self.nlp(text)
+ end = timer()
+ print("Time: ", end - start)
+
+ def test_without_spell_check(self):
+ # Now without spell check
+ self.config.general['spell_check'] = False
+ start = timer()
+ for i in range(50):
+ d = self.nlp(self.text)
+ end = timer()
+ print("Time: ", end - start)
+
+
+ def test_for_linker(self):
+ self.config = Config()
+ self.config.general['log_level'] = logging.DEBUG
+ cdb = CDB(config=self.config)
+
+ # Add a couple of names
+ cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config))
+ cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config))
+ cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config))
+ cdb.add_names(cui='S-2290045', names=prepare_name('Movar', self.nlp, {}, self.config))
+ # Check
+ #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}, 'S-2290045': {'movar'}}
+
+ cuis = list(cdb.cui2names.keys())
+ for cui in cuis[0:50]:
+ vectors = {'short': np.random.rand(300),
+ 'long': np.random.rand(300),
+ 'medium': np.random.rand(300)
+ }
+ cdb.update_context_vector(cui, vectors, negative=False)
+
+ d = self.nlp(self.text)
+ vocab = Vocab.load(self.vocab_path)
+ cm = ContextModel(cdb, vocab, self.config)
+ cm.train_using_negative_sampling('S-229004')
+ self.config.linking['train_count_threshold'] = 0
+
+ cm.train('S-229004', d._.ents[1], d)
+
+ cm.similarity('S-229004', d._.ents[1], d)
+
+ cm.disambiguate(['S-2290045', 'S-229004'], d._.ents[1], 'movar', d)
diff --git a/tests/test_cat.py b/tests/test_cat.py
new file mode 100644
index 000000000..869bcba54
--- /dev/null
+++ b/tests/test_cat.py
@@ -0,0 +1,27 @@
+import os
+import unittest
+from medcat.vocab import Vocab
+from medcat.cdb import CDB
+from medcat.cat import CAT
+
+
+class CATTests(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
+ self.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat"))
+ self.cdb.config.ner['min_name_len'] = 2
+ self.cdb.config.ner['upper_case_limit_len'] = 3
+ self.cdb.config.general['spell_check'] = True
+ self.cdb.config.linking['train_count_threshold'] = 10
+ self.cdb.config.linking['similarity_threshold'] = 0.3
+ self.cdb.config.linking['train'] = True
+ self.cdb.config.linking['disamb_length_limit'] = 5
+ self.cdb.config.general['full_unlink'] = True
+ self.undertest = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab)
+
+ def test_pipeline(self):
+ text = "The dog is sitting outside the house."
+ doc = self.undertest(text)
+ self.assertEqual(text, doc.text)
+
diff --git a/tests/test_cdb.py b/tests/test_cdb.py
new file mode 100644
index 000000000..0d9bad237
--- /dev/null
+++ b/tests/test_cdb.py
@@ -0,0 +1,51 @@
+import os
+import shutil
+import unittest
+from medcat.config import Config
+from medcat.cdb_maker import CDBMaker
+
+
+class CDBTests(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ config = Config()
+ config.general["spacy_model"] = "en_core_sci_md"
+ cls.cdb_maker = CDBMaker(config)
+
+ def setUp(self) -> None:
+ cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.csv")
+ cdb_2_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_2.csv")
+ self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
+ os.makedirs(self.tmp_dir, exist_ok=True)
+ self.undertest = CDBTests.cdb_maker.prepare_csvs([cdb_csv, cdb_2_csv], full_build=True)
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.tmp_dir)
+
+ def test_name2cuis(self):
+ self.assertEqual({
+ 'second~csv': ['C0000239'],
+ 'virus': ['C0000039', 'C0000139'],
+ 'virus~k': ['C0000039', 'C0000139'],
+ 'virus~m': ['C0000039', 'C0000139'],
+ 'virus~z': ['C0000039', 'C0000139']
+ }, self.undertest.name2cuis)
+
+ def test_cui2names(self):
+ self.assertEqual({
+ 'C0000039': {'virus~z', 'virus~k', 'virus~m', 'virus'},
+ 'C0000139': {'virus~z', 'virus', 'virus~m', 'virus~k'},
+ 'C0000239': {'second~csv'}
+ }, self.undertest.cui2names)
+
+ def test_cui2preferred_name(self):
+ self.assertEqual({'C0000039': 'Virus', 'C0000139': 'Virus Z'}, self.undertest.cui2preferred_name)
+
+ def test_cui2type_ids(self):
+ self.assertEqual({'C0000039': {'T109', 'T234', 'T123'}, 'C0000139': set(), 'C0000239': set()}, self.undertest.cui2type_ids)
+
+ def test_save_and_load(self):
+ cdb_path = f"{self.tmp_dir}/cdb.dat"
+ self.undertest.save(cdb_path)
+ self.undertest.load(cdb_path)
diff --git a/tests/test_cdb_maker.py b/tests/test_cdb_maker.py
index 5a96a701f..eb58ee481 100644
--- a/tests/test_cdb_maker.py
+++ b/tests/test_cdb_maker.py
@@ -1,10 +1,11 @@
import unittest
+import logging
+import os
+import numpy as np
from medcat.cdb_maker import CDBMaker
from medcat.cdb import CDB
from medcat.config import Config
from medcat.preprocessing.cleaners import prepare_name
-import numpy as np
-import logging
#cdb.csv
#cui name ontologies name_status type_ids description
@@ -20,6 +21,7 @@
#TESTS RUN IN ALPHABETICAL ORDER - CONTROLLING WITH '[class_letter]Class and test_[classletter subclassletter]' function syntax
+
class A_CDBMakerLoadTests(unittest.TestCase):
@classmethod
@@ -28,7 +30,10 @@ def setUpClass(cls):
config = Config()
config.general['log_level'] = logging.DEBUG
maker = CDBMaker(config)
- csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv']
+ csvs = [
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'examples', 'cdb.csv'),
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'examples', 'cdb_2.csv')
+ ]
cls.cdb = maker.prepare_csvs(csvs, full_build=True)
def test_aa_cdb_names_length(self):
@@ -110,7 +115,10 @@ def setUpClass(cls):
cls.config = Config()
cls.config.general['log_level'] = logging.DEBUG
cls.maker = CDBMaker(cls.config)
- csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv']
+ csvs = [
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'examples', 'cdb.csv'),
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'examples', 'cdb_2.csv')
+ ]
cls.cdb = cls.maker.prepare_csvs(csvs, full_build=True)
cls.cdb2 = CDB(cls.config)
diff --git a/tests/test_ner.py b/tests/test_ner.py
index e05be9659..14720c205 100644
--- a/tests/test_ner.py
+++ b/tests/test_ner.py
@@ -28,7 +28,7 @@ def setUpClass(cls):
print("Set up Vocab")
vocab_path = "./tmp_vocab.dat"
if not os.path.exists(vocab_path):
- tmp = requests.get("https://s3-eu-west-1.amazonaws.com/zkcl/vocab.dat")
+ tmp = requests.get("https://medcat.rosalind.kcl.ac.uk/media/vocab.dat")
with open(vocab_path, 'wb') as f:
f.write(tmp.content)
@@ -37,7 +37,7 @@ def setUpClass(cls):
print("Set up NLP pipeline")
cls.nlp = Pipe(tokenizer=spacy_split_all, config=cls.config)
- cls.nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=cls.config),
+ cls.nlp.add_tagger(tagger=tag_skip_and_punct,
name='skip_and_punct',
additional_fields=['is_punct'])
diff --git a/tests/test_vocab.py b/tests/test_vocab.py
new file mode 100644
index 000000000..8db82df41
--- /dev/null
+++ b/tests/test_vocab.py
@@ -0,0 +1,39 @@
+import os
+import shutil
+import unittest
+from medcat.vocab import Vocab
+
+
+class CATTests(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.undertest = Vocab()
+ self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
+ os.makedirs(self.tmp_dir, exist_ok=True)
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.tmp_dir)
+
+ def test_add_words(self):
+ self.undertest.add_words(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_data.txt"))
+ self.assertEqual(["house", "dog"], list(self.undertest.vocab.keys()))
+
+ def test_add_word(self):
+ self.undertest.add_word("test", cnt=31, vec=[1.42, 1.44, 1.55])
+ self.assertEqual(["test"], list(self.undertest.vocab.keys()))
+ self.assertTrue("test" in self.undertest)
+
+ def test_count(self):
+ self.undertest.add_words(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_data.txt"))
+ self.assertEqual(34444, self.undertest.count("house"))
+
+ def test_save_and_load(self):
+ self.undertest.add_words(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_data.txt"))
+ self.undertest.add_word("test", cnt=31, vec=[1.42, 1.44, 1.55])
+ vocab_path = f"{self.tmp_dir}/vocab.dat"
+ self.undertest.save(vocab_path)
+ vocab = Vocab.load(vocab_path)
+ self.assertEqual(["house", "dog", "test"], list(vocab.vocab.keys()))
+
+
+
diff --git a/webapp/envs/env_medmen b/webapp/envs/env_medmen
index 85b79f850..447f0dfb9 100644
--- a/webapp/envs/env_medmen
+++ b/webapp/envs/env_medmen
@@ -19,8 +19,8 @@ KEEP_PUNCT=:|.
SPACY_MODEL=en_core_sci_md
VOCAB_PATH=/webapp/models/vocab.dat
CDB_PATH=/webapp/models/cdb.dat
-VOCAB_URL=https://s3-eu-west-1.amazonaws.com/zkcl/vocab.dat
-CDB_URL=https://s3-eu-west-1.amazonaws.com/zkcl/cdb-medmen.dat
+VOCAB_URL=https://medcat.rosalind.kcl.ac.uk/media/vocab.dat
+CDB_URL=https://medcat.rosalind.kcl.ac.uk/media/cdb-medmen-v1.dat
MKL_NUM_THREAD=1
NUMEXPR_NUM_THREADS=1
diff --git a/webapp/webapp/Dockerfile b/webapp/webapp/Dockerfile
index ccf9dfa81..61575edba 100644
--- a/webapp/webapp/Dockerfile
+++ b/webapp/webapp/Dockerfile
@@ -6,8 +6,8 @@ RUN mkdir -p /webapp/models
# Copy everything
COPY . /webapp
-ENV VOCAB_URL=https://s3-eu-west-1.amazonaws.com/zkcl/vocab.dat
-ENV CDB_URL=https://s3-eu-west-1.amazonaws.com/zkcl/cdb-medmen.dat
+ENV VOCAB_URL=https://medcat.rosalind.kcl.ac.uk/media/vocab.dat
+ENV CDB_URL=https://medcat.rosalind.kcl.ac.uk/media/cdb-medmen-v1.dat
ENV CDB_PATH=/webapp/models/cdb.dat
ENV VOCAB_PATH=/webapp/models/vocab.dat
@@ -17,9 +17,8 @@ WORKDIR /webapp
RUN pip install -r requirements.txt
-# Get the spacy and scipspacy model
+# Get the spacy model
RUN python -m spacy download en_core_web_md
-RUN pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_md-0.2.4.tar.gz
# Build the db
RUN python manage.py makemigrations && \