Skip to content

Commit

Permalink
CAT-18 refactor and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
baixiac committed Jul 28, 2021
1 parent 6c07e78 commit a70409d
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 24 deletions.
2 changes: 2 additions & 0 deletions examples/vocab_data.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
house 34444 0.3232 0.123213 1.231231
dog 14444 0.76762 0.76767 1.45454
2 changes: 1 addition & 1 deletion medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __call__(self, text, do_train=False):
Returns:
A spacy document with the extracted entities
'''
# Should we train - do not use this for training, unles you know what you are doing. Use the
# Should we train - do not use this for training, unless you know what you are doing. Use the
#self.train() function
self.config.linking['train'] = do_train

Expand Down
12 changes: 3 additions & 9 deletions medcat/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, tokenizer, config):
self.nlp.tokenizer = tokenizer(self.nlp)
self.config = config


def add_tagger(self, tagger, name, additional_fields=[]):
r''' Add any kind of a tagger for tokens.
Expand All @@ -37,17 +36,16 @@ def add_tagger(self, tagger, name, additional_fields=[]):
additional_fields (`List[str]`):
Fields to be added to the `_` properties of a token.
'''
component_name = spacy.util.get_object_name(tagger)
Language.factory(name=component_name, default_config={"config": self.config}, func=tagger)
self.nlp.add_pipe(component_name, name='tag_' + name, first=True)
component_factory_name = spacy.util.get_object_name(tagger)
Language.factory(name=component_factory_name, default_config={"config": self.config}, func=tagger)
self.nlp.add_pipe(component_factory_name, name='tag_' + name, first=True)
# Add custom fields needed for this usecase
Token.set_extension('to_skip', default=False, force=True)

# Add any additional fields that are required
for field in additional_fields:
Token.set_extension(field, default=False, force=True)


def add_token_normalizer(self, config, spell_checker=None):
token_normalizer = TokenNormalizer(spell_checker=spell_checker, config=config)
component_name = spacy.util.get_object_name(token_normalizer)
Expand All @@ -57,7 +55,6 @@ def add_token_normalizer(self, config, spell_checker=None):
# Add custom fields needed for this usecase
Token.set_extension('norm', default=None, force=True)


def add_ner(self, ner):
r''' Add NER from CAT to the pipeline, will also add the necessary fields
to the document and Span objects.
Expand All @@ -75,7 +72,6 @@ def add_ner(self, ner):
Span.set_extension('detected_name', default=None, force=True)
Span.set_extension('link_candidates', default=None, force=True)


