Skip to content

Commit

Permalink
Fix predictors (#52)
Browse files Browse the repository at this point in the history
* Fixed ngram_predictor

* RUFF fix
  • Loading branch information
michaelbeale-IL authored Dec 18, 2024
1 parent 8b456c4 commit 66b2fcf
Show file tree
Hide file tree
Showing 16 changed files with 1,974 additions and 1,961 deletions.
19 changes: 7 additions & 12 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
{
"version": "0.2.0",
"configurations": [
// {
// "name": "Python Debugger: Module",
// "type": "debugpy",
// "request": "launch",
// "module": "pyinstaller"
// },
{
"name": "Python Debugger: Current File",
"name": "Python Debugger: Main with Arguments",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "internalConsole",
"justMyCode": false,
"redirectOutput": true,
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
"PYTHONPATH": "${workspaceFolder}/src/"
},
"justMyCode": false,
"args": "${command:pickArgs}"
},
{
"name": "Continuous Predict Util",
Expand Down
57 changes: 57 additions & 0 deletions 3rd_party_resources/utils/database_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse
from convassist.utilities.ngram.ngramutil import NGramUtil

def configure():
# Create top-level parser
parser = argparse.ArgumentParser(description="Recreate an NGram Database for ConvAssist")

# File command
parser.add_argument(
"database",
type=str,
help="The database file to use for the n-gram model. Must be a path to a db file."
)

parser.add_argument(
'input_file',
type=str,
help='The input file to use for the database. Must be a text file with one sentence per line.'
)

parser.add_argument(
'--cardinality',
type=int,
default=3,
help='The number of tokens to consider in the n-gram model'
)

parser.add_argument(
"--lowercase",
type=bool,
default=False,
help="Whether to convert all tokens to lowercase"
)

parser.add_argument(
"--normalize",
type=bool,
default=False,
help="Whether to normalize the database"
)
return parser

def main(argv=None):
parser = configure()
args = parser.parse_args(argv)

with NGramUtil(args.database, args.cardinality, args.lowercase, args.normalize) as ngramutil:
phrases = []

with open(args.input_file) as f:
for line in f:
phrases.append(line.strip())

ngramutil.update(phrases)

if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions convassist/ConvAssist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
import os
from configparser import ConfigParser, ExtendedInterpolation
from configparser import ConfigParser

import nltk

Expand All @@ -24,7 +23,7 @@ def __init__(
ini_file: str | None = None,
config: ConfigParser | None = None,
log_location: str | None = None,
log_level: int = logging.ERROR,
log_level: int = logging.DEBUG,
):
"""
Initializes an instance of the class.
Expand Down
70 changes: 0 additions & 70 deletions convassist/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,76 +15,6 @@


class Predictor(ABC):
"""
Predictor is an abstract base class that defines the interface for various predictors.
Attributes:
config (ConfigParser): Configuration parser object.
context_tracker (ContextTracker): Object to track the context.
predictor_name (str): Name of the predictor.
logger (logging.Logger): Logger object for logging information.
_aac_dataset (str): Path to the AAC dataset.
_blacklist_file (str): Path to the blacklist file.
_database (str): Path to the database.
_deltas (str): String of delta values.
_embedding_cache_path (str): Path to the embedding cache.
_generic_phrases (str): Path to the generic phrases file.
_index_path (str): Path to the index file.
_learn (bool): Flag to enable learning.
_modelname (str): Path to the model name.
_personalized_allowed_toxicwords_file (str): Path to the personalized allowed toxic words file.
_personalized_cannedphrases (str): Path to the personalized canned phrases file.
_personalized_resources_path (str): Path to the personalized resources.
_predictor_class (str): Class of the predictor.
_retrieve_database (str): Path to the retrieve database.
_retrieveaac (bool): Flag to enable AAC retrieval.
_sbertmodel (str): SBERT model name.
_sent_database (str): Path to the sentence database.
_sentence_transformer_model (str): Path to the sentence transformer model.
_sentences_db (str): Path to the sentences database.
_spellingdatabase (str): Path to the spelling database.
_startsents (str): Path to the start sentences file.
_startwords (str): Path to the start words file.
_static_resources_path (str): Path to the static resources.
_stopwords (str): Path to the stopwords file.
_test_generalsentenceprediction (bool): Flag to enable general sentence prediction testing.
_tokenizer (str): Path to the tokenizer.
Methods:
predictor_name: Returns the name of the predictor.
aac_dataset: Returns the path to the AAC dataset.
database: Returns the path to the database.
deltas: Gets and sets the list of delta values.
cardinality: Returns the number of delta values.
generic_phrases: Returns the path to the generic phrases file.
learn_enabled: Returns the learning flag.
modelname: Returns the path to the model name.
personalized_cannedphrases: Returns the path to the personalized canned phrases file.
predictor_class: Returns the class of the predictor.
retrieveaac: Returns the AAC retrieval flag.
sbertmodel: Returns the SBERT model name.
sentence_transformer_model: Returns the path to the sentence transformer model.
sent_database: Returns the path to the sentence database.
retrieve_database: Returns the path to the retrieve database.
blacklist_file: Returns the path to the blacklist file.
embedding_cache_path: Returns the path to the embedding cache.
index_path: Returns the path to the index file.
stopwordsFile: Returns the path to the stopwords file.
personalized_allowed_toxicwords_file: Returns the path to the personalized allowed toxic words file.
startsents: Returns the path to the start sentences file.
tokenizer: Returns the path to the tokenizer.
startwords: Returns the path to the start words file.
test_generalsentenceprediction: Gets and sets the flag for general sentence prediction testing.
configure: Abstract method to configure the predictor.
predict: Abstract method to predict the next word and sentence based on the context.
read_personalized_corpus: Reads the personalized corpus from the canned phrases file.
learn: Method for learning, to be implemented by subclasses if needed.
recreate_database: Method to recreate the database, to be implemented by subclasses if needed.
load_model: Method to load the model, to be implemented by subclasses if needed.
read_personalized_toxic_words: Reads personalized toxic words, to be implemented by subclasses if needed.
_find_option_in_section: Finds an option in the given section or in the "Common" section.
_read_config: Reads the configuration for the predictor.
__repr__: Returns a string representation of the predictor.
"""

def __init__(
self,
config: ConfigParser,
Expand Down
3 changes: 1 addition & 2 deletions convassist/predictor/sentence_completion_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import collections
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
Expand All @@ -13,7 +12,7 @@
import numpy
import torch
import transformers
from nltk import sent_tokenize, word_tokenize
from nltk import word_tokenize
from nltk.stem.porter import PorterStemmer
from sentence_transformers import SentenceTransformer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,6 @@


class GeneralWordPredictor(SmoothedNgramPredictor):
"""
GeneralWordPredictor is a class that extends SmoothedNgramPredictor to provide word predictions
based on a precomputed set of most frequent starting words from an AAC dataset. It overrides
certain properties and methods to achieve this functionality.
Methods:
configure():
Configures the predictor by precomputing the most frequent starting words from an AAC dataset
and storing them in a file if they are not already stored.
aac_dataset:
Property that returns the path to the AAC dataset file.
database:
Property that returns the path to the database file.
startwords:
Property that returns the path to the file where precomputed starting words are stored.
predict(max_partial_prediction_size: int, filter):
Predicts the next word based on the context tracker and the n-gram model. If no tokens are
available in the context tracker, it returns the most frequent starting words.
"""

def configure(self):
super().configure()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import string
from abc import ABC, abstractmethod
from io import FileIO
from abc import ABC
from typing import List

from convassist.predictor import Predictor
Expand All @@ -25,23 +24,25 @@ def configure(self) -> None:
# ngramutil.learn(line.strip('.\n'))
pass

@abstractmethod
def extract_svo(self, sent):
raise NotImplementedError(f"extract_svo not implemented in {self.predictor_name}")
return sent

def predict(self, max_partial_prediction_size: int, filter):

sentence_prediction = Prediction()
word_prediction = Prediction()

self.logger.debug("Starting Ngram prediction")

# get self.cardinality tokens from the context tracker
actual_tokens, tokens = self.context_tracker.get_tokens(self.cardinality)
prefix_completion_candidates: List[str] = []

try:
partial = None
prefix_ngram = None
for ngram_len in reversed(range(1, actual_tokens + 1)):
# for ngram_len in reversed(range(1, actual_tokens + 1)):
for ngram_len in range(actual_tokens, 0, -1):
if len(prefix_completion_candidates) >= max_partial_prediction_size:
break

Expand All @@ -63,12 +64,13 @@ def predict(self, max_partial_prediction_size: int, filter):
self.logger.error(f"Error fetching ngrams for {prefix_ngram}: {e}")
continue
for p in partial:
candidate = p[-2]
if (
candidate not in tokens
and candidate not in prefix_completion_candidates
):
prefix_completion_candidates.append(candidate)
# candidate = p[-2]
# if (
# candidate not in tokens
# and candidate not in prefix_completion_candidates
# ):
# prefix_completion_candidates.append(candidate)
prefix_completion_candidates.append(p[-2])

# smoothing
unigram_counts_sum = ngramutil.unigram_counts_sum()
Expand Down
3 changes: 2 additions & 1 deletion convassist/predictor/utilities/suggestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def probability(self):

@probability.setter
def probability(self, value):
if value < MIN_PROBABILITY or value > MAX_PROBABILITY:
# value = round(value, 2)
if value < MIN_PROBABILITY or round(value, 2) > MAX_PROBABILITY:
raise SuggestionException("Probability is too high or too low = " + str(value))
self._probability = value

Expand Down
7 changes: 5 additions & 2 deletions convassist/tests/predictors/test_canned_word_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import configparser
import os
import unittest
from unittest.mock import patch

Expand Down Expand Up @@ -60,8 +59,11 @@ def test_configure(self):
[
("no_context", "", 1, "all"),
("trigram", "to the ", 1, "crazy"),
("trigram", "to the crazy", 1, "crazy"),
("bigram", "the ", 1, "crazy"),
("unigram", "bec", 1, "because"),
("bigram", "the crazy", 1, "crazy"),
("unigram", "cra", 1, "crazy"),
("unigram", "crazy", 1, "crazy"),
]
)
def test_predict(self, name, context, max, expected_word):
Expand All @@ -74,6 +76,7 @@ def test_predict(self, name, context, max, expected_word):
self.assertEqual(len(sentence_predictions), 0)
self.assertIsNotNone(word_predictions)
self.assertEqual(len(word_predictions), max_partial_prediction_size)
self.assertEqual(word_predictions[0].word, expected_word)

def test_learn_new_sentence(self):
change_tokens = "This is a new sentence to learn."
Expand Down
3 changes: 2 additions & 1 deletion convassist/utilities/databaseutils/sqllite_dbconnector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def fetch_all(self, query: str, params: Optional[Tuple[Any, ...]] = None):

finally:
cursor.close()
return result

return result

def begin_transaction(self) -> None:
if not self.conn:
Expand Down
2 changes: 0 additions & 2 deletions convassist/utilities/logging_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
if sys.platform == "win32":
import pydebugstring

from .singleton import Singleton


class QueueHandler(logging.Handler):
"""
Expand Down
Loading

0 comments on commit 66b2fcf

Please sign in to comment.