Spaces:
Runtime error
Runtime error
Use stopwords from NLTK and download NLTK data
Browse files- app/cli.py +6 -4
- app/model.py +16 -1
app/cli.py
CHANGED
|
@@ -117,15 +117,17 @@ def train(
|
|
| 117 |
click.echo(DONE_STR)
|
| 118 |
|
| 119 |
click.echo("Creating model... ", nl=False)
|
| 120 |
-
model = create_model(max_features, seed=None if seed == -1 else seed)
|
| 121 |
click.echo(DONE_STR)
|
| 122 |
|
| 123 |
-
click.echo("Training model... ", nl=False)
|
|
|
|
| 124 |
accuracy = train_model(model, text_data, label_data)
|
| 125 |
joblib.dump(model, model_path)
|
| 126 |
-
click.echo(
|
|
|
|
| 127 |
|
| 128 |
-
click.echo("Model accuracy: ")
|
| 129 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
| 130 |
|
| 131 |
# TODO: Add hyperparameter options
|
|
|
|
| 117 |
click.echo(DONE_STR)
|
| 118 |
|
| 119 |
click.echo("Creating model... ", nl=False)
|
| 120 |
+
model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
|
| 121 |
click.echo(DONE_STR)
|
| 122 |
|
| 123 |
+
# click.echo("Training model... ", nl=False)
|
| 124 |
+
click.echo("Training model... ")
|
| 125 |
accuracy = train_model(model, text_data, label_data)
|
| 126 |
joblib.dump(model, model_path)
|
| 127 |
+
click.echo("Model saved to: ", nl=False)
|
| 128 |
+
click.secho(str(model_path), fg="blue")
|
| 129 |
|
| 130 |
+
click.echo("Model accuracy: ", nl=False)
|
| 131 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
| 132 |
|
| 133 |
# TODO: Add hyperparameter options
|
app/model.py
CHANGED
|
@@ -5,8 +5,10 @@ import re
|
|
| 5 |
import warnings
|
| 6 |
from typing import Literal
|
| 7 |
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
from joblib import Memory
|
|
|
|
| 10 |
from nltk.stem import WordNetLemmatizer
|
| 11 |
from sklearn.base import BaseEstimator, TransformerMixin
|
| 12 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
|
@@ -248,28 +250,41 @@ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> t
|
|
| 248 |
def create_model(
|
| 249 |
max_features: int,
|
| 250 |
seed: int | None = None,
|
|
|
|
| 251 |
) -> Pipeline:
|
| 252 |
"""Create a sentiment analysis model.
|
| 253 |
|
| 254 |
Args:
|
| 255 |
max_features: Maximum number of features
|
| 256 |
seed: Random seed (None for random seed)
|
|
|
|
| 257 |
|
| 258 |
Returns:
|
| 259 |
Untrained model
|
| 260 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
return Pipeline(
|
| 262 |
[
|
| 263 |
# Text preprocessing
|
| 264 |
("clean", TextCleaner()),
|
| 265 |
("lemma", TextLemmatizer()),
|
| 266 |
# Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
|
| 267 |
-
(
|
|
|
|
|
|
|
|
|
|
| 268 |
("tfidf", TfidfTransformer()),
|
| 269 |
# Classifier
|
| 270 |
("clf", LogisticRegression(max_iter=1000, random_state=seed)),
|
| 271 |
],
|
| 272 |
memory=Memory(CACHE_DIR, verbose=0),
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
|
|
|
|
| 5 |
import warnings
|
| 6 |
from typing import Literal
|
| 7 |
|
| 8 |
+
import nltk
|
| 9 |
import pandas as pd
|
| 10 |
from joblib import Memory
|
| 11 |
+
from nltk.corpus import stopwords
|
| 12 |
from nltk.stem import WordNetLemmatizer
|
| 13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
| 14 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
|
|
|
| 250 |
def create_model(
|
| 251 |
max_features: int,
|
| 252 |
seed: int | None = None,
|
| 253 |
+
verbose: bool = False,
|
| 254 |
) -> Pipeline:
|
| 255 |
"""Create a sentiment analysis model.
|
| 256 |
|
| 257 |
Args:
|
| 258 |
max_features: Maximum number of features
|
| 259 |
seed: Random seed (None for random seed)
|
| 260 |
+
verbose: Whether to log progress during training
|
| 261 |
|
| 262 |
Returns:
|
| 263 |
Untrained model
|
| 264 |
"""
|
| 265 |
+
# Download NLTK data if not already downloaded
|
| 266 |
+
nltk.download("wordnet", quiet=True)
|
| 267 |
+
nltk.download("stopwords", quiet=True)
|
| 268 |
+
|
| 269 |
+
# Load English stopwords
|
| 270 |
+
stopwords_en = set(stopwords.words("english"))
|
| 271 |
+
|
| 272 |
return Pipeline(
|
| 273 |
[
|
| 274 |
# Text preprocessing
|
| 275 |
("clean", TextCleaner()),
|
| 276 |
("lemma", TextLemmatizer()),
|
| 277 |
# Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
|
| 278 |
+
(
|
| 279 |
+
"vectorize",
|
| 280 |
+
CountVectorizer(stop_words=stopwords_en, ngram_range=(1, 2), max_features=max_features),
|
| 281 |
+
),
|
| 282 |
("tfidf", TfidfTransformer()),
|
| 283 |
# Classifier
|
| 284 |
("clf", LogisticRegression(max_iter=1000, random_state=seed)),
|
| 285 |
],
|
| 286 |
memory=Memory(CACHE_DIR, verbose=0),
|
| 287 |
+
verbose=verbose,
|
| 288 |
)
|
| 289 |
|
| 290 |
|