def add_linker(self, linker):
r''' Add entity linker to the pipeline, will also add the necessary fields
to Span object.
Expand All @@ -90,7 +86,6 @@ def add_linker(self, linker):
Span.set_extension('cui', default=-1, force=True)
Span.set_extension('context_similarity', default=-1, force=True)


def add_meta_cat(self, meta_cat, name):
component_name = spacy.util.get_object_name(meta_cat)
Language.component(name=component_name, func=meta_cat)
Expand All @@ -100,6 +95,5 @@ def add_meta_cat(self, meta_cat, name):
#of {category_name: value, ...}
Span.set_extension('meta_anns', default=None, force=True)


def __call__(self, text):
return self.nlp(text)
21 changes: 7 additions & 14 deletions medcat/preprocessing/taggers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
import re

def tag_skip_and_punct(nlp, name, config):
r''' Detects and tags spacy tokens that are punctuation and that should be skipped.
Args:
nlp (spacy.language.<lng>):
The base spacy NLP pipeline.
name (`str`):
The component instance name.
config (`medcat.config.Config`):
Global config for medcat.
'''

return TagSkipAndPunct(nlp, name, config)
return _Tagger(nlp, name, config)


class TagSkipAndPunct(object):
class _Tagger(object):

def __init__(self, nlp, name, config):
self.nlp = nlp
self.name = name
self.config = config

def __call__(self, doc):
r''' Detects and tags spacy tokens that are punctuation and that should be skipped.
Args:
doc (`spacy.tokens.Doc`):
Spacy document that will be tagged.
Return:
(`spacy.tokens.Doc):
Tagged spacy document
'''
# Make life easier
cnf_p = self.config.preprocessing

Expand Down
27 changes: 27 additions & 0 deletions tests/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import unittest
from medcat.vocab import Vocab
from medcat.cdb import CDB
from medcat.cat import CAT


class CATTests(unittest.TestCase):

def setUp(self) -> None:
self.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
self.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat"))
self.cdb.config.ner['min_name_len'] = 2
self.cdb.config.ner['upper_case_limit_len'] = 3
self.cdb.config.general['spell_check'] = True
self.cdb.config.linking['train_count_threshold'] = 10
self.cdb.config.linking['similarity_threshold'] = 0.3
self.cdb.config.linking['train'] = True
self.cdb.config.linking['disamb_length_limit'] = 5
self.cdb.config.general['full_unlink'] = True
self.undertest = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab)

def test_pipeline(self):
text = "The dog is sitting outside the house."
doc = self.undertest(text)
self.assertEqual(text, doc.text)

51 changes: 51 additions & 0 deletions tests/test_cdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import shutil
import unittest
from medcat.config import Config
from medcat.cdb_maker import CDBMaker


class CDBTests(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
config = Config()
config.general["spacy_model"] = "en_core_sci_md"
cls.cdb_maker = CDBMaker(config)

def setUp(self) -> None:
cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.csv")
cdb_2_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_2.csv")
self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
os.makedirs(self.tmp_dir, exist_ok=True)
self.undertest = CDBTests.cdb_maker.prepare_csvs([cdb_csv, cdb_2_csv], full_build=True)

def tearDown(self) -> None:
shutil.rmtree(self.tmp_dir)

def test_name2cuis(self):
self.assertEqual({
'second~csv': ['C0000239'],
'virus': ['C0000039', 'C0000139'],
'virus~k': ['C0000039', 'C0000139'],
'virus~m': ['C0000039', 'C0000139'],
'virus~z': ['C0000039', 'C0000139']
}, self.undertest.name2cuis)

def test_cui2names(self):
self.assertEqual({
'C0000039': {'virus~z', 'virus~k', 'virus~m', 'virus'},
'C0000139': {'virus~z', 'virus', 'virus~m', 'virus~k'},
'C0000239': {'second~csv'}
}, self.undertest.cui2names)

def test_cui2preferred_name(self):
self.assertEqual({'C0000039': 'Virus', 'C0000139': 'Virus Z'}, self.undertest.cui2preferred_name)

def test_cui2type_ids(self):
self.assertEqual({'C0000039': {'T109', 'T234', 'T123'}, 'C0000139': set(), 'C0000239': set()}, self.undertest.cui2type_ids)

def test_save_and_load(self):
cdb_path = f"{self.tmp_dir}/cdb.dat"
self.undertest.save(cdb_path)
self.undertest.load(cdb_path)
39 changes: 39 additions & 0 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import shutil
import unittest
from medcat.vocab import Vocab


class CATTests(unittest.TestCase):

def setUp(self) -> None:
self.undertest = Vocab()
self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
os.makedirs(self.tmp_dir, exist_ok=True)

def tearDown(self) -> None:
shutil.rmtree(self.tmp_dir)

def test_add_words(self):
self.undertest.add_words(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_data.txt"))
self.assertEqual(["house", "dog"], list(self.undertest.vocab.keys()))

def test_add_word(self):
self.undertest.add_word("test", cnt=31, vec=[1.42, 1.44, 1.55])
self.assertEqual(["test"], list(self.undertest.vocab.keys()))
self.assertTrue("test" in self.undertest)

def test_count(self):
self.undertest.add_words(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_data.txt"))
self.assertEqual(34444, self.undertest.count("house"))

def test_save_and_load(self):
self.undertest.add_words(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_data.txt"))
self.undertest.add_word("test", cnt=31, vec=[1.42, 1.44, 1.55])
vocab_path = f"{self.tmp_dir}/vocab.dat"
self.undertest.save(vocab_path)
vocab = Vocab.load(vocab_path)
self.assertEqual(["house", "dog", "test"], list(vocab.vocab.keys()))



0 comments on commit a70409d

Please sign in to comment.