diff --git a/camel_tools/disambig/mle.py b/camel_tools/disambig/mle.py index 6dcf828..8ed6829 100644 --- a/camel_tools/disambig/mle.py +++ b/camel_tools/disambig/mle.py @@ -133,12 +133,13 @@ def __init__(self, analyzer, mle_path=None, top=1, cache_size=100000): top = 1 self._top = top - if cache_size < 0: + if cache_size <= 0: cache_size = 0 - - self._cache = LFUCache(cache_size) - self._scored_analyses = cached(self._cache)( - self._scored_analyses) + self._cache = None + self._score_fn = self._scored_analyses + else: + self._cache = LFUCache(cache_size) + self._score_fn = self._scored_analyses_cached @staticmethod def pretrained(model_name=None, analyzer=None, top=1, cache_size=100000): @@ -208,10 +209,13 @@ def _scored_analyses(self, word_dd): w.analysis['diac'])) return scored_analyses[0:self._top] + + def _scored_analyses_cached(self, word_dd): + return self._cache.get(word_dd, self._scored_analyses(word_dd)) def _disambiguate_word(self, word): word_dd = dediac_ar(word) - scored_analyses = self._scored_analyses(word_dd) + scored_analyses = self._score_fn(word_dd) return DisambiguatedWord(word, scored_analyses)