Before we start, let's briefly talk about machine learning.
In simple terms:
Machine learning is a type of artificial intelligence (AI) that provides computers with the ability to learn without being explicitly programmed.
There are three(Main) types of machine learning algorithms:
Today, We are only going to use supervised learning. So let's talk about supervised learning.
Supervised learning is the machine learning task of learning a function that maps an input to an output based on example input-output pairs. It build a function from labeled training data consisting of a set of training examples.
Since we are going to classify emails into spam and non-spam categories, let's see how we can use supervised learning to solve this problem.
Hope you all got familiar with supervised learning concept.
Now, lets's move to over main objective, creating email classifier.
Important Note 1:
Important Note 2:
First thing first, we need to install the libraries that we are going to use in this project.
%pip install pandas scikit-learn numpy matplotlib seaborn nltk wordcloud xgboost lightgbm
Requirement already satisfied: pandas in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (2.1.1) Requirement already satisfied: scikit-learn in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (1.3.1) Requirement already satisfied: numpy in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (1.25.1) Requirement already satisfied: matplotlib in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (3.7.2) Requirement already satisfied: seaborn in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (0.13.0) Requirement already satisfied: nltk in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (3.8.1) Requirement already satisfied: wordcloud in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (1.9.2) Requirement already satisfied: xgboost in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (2.0.0) Requirement already satisfied: lightgbm in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (4.1.0) Requirement already satisfied: python-dateutil>=2.8.2 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from pandas) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from pandas) (2023.3.post1) Requirement already satisfied: tzdata>=2022.1 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from pandas) (2023.3) Requirement already satisfied: scipy>=1.5.0 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from scikit-learn) (1.11.3) Requirement already satisfied: joblib>=1.1.1 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from scikit-learn) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from scikit-learn) (3.2.0) Requirement already satisfied: contourpy>=1.0.1 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (1.1.0) Requirement already satisfied: cycler>=0.10 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (0.11.0) Requirement already satisfied: fonttools>=4.22.0 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (4.41.1) Requirement already satisfied: kiwisolver>=1.0.1 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (1.4.4) Requirement already satisfied: packaging>=20.0 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (23.1) Requirement already satisfied: pillow>=6.2.0 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (10.0.0) Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from matplotlib) (3.0.9) Requirement already satisfied: click in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from nltk) (8.1.7) Requirement already satisfied: regex>=2021.8.3 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from nltk) (2023.10.3) Requirement already satisfied: tqdm in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from nltk) (4.65.0) Requirement already satisfied: six>=1.5 in /mnt/ssd/mamba/envs/aaa/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: Note: you may need to restart the kernel to use updated packages.
Next thing is to import all the libraries that we are going to use in this project.
Let's list down the libraries that we are going to use in this project with a brief description.
# For data manipulation and analysis
import pandas as pd
import numpy as np
import string
from collections import Counter
# For plotting
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
# For text processing and feature extraction
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
# For text preprocessing and feature extraction
import nltk
from nltk import bigrams
from nltk.corpus import stopwords
from nltk.stem.porter import PorterStemmer
from nltk.tokenize import word_tokenize, sent_tokenize
# Machine learning models and evaluation metrics
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm.sklearn import LGBMClassifier
# Machine learning evaluation metrics
from sklearn.metrics import accuracy_score,confusion_matrix, precision_score, roc_auc_score, f1_score, recall_score
from wordcloud import WordCloud
import warnings
[nltk_data] Downloading package punkt to /root/nltk_data... [nltk_data] Package punkt is already up-to-date! [nltk_data] Downloading package stopwords to /root/nltk_data... [nltk_data] Package stopwords is already up-to-date!
After importing the libraries, we need to import the dataset as for the next step.
# DATA_SET = "./data/spam_dataset.csv"
df = pd.read_csv(DATA_SET, encoding='latin-1')
v1 | v2 | Unnamed: 2 | Unnamed: 3 | Unnamed: 4 | |
0 | ham | Go until jurong point, crazy.. Available only ... | NaN | NaN | NaN |
1 | ham | Ok lar... Joking wif u oni... | NaN | NaN | NaN |
2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... | NaN | NaN | NaN |
3 | ham | U dun say so early hor... U c already then say... | NaN | NaN | NaN |
4 | ham | Nah I don't think he goes to usf, he lives aro... | NaN | NaN | NaN |
The dataset seems to contain multiple columns:
To proceed with the exploratory data analysis (EDA):
Next function will help us to get a summary of the dataset.
# basic info about the dataset
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5572 entries, 0 to 5571 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 v1 5572 non-null object 1 v2 5572 non-null object 2 Unnamed: 2 50 non-null object 3 Unnamed: 3 12 non-null object 4 Unnamed: 4 6 non-null object dtypes: object(5) memory usage: 217.8+ KB
We can use next code cell to drop these columns.
# last 3 columns are unnecessary. So we are going to drop them.
df.drop(['Unnamed: 2', 'Unnamed: 3', 'Unnamed: 4'], axis=1, inplace=True)
Then we are going to rename the columns of the dataset for better understanding.
df.rename(columns={'v1': 'label', 'v2': 'message'}, inplace=True)
label | message | |
0 | ham | Go until jurong point, crazy.. Available only ... |
1 | ham | Ok lar... Joking wif u oni... |
2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... |
3 | ham | U dun say so early hor... U c already then say... |
4 | ham | Nah I don't think he goes to usf, he lives aro... |
5 | spam | FreeMsg Hey there darling it's been 3 week's n... |
6 | ham | Even my brother is not like to speak with me. ... |
7 | ham | As per your request 'Melle Melle (Oru Minnamin... |
8 | spam | WINNER!! As a valued network customer you have... |
9 | spam | Had your mobile 11 months or more? U R entitle... |
For this dataset, target variable is label column.
Let's create pie chart to see the distribution of the target variable.
target_cat = df.label.value_counts()
#create pie chart using seaborn
color1 = sns.color_palette("viridis")[0]
color2 = sns.color_palette("viridis")[3]
plt.pie(target_cat, labels=target_cat.index, colors=[color1, color2], autopct='%1.1f%%', shadow=True, startangle=90, explode=(0, 0.1))
plt.title('Target Variable Distribution (Spam and Ham)')
label ham 4825 spam 747 Name: count, dtype: int64
Output shows:
When the data is skewed like this, machine learning models might not perform at their best.
The uneven distribution can cause issues when we divide the data for training and testing. And chain effect of this is, our model will be biased towards the majority class.
There are few ways to handle this. But, to keep things simple, we'll just randomly remove some 'ham' messages to even things out.
After doing this, we'll look at the target variable distribution once more.
num_spam = target_cat['spam']
# Randomly sampling the majority class to match the count of the minority class
df = pd.concat(
df[df['label'] == 'ham'].sample(n=num_spam, random_state=42),
df[df['label'] == 'spam']
#check the target variable distribution again
target_cat = df.label.value_counts()
#set the colors
color1 = sns.color_palette("viridis")[0]
color2 = sns.color_palette("viridis")[3]
#plot the pie chart
plt.pie(target_cat, labels=target_cat.index, colors=[color1, color2], autopct='%1.1f%%', shadow=True, startangle=90, explode=(0, 0.1))
plt.title('Target Variable Distribution after dropping ham messages')
label ham 747 spam 747 Name: count, dtype: int64
Now we can see that the percentage of ham and spams and the counts are Identical. So, we can proceed with the next step.
What is Unigrams?
Next code cell will help us to get the 10 most common unigrams in spam and ham messages. And also we are going to check the count of the unigrams in spam messages to how its look like.
# We'll use a predefined list of stopwords as we can't download them directly here
# stop_words = set([
# "ourselves", "hers", "between", "yourself", "but",
# "again", "there", "about", "once", "during", "out",
# "very", "having", "with", "they", "own", "an", "be",
# "some", "for", "do", "its", "yours", "such", "into",
# "of", "most", "itself", "other", "off", "is", "s", "am",
# "or", "who", "as", "from", "him", "each", "the", "themselves",
# "until", "below", "are", "we", "these", "your", "his", "through",
# "don", "nor", "me", "were", "her", "more", "himself", "this",
# "down", "should", "our", "their", "while", "above", "both", "up",
# "to", "ours", "had", "she", "all", "no", "when", "at", "any", "before",
# "them", "same", "and", "been", "have", "in", "will", "on", "does",
# "yourselves", "then", "that", "because", "what", "over", "why", "so",
# "can", "did", "not", "now", "under", "he", "you", "herself", "has",
# "just", "where", "too", "only", "myself", "which", "those", "i",
# "after", "few", "whom", "t", "being", "if", "theirs", "my", "against",
# "a", "by", "doing", "it", "how", "further", "was", "here", "than"
# ])
#get the stopwords from nltk
stop_words = set(stopwords.words('english'))
def get_most_common_words(texts, n=10):
words = []
for text in texts:
tokens = word_tokenize(text)
words.extend([word.lower() for word in tokens if word.isalpha() and word.lower() not in stop_words])
# Count the frequency of each word
word_freq = Counter(words)
return word_freq.most_common(n)
# Get the 10 most common words for spam and ham messages
common_words_ham = get_most_common_words(df[df['label'] == 'ham']['message'], 10)
common_words_spam = get_most_common_words(df[df['label'] == 'spam']['message'], 10)
#create table to for common_words_ham, common_words_spam
common_words_ham_df = pd.DataFrame(common_words_ham, columns=['word', 'count'])
common_words_spam_df = pd.DataFrame(common_words_spam, columns=['word', 'count'])
word | count | |
0 | call | 346 |
1 | free | 219 |
2 | txt | 156 |
3 | u | 144 |
4 | ur | 144 |
5 | mobile | 123 |
6 | text | 121 |
7 | stop | 114 |
8 | claim | 113 |
9 | reply | 104 |
So, let's create word cloud for spam and ham messages.
def generate_wordcloud(texts_ham, texts_spam):
# Join all the text messages together
all_text_ham = ' '.join(texts_ham)
all_text_spam = ' '.join(texts_spam)
# Create a word cloud for ham and spam messages
wordcloud_ham = WordCloud(stopwords=stop_words, background_color="white", width=800, height=400).generate(all_text_ham)
wordcloud_spam = WordCloud(stopwords=stop_words, background_color="white", width=800, height=400).generate(all_text_spam)
# Plot the word clouds
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 5))
ax1.imshow(wordcloud_ham, interpolation='bilinear')
ax1.set_title("Ham Messages Word Cloud", fontsize=25)
ax2.imshow(wordcloud_spam, interpolation='bilinear')
ax2.set_title("Spam Messages Word Cloud", fontsize=25)
# Adjust layout
# Generate word clouds for ham and spam messages side by side
generate_wordcloud(df[df['label'] == 'ham']['message'].tolist(), df[df['label'] == 'spam']['message'].tolist())
Previously we looked at the most common unigrams in spam and ham messages. Now we are going to look at the most common bigrams.
Bigrams, in the context of text analysis, are pairs of consecutive words from a given text.
Example: For the sentence "I love ice cream", the bigrams would be:
Let's take a look at the most common bigrams in spam and ham messages.
def remove_punctuation(message):
return ''.join([char for char in message if char not in string.punctuation])
def tokenize(message):
return [word.lower() for word in message.split() if word.lower() not in stop_words]
def generate_bigrams(tokens):
return list(bigrams(tokens))
def preprocess_message(message):
message = remove_punctuation(message)
tokens = tokenize(message)
return generate_bigrams(tokens)
def get_top_bigrams(series, n=10):
# Flatten the list of bigrams and count occurrences
all_bigrams = [bigram for sublist in series for bigram in sublist]
bigram_counts = Counter(all_bigrams)
return bigram_counts.most_common(n)
def plot_top_bigrams(df, column, title, ax, palette="viridis"):
sns.barplot(data=df, y=column, x='Count', ax=ax, palette=palette)
# Assuming your original DataFrame is loaded as 'df'
df["bigrams"] = df["message"].apply(preprocess_message)
# Extract top bigrams
top_spam_bigrams = get_top_bigrams(df[df['label'] == 'spam']['bigrams'])
top_ham_bigrams = get_top_bigrams(df[df['label'] == 'ham']['bigrams'])
# Create dataframes for visualization
spam_df = pd.DataFrame(top_spam_bigrams, columns=['Bigrams', 'Count'])
ham_df = pd.DataFrame(top_ham_bigrams, columns=['Bigrams', 'Count'])
# Adjust bigram representation for readability
spam_df['Bigrams'] = spam_df['Bigrams'].apply(lambda x: ' '.join(x))
ham_df['Bigrams'] = ham_df['Bigrams'].apply(lambda x: ' '.join(x))
# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharex=True)
plot_top_bigrams(spam_df, 'Bigrams', 'Top Bigrams in Spam Messages', axes[0])
plot_top_bigrams(ham_df, 'Bigrams', 'Top Bigrams in Ham Messages', axes[1])
From above graph, we can observe that:
Spam messages frequently include prompts for the recipient to take action, such as "please call", "contact u", and "await collection".
On the other hand, non-spam messages seem to revolve more around everyday conversations, with phrases like "call later", "let know", and "good morning".
Feature engineering means using what you know about the data to create new features, helping machine learning models perform better.
# create a new column for message length and calculate the length of each message
df["msg_len"] = df["message"].apply(len)
# Plot the distribution of message lengths
sns.distplot(df[df.label == 'ham'].msg_len, label='Ham', color=sns.color_palette("viridis")[0])
sns.distplot(df[df.label == 'spam'].msg_len, label='Spam', color=sns.color_palette("viridis")[3])
plt.title('Distribution Plot for Length of Messages')
plt.xlabel('Length of Messages')
print('Average message length for ham messages: ', df[df['label'] == 'ham']['message'].str.len().mean())
print('Average message length for spam messages: ', df[df['label'] == 'spam']['message'].str.len().mean())
Average message length for ham messages: 69.11780455153949 Average message length for spam messages: 138.8661311914324
Above graph shows that:
Due to the limited time, without explaining each feature, we are going to create all the features at once.
But, In real world projects, we need to analyze each feature and decide whether to use it or not.
By following the same approach as we did earlier, we can create other features such as:
Next code cell will create all the features mentioned above.
# Create new column for number of punctuations, and calculate the number of punctuations in each message
df['num_punctuations'] = df['message'].apply(lambda x: sum([1 for char in x if char in string.punctuation]))
# Create new column for number of exclamation marks, and calculate the number of exclamation marks in each message
df['num_exclamation_marks'] = df['message'].apply(lambda x: sum([1 for char in x if char == '!']))
# Calculate the number of uppercase letters in each message and create a new column for it
df['num_upper_case'] = df['message'].apply(lambda x: sum([1 for char in x if char.isupper()]))
# Count the number of numeric characters in each message and create a new column for it
df['num_numeric'] = df['message'].apply(lambda x: sum([1 for char in x if char.isdigit()]))
# Tokenize messages and compute the number of sentences in each message and create a new column for it
df['num_sentences'] = df['message'].apply(lambda x: len(sent_tokenize(x)))
# Create a new column for number of words in each message and calculate the number of words in each message
df['num_words'] = df['message'].apply(lambda x: len(word_tokenize(x)))
Lets check the current columns in the dataframe.
Index(['label', 'message', 'bigrams', 'msg_len', 'num_punctuations', 'num_exclamation_marks', 'num_upper_case', 'num_numeric', 'num_sentences', 'num_words'], dtype='object')
X = df[['msg_len', 'num_exclamation_marks', 'num_punctuations', 'num_upper_case', 'num_numeric', 'num_words', 'num_sentences']]
For this dataset, we need to encode the label column.
For example:
encoder = LabelEncoder()
encoded_labels = encoder.fit_transform(df['label'])
df['encoded_label'] = encoded_labels
label | message | bigrams | msg_len | num_punctuations | num_exclamation_marks | num_upper_case | num_numeric | num_sentences | num_words | encoded_label | |
5537 | spam | Want explicit SEX in 30 secs? Ring 02073162414... | [(want, explicit), (explicit, sex), (sex, 30),... | 90 | 3 | 1 | 17 | 21 | 3 | 18 | 1 |
5540 | spam | ASKED 3MOBILE IF 0870 CHATLINES INCLU IN FREE ... | [(asked, 3mobile), (3mobile, 0870), (0870, cha... | 160 | 5 | 0 | 104 | 14 | 6 | 38 | 1 |
5547 | spam | Had your contract mobile 11 Mnths? Latest Moto... | [(contract, mobile), (mobile, 11), (11, mnths)... | 160 | 8 | 1 | 20 | 2 | 5 | 35 | 1 |
5566 | spam | REMINDER FROM O2: To get 2.50 pounds free call... | [(reminder, o2), (o2, get), (get, 250), (250, ... | 147 | 3 | 0 | 14 | 5 | 1 | 30 | 1 |
5567 | spam | This is the 2nd time we have tried 2 contact u... | [(2nd, time), (time, tried), (tried, 2), (2, c... | 161 | 8 | 1 | 9 | 21 | 4 | 35 | 1 |
Let's check the how our target variables are encoded.
# Display classes and their encoded values
for original, encoded in zip(encoder.classes_, range(len(encoder.classes_))):
print(f"'{original}' is encoded as {encoded}")
'ham' is encoded as 0 'spam' is encoded as 1
Now we can assign the encoded labels to variable. (y)
y = df["encoded_label"]
For machine learning, we need to split the data into training and test set.
Next code cell will split the data into training and test set.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
Let's try different algorithms and compare their performance.
Let's create a dictionary of machine learning algorithms.
I'm not going to explain each algorithm in detail. But I will give a brief description of each algorithm.
num_feat_models = {
"Random Forest": RandomForestClassifier(n_estimators=1000, random_state=42),
"Decision Tree": DecisionTreeClassifier(random_state=42),
"Logistic Regression": LogisticRegression(random_state=42),
"SVC": SVC(kernel="linear", random_state=42),
"KNN" : KNeighborsClassifier(),
"XGB": XGBClassifier(objective='binary:hinge', random_state=42),
"Multinomial NB": MultinomialNB(),
"LGBM" : LGBMClassifier(boosting_type='gbdt',objective='binary', max_depth=4, random_state=42, verbose=-1)
We came to the final step of our project. That is training the model.
def train_clf(model, X_train, y_train, X_test, y_test):
# Train the model, y_train)
# Make predictions and get accuracy and precision
y_pred = model.predict(X_test)
results = {
'Accuracy': accuracy_score(y_test, y_pred),
'Precision': precision_score(y_test, y_pred),
'Recall': recall_score(y_test, y_pred),
'F1 Score': f1_score(y_test, y_pred),
'Roc Auc Score': roc_auc_score(y_test, y_pred),
'Confusion Matrix': confusion_matrix(y_test, y_pred)
return results
# Store scores for each model
model_scores = {}
for name, model in num_feat_models.items():
model_scores[name] = train_clf(model, X_train, y_train, X_test, y_test)
# Convert scores dictionary to a DataFrame
scores_df = pd.DataFrame(model_scores).T
scores_df.rename(columns={'index': 'Algorithm'}, inplace=True)
scores_df = scores_df.sort_values(by="F1 Score", ascending=False)
Now let's check the performance of the models.
Algorithm | Accuracy | Precision | Recall | F1 Score | Roc Auc Score | Confusion Matrix | |
0 | Random Forest | 0.973274 | 0.986726 | 0.961207 | 0.973799 | 0.973691 | [[214, 3], [9, 223]] |
7 | LGBM | 0.966592 | 0.977974 | 0.956897 | 0.96732 | 0.966928 | [[212, 5], [10, 222]] |
5 | XGB | 0.962138 | 0.969432 | 0.956897 | 0.963124 | 0.962319 | [[210, 7], [10, 222]] |
1 | Decision Tree | 0.957684 | 0.965066 | 0.952586 | 0.958785 | 0.95786 | [[209, 8], [11, 221]] |
4 | KNN | 0.955457 | 0.969027 | 0.943966 | 0.956332 | 0.955854 | [[210, 7], [13, 219]] |
2 | Logistic Regression | 0.953229 | 0.968889 | 0.939655 | 0.954048 | 0.953699 | [[210, 7], [14, 218]] |
3 | SVC | 0.951002 | 0.972973 | 0.931034 | 0.951542 | 0.951692 | [[211, 6], [16, 216]] |
6 | Multinomial NB | 0.919822 | 0.962264 | 0.87931 | 0.918919 | 0.921222 | [[209, 8], [28, 204]] |
Due to the limited time, we are not going to explain each metric. What we need to know is:
From the above output, we can see that:
True Positive (TP):
True Negative (TN):
False Positive (FP) - also known as Type I error:
False Negative (FN) - also known as Type II error:
These four measures are often displayed together in what is known as a Confusion Matrix. Understanding these terms is crucial since they form the basis for many other performance metrics in classification, like accuracy, precision, recall, and the F1 score.
The thing we need to know is that the higher the accuracy the better the model.
Next code cell will plot the confusion matrix for each model.
# Plot
classes = ['ham', 'spam']
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(20, 10))
for idx, ax in enumerate(axes.flatten()):
algo = scores_df.iloc[idx]['Algorithm']
cm = scores_df.iloc[idx]['Confusion Matrix']
# Prepare the labels for each cell
labels = np.array([["TP: " + str(cm[0, 0]), "FP: " + str(cm[0, 1])],
["FN: " + str(cm[1, 0]), "TN: " + str(cm[1, 1])]])
sns.heatmap(cm, annot=labels, fmt="", cmap="Blues", ax=ax, cbar=False)
ax.set_yticklabels(classes, rotation=0)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
Now let's try predicting a new message using numerical features based model.
In summary following code will create features for the new message and predict using our all the models. So we can compare the predictions of each model.
def extract_numerical_features(message):
features = {
'msg_len': len(message),
'num_exclamation_marks': message.count('!'),
'num_punctuations': sum(1 for char in message if char in string.punctuation),
'num_upper_case': sum(1 if char.isupper() else 0 for char in message),
'num_numeric': sum(1 if char.isdigit() else 0 for char in message),
'num_words': len(word_tokenize(message)),
'num_sentences': len(sent_tokenize(message))
return features
def predict_new_data(message, num_feat_models, encoder):
# Extract features
features = extract_numerical_features(message)
# Convert to dataframe
df = pd.DataFrame(features, index=[0])
# Dictionary to store the results
predictions = {}
for model_name, model in num_feat_models.items():
pred = model.predict(df)
# Assuming encoder.inverse_transform(pred) returns a list of class labels
predicted_class = encoder.inverse_transform(pred)[0]
predictions[model_name] = predicted_class
return predictions
def compare_predictions(text_samples, num_feat_models, encoder):
for text in text_samples:
print(f"Predicting for: {text}")
predictions = predict_new_data(text, num_feat_models, encoder)
# Print individual model predictions
for model, pred in predictions.items():
print(f"{model} predicted: {pred}")
# Check if all model predictions are the same; if not, compare the results
unique_predictions = set(predictions.values())
if len(unique_predictions) == 1:
print("All models agree on the prediction.")
print("Models made different predictions.")
# Usage:
text_samples = [
"Congratulations! You have won a $1000 Walmart gift card. Click here to claim:",
"Hey, how are you doing? Let's meet up soon!"
# Assuming 'num_feat_models' is a dictionary of your models and 'encoder' is your label encoder
print("Comparing predictions on new data using numeric features model")
compare_predictions(text_samples, num_feat_models, encoder)
Comparing predictions on new data using numeric features model Predicting for: Congratulations! You have won a $1000 Walmart gift card. Click here to claim: Random Forest predicted: spam Decision Tree predicted: spam Logistic Regression predicted: spam SVC predicted: spam KNN predicted: ham XGB predicted: spam Multinomial NB predicted: spam LGBM predicted: spam Models made different predictions. ---------- Predicting for: Hey, how are you doing? Let's meet up soon! Random Forest predicted: ham Decision Tree predicted: ham Logistic Regression predicted: ham SVC predicted: ham KNN predicted: ham XGB predicted: ham Multinomial NB predicted: ham LGBM predicted: ham All models agree on the prediction. ----------