始める前に、機械学習について簡単に説明します。
簡単に言うと:
機械学習は人工知能(AI)の一種で、明示的にプログラムすることなく学習する能力をコンピュータに提供するもの。
機械学習アルゴリズムには3つの種類があります:
今日は、教師あり学習のみを使用します。それでは、教師あり学習について説明しましょう。
教師あり学習は、入力と出力がペアになったサンプルに基づいて、入力を出力に対応付ける関数を学習する、機械学習タスクです。教師あり学習は、訓練用サンフルの集合からなるラベル付き訓練データから関数を構築します。
メールをスパムと非スパムのカテゴリーに分類するので、教師あり学習を使ってこの問題を解決する方法を見てみましょう。
皆さん、教師あり学習の概念には慣れましたか?
さて、本題のメール分類器の作成に移りましょう。
注意点1:
注意点2:
まず最初に、このプロジェクトで使用するライブラリをインストールします。
%pip install pandas scikit-learn numpy matplotlib seaborn nltk wordcloud xgboost lightgbm
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.2.2) Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.23.5) Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1) Requirement already satisfied: seaborn in /usr/local/lib/python3.10/dist-packages (0.12.2) Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.8.1) Requirement already satisfied: wordcloud in /usr/local/lib/python3.10/dist-packages (1.9.2) Requirement already satisfied: xgboost in /usr/local/lib/python3.10/dist-packages (2.0.1) Requirement already satisfied: lightgbm in /usr/local/lib/python3.10/dist-packages (4.1.0) Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1) Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.11.3) Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.2.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.1.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.43.1) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.5) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (23.2) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (9.4.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.1) Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.7) Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk) (2023.6.3) Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk) (4.66.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)
次に、このプロジェクトで使用するすべてのライブラリをインポートします。
このプロジェクトで使用するライブラリーを簡単な説明とともに列挙してみましょう。
# For data manipulation and analysis
import pandas as pd
import numpy as np
import string
from collections import Counter
# For plotting
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
# For text processing and feature extraction
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
# For text preprocessing and feature extraction
import nltk
from nltk import bigrams
from nltk.corpus import stopwords
from nltk.stem.porter import PorterStemmer
from nltk.tokenize import word_tokenize, sent_tokenize
# Machine learning models and evaluation metrics
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm.sklearn import LGBMClassifier
# Machine learning evaluation metrics
from sklearn.metrics import accuracy_score,confusion_matrix, precision_score, roc_auc_score, f1_score, recall_score
from wordcloud import WordCloud
import warnings
warnings.filterwarnings("ignore")
nltk.download('punkt')
nltk.download('stopwords')
[nltk_data] Downloading package punkt to /root/nltk_data... [nltk_data] Unzipping tokenizers/punkt.zip. [nltk_data] Downloading package stopwords to /root/nltk_data... [nltk_data] Unzipping corpora/stopwords.zip.
True
ライブラリをインポートしたら、次のステップとしてデータセットをインポートする必要があります。
# DATA_SET = "./data/spam_dataset.csv"
DATA_SET = "https://raw.githubusercontent.com/qualitia-cdev/hands-on-data/main/spam_dataset.csv"
df = pd.read_csv(DATA_SET, encoding='latin-1')
df.head()
v1 | v2 | Unnamed: 2 | Unnamed: 3 | Unnamed: 4 | |
---|---|---|---|---|---|
0 | ham | Go until jurong point, crazy.. Available only ... | NaN | NaN | NaN |
1 | ham | Ok lar... Joking wif u oni... | NaN | NaN | NaN |
2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... | NaN | NaN | NaN |
3 | ham | U dun say so early hor... U c already then say... | NaN | NaN | NaN |
4 | ham | Nah I don't think he goes to usf, he lives aro... | NaN | NaN | NaN |
データセットには複数のカラムが含まれているようです:
探索的データ分析(EDA)を進めるには:
次の関数は、データセットの要約を得るのに役立ちます。
# basic info about the dataset
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5572 entries, 0 to 5571 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 v1 5572 non-null object 1 v2 5572 non-null object 2 Unnamed: 2 50 non-null object 3 Unnamed: 3 12 non-null object 4 Unnamed: 4 6 non-null object dtypes: object(5) memory usage: 217.8+ KB
これらの列を削除するには、次のコード・セルを使用します。
# last 3 columns are unnecessary. So we are going to drop them.
df.drop(['Unnamed: 2', 'Unnamed: 3', 'Unnamed: 4'], axis=1, inplace=True)
次に、わかりやすいように、データセットのカラム名を変更します。
df.rename(columns={'v1': 'label', 'v2': 'message'}, inplace=True)
df.head(10)
label | message | |
---|---|---|
0 | ham | Go until jurong point, crazy.. Available only ... |
1 | ham | Ok lar... Joking wif u oni... |
2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... |
3 | ham | U dun say so early hor... U c already then say... |
4 | ham | Nah I don't think he goes to usf, he lives aro... |
5 | spam | FreeMsg Hey there darling it's been 3 week's n... |
6 | ham | Even my brother is not like to speak with me. ... |
7 | ham | As per your request 'Melle Melle (Oru Minnamin... |
8 | spam | WINNER!! As a valued network customer you have... |
9 | spam | Had your mobile 11 months or more? U R entitle... |
target_cat = df.label.value_counts()
#create pie chart using seaborn
color1 = sns.color_palette("viridis")[0]
color2 = sns.color_palette("viridis")[3]
plt.pie(target_cat, labels=target_cat.index, colors=[color1, color2], autopct='%1.1f%%', shadow=True, startangle=90, explode=(0, 0.1))
plt.title('Target Variable Distribution (Spam and Ham)')
plt.show()
print(target_cat)
ham 4825 spam 747 Name: label, dtype: int64
出力は:
データがこのように偏っていると、機械学習モデルは最高のパフォーマンスを発揮できないかもしれません。
不均衡な分布は、訓練とテスト用のデータを分割するときに問題を引き起こす可能性があります。そしてその結果、モデルが多数派のクラスに偏ってしまいます。
これを処理する方法はいくつかある。しかし、物事をシンプルにするために、ランダムに "ハム"メッセージを削除して、バランスさせます。
これを行った後、もう一度ターゲット変数の分布を見てみましょう。
num_spam = target_cat['spam']
# Randomly sampling the majority class to match the count of the minority class
df = pd.concat(
[
df[df['label'] == 'ham'].sample(n=num_spam, random_state=42),
df[df['label'] == 'spam']
]
)
#check the target variable distribution again
target_cat = df.label.value_counts()
#set the colors
color1 = sns.color_palette("viridis")[0]
color2 = sns.color_palette("viridis")[3]
#plot the pie chart
plt.pie(target_cat, labels=target_cat.index, colors=[color1, color2], autopct='%1.1f%%', shadow=True, startangle=90, explode=(0, 0.1))
plt.title('Target Variable Distribution after dropping ham messages')
plt.show()
print(target_cat)
ham 747 spam 747 Name: label, dtype: int64
これで、ハムとスパムの割合と通数が同じであることがわかります。では、次のステップに進んでみましょう。
ユニグラムとは?
次のコード・セルでは、スパムメッセージとハムメッセージで上位10個のユニグラムを取得します。 また、スパムメッセージに含まれるユニグラムの数がどのように見えるかをチェックします。
# We'll use a predefined list of stopwords as we can't download them directly here
# stop_words = set([
# "ourselves", "hers", "between", "yourself", "but",
# "again", "there", "about", "once", "during", "out",
# "very", "having", "with", "they", "own", "an", "be",
# "some", "for", "do", "its", "yours", "such", "into",
# "of", "most", "itself", "other", "off", "is", "s", "am",
# "or", "who", "as", "from", "him", "each", "the", "themselves",
# "until", "below", "are", "we", "these", "your", "his", "through",
# "don", "nor", "me", "were", "her", "more", "himself", "this",
# "down", "should", "our", "their", "while", "above", "both", "up",
# "to", "ours", "had", "she", "all", "no", "when", "at", "any", "before",
# "them", "same", "and", "been", "have", "in", "will", "on", "does",
# "yourselves", "then", "that", "because", "what", "over", "why", "so",
# "can", "did", "not", "now", "under", "he", "you", "herself", "has",
# "just", "where", "too", "only", "myself", "which", "those", "i",
# "after", "few", "whom", "t", "being", "if", "theirs", "my", "against",
# "a", "by", "doing", "it", "how", "further", "was", "here", "than"
# ])
#get the stopwords from nltk
stop_words = set(stopwords.words('english'))
def get_most_common_words(texts, n=10):
words = []
for text in texts:
tokens = word_tokenize(text)
words.extend([word.lower() for word in tokens if word.isalpha() and word.lower() not in stop_words])
# Count the frequency of each word
word_freq = Counter(words)
return word_freq.most_common(n)
# Get the 10 most common words for spam and ham messages
common_words_ham = get_most_common_words(df[df['label'] == 'ham']['message'], 10)
common_words_spam = get_most_common_words(df[df['label'] == 'spam']['message'], 10)
#create table to for common_words_ham, common_words_spam
common_words_ham_df = pd.DataFrame(common_words_ham, columns=['word', 'count'])
common_words_spam_df = pd.DataFrame(common_words_spam, columns=['word', 'count'])
common_words_spam_df.head(10)
word | count | |
---|---|---|
0 | call | 346 |
1 | free | 219 |
2 | txt | 156 |
3 | u | 144 |
4 | ur | 144 |
5 | mobile | 123 |
6 | text | 121 |
7 | stop | 114 |
8 | claim | 113 |
9 | reply | 104 |
それでは、スパムメールと迷惑メールのワードクラウド(Word Cloud)を作ってみましょう。
def generate_wordcloud(texts_ham, texts_spam):
# Join all the text messages together
all_text_ham = ' '.join(texts_ham)
all_text_spam = ' '.join(texts_spam)
# Create a word cloud for ham and spam messages
wordcloud_ham = WordCloud(stopwords=stop_words, background_color="white", width=800, height=400).generate(all_text_ham)
wordcloud_spam = WordCloud(stopwords=stop_words, background_color="white", width=800, height=400).generate(all_text_spam)
# Plot the word clouds
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 5))
ax1.imshow(wordcloud_ham, interpolation='bilinear')
ax1.axis('off')
ax1.set_title("Ham Messages Word Cloud", fontsize=25)
ax2.imshow(wordcloud_spam, interpolation='bilinear')
ax2.axis('off')
ax2.set_title("Spam Messages Word Cloud", fontsize=25)
# Adjust layout
plt.tight_layout()
plt.show()
# Generate word clouds for ham and spam messages side by side
generate_wordcloud(df[df['label'] == 'ham']['message'].tolist(), df[df['label'] == 'spam']['message'].tolist())
上でスパムメッセージとハムメッセージに最もよく見られるユニグラムについて見ました。次に、最も一般的なバイグラムについて見てみましょう。
テキスト分析におけるバイグラムとは、与えられたテキスト中の連続する単語のペアのことです。
例: "I love ice cream"という文章の場合、バイグラムだと以下のようになります。:
スパムとハムメッセージで最もよく使われるバイグラムを見てみましょう。
def remove_punctuation(message):
return ''.join([char for char in message if char not in string.punctuation])
def tokenize(message):
return [word.lower() for word in message.split() if word.lower() not in stop_words]
def generate_bigrams(tokens):
return list(bigrams(tokens))
def preprocess_message(message):
message = remove_punctuation(message)
tokens = tokenize(message)
return generate_bigrams(tokens)
def get_top_bigrams(series, n=10):
# Flatten the list of bigrams and count occurrences
all_bigrams = [bigram for sublist in series for bigram in sublist]
bigram_counts = Counter(all_bigrams)
return bigram_counts.most_common(n)
def plot_top_bigrams(df, column, title, ax, palette="viridis"):
sns.barplot(data=df, y=column, x='Count', ax=ax, palette=palette)
ax.set_title(title)
ax.set_xlabel('Count')
ax.set_ylabel(column)
# Assuming your original DataFrame is loaded as 'df'
df["bigrams"] = df["message"].apply(preprocess_message)
# Extract top bigrams
top_spam_bigrams = get_top_bigrams(df[df['label'] == 'spam']['bigrams'])
top_ham_bigrams = get_top_bigrams(df[df['label'] == 'ham']['bigrams'])
# Create dataframes for visualization
spam_df = pd.DataFrame(top_spam_bigrams, columns=['Bigrams', 'Count'])
ham_df = pd.DataFrame(top_ham_bigrams, columns=['Bigrams', 'Count'])
# Adjust bigram representation for readability
spam_df['Bigrams'] = spam_df['Bigrams'].apply(lambda x: ' '.join(x))
ham_df['Bigrams'] = ham_df['Bigrams'].apply(lambda x: ' '.join(x))
# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharex=True)
plot_top_bigrams(spam_df, 'Bigrams', 'Top Bigrams in Spam Messages', axes[0])
plot_top_bigrams(ham_df, 'Bigrams', 'Top Bigrams in Ham Messages', axes[1])
plt.tight_layout()
plt.show()
上のグラフから、次のことがわかります:
スパムメッセージには、"please call"(電話してください)、"contact u"(連絡します)、"await collection"(回収を待っています)など、受信者に行動を促す文言が頻繁に含まれています。
一方、非スパムメッセージは、「call later(後で電話する)」、「let know(知らせる)」、「good morning(おはよう)」などのフレーズで、日常会話を中心に展開されているようです。
# create a new column for message length and calculate the length of each message
df["msg_len"] = df["message"].apply(len)
# Plot the distribution of message lengths
sns.distplot(df[df.label == 'ham'].msg_len, label='Ham', color=sns.color_palette("viridis")[0])
sns.distplot(df[df.label == 'spam'].msg_len, label='Spam', color=sns.color_palette("viridis")[3])
plt.title('Distribution Plot for Length of Messages')
plt.xlabel('Length of Messages')
plt.legend()
plt.show()
print('Average message length for ham messages: ', df[df['label'] == 'ham']['message'].str.len().mean())
print('Average message length for spam messages: ', df[df['label'] == 'spam']['message'].str.len().mean())
Average message length for ham messages: 69.11780455153949 Average message length for spam messages: 138.8661311914324
上のグラフを見ると
時間が限られているため、各機能を説明することなく、すべての特徴量を一度に作成します。
しかし、実際のプロジェクトでは、各特徴量を分析し、使用するかどうかを決定する必要があります。
先ほどと同じようなアプローチで、次のような特徴量を作ることができます:
次のコード・セルは、上記のすべての特徴量を作成します。
# Create new column for number of punctuations, and calculate the number of punctuations in each message
df['num_punctuations'] = df['message'].apply(lambda x: sum([1 for char in x if char in string.punctuation]))
# Create new column for number of exclamation marks, and calculate the number of exclamation marks in each message
df['num_exclamation_marks'] = df['message'].apply(lambda x: sum([1 for char in x if char == '!']))
# Calculate the number of uppercase letters in each message and create a new column for it
df['num_upper_case'] = df['message'].apply(lambda x: sum([1 for char in x if char.isupper()]))
# Count the number of numeric characters in each message and create a new column for it
df['num_numeric'] = df['message'].apply(lambda x: sum([1 for char in x if char.isdigit()]))
# Tokenize messages and compute the number of sentences in each message and create a new column for it
df['num_sentences'] = df['message'].apply(lambda x: len(sent_tokenize(x)))
# Create a new column for number of words in each message and calculate the number of words in each message
df['num_words'] = df['message'].apply(lambda x: len(word_tokenize(x)))
df.columns
Index(['label', 'message', 'bigrams', 'msg_len', 'num_punctuations', 'num_exclamation_marks', 'num_upper_case', 'num_numeric', 'num_sentences', 'num_words'], dtype='object')
注意点
X = df[['msg_len', 'num_exclamation_marks', 'num_punctuations', 'num_upper_case', 'num_numeric', 'num_words', 'num_sentences']]
このデータセットでは、ラベル列をエンコードする必要があります。
たとえば
encoder = LabelEncoder()
encoded_labels = encoder.fit_transform(df['label'])
df['encoded_label'] = encoded_labels
df.tail()
label | message | bigrams | msg_len | num_punctuations | num_exclamation_marks | num_upper_case | num_numeric | num_sentences | num_words | encoded_label | |
---|---|---|---|---|---|---|---|---|---|---|---|
5537 | spam | Want explicit SEX in 30 secs? Ring 02073162414... | [(want, explicit), (explicit, sex), (sex, 30),... | 90 | 3 | 1 | 17 | 21 | 3 | 18 | 1 |
5540 | spam | ASKED 3MOBILE IF 0870 CHATLINES INCLU IN FREE ... | [(asked, 3mobile), (3mobile, 0870), (0870, cha... | 160 | 5 | 0 | 104 | 14 | 6 | 38 | 1 |
5547 | spam | Had your contract mobile 11 Mnths? Latest Moto... | [(contract, mobile), (mobile, 11), (11, mnths)... | 160 | 8 | 1 | 20 | 2 | 5 | 35 | 1 |
5566 | spam | REMINDER FROM O2: To get 2.50 pounds free call... | [(reminder, o2), (o2, get), (get, 250), (250, ... | 147 | 3 | 0 | 14 | 5 | 1 | 30 | 1 |
5567 | spam | This is the 2nd time we have tried 2 contact u... | [(2nd, time), (time, tried), (tried, 2), (2, c... | 161 | 8 | 1 | 9 | 21 | 4 | 35 | 1 |
ターゲット変数がどのようにエンコードされているか確認してみましょう。
# Display classes and their encoded values
for original, encoded in zip(encoder.classes_, range(len(encoder.classes_))):
print(f"'{original}' is encoded as {encoded}")
'ham' is encoded as 0 'spam' is encoded as 1
エンコードされたラベルを変数に代入します。(y)
注意点
y = df["encoded_label"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
様々なアルゴリズムを試して、そのパフォーマンスを比較してみましょう。
機械学習アルゴリズムの辞書を作ってみます。
それぞれのアルゴリズムについて簡単に説明します。
num_feat_models = {
"Random Forest": RandomForestClassifier(n_estimators=1000, random_state=42),
"Decision Tree": DecisionTreeClassifier(random_state=42),
"Logistic Regression": LogisticRegression(random_state=42),
"SVC": SVC(kernel="linear", random_state=42),
"KNN" : KNeighborsClassifier(),
"XGB": XGBClassifier(objective='binary:hinge', random_state=42),
"Multinomial NB": MultinomialNB(),
"LGBM" : LGBMClassifier(boosting_type='gbdt',objective='binary', max_depth=4, random_state=42, verbose=-1)
}
プロジェクトの最終段階に来ました。モデルのトレーニングです。
def train_clf(model, X_train, y_train, X_test, y_test):
# Train the model
model.fit(X_train, y_train)
# Make predictions and get accuracy and precision
y_pred = model.predict(X_test)
results = {
'Accuracy': accuracy_score(y_test, y_pred),
'Precision': precision_score(y_test, y_pred),
'Recall': recall_score(y_test, y_pred),
'F1 Score': f1_score(y_test, y_pred),
'Roc Auc Score': roc_auc_score(y_test, y_pred),
'Confusion Matrix': confusion_matrix(y_test, y_pred)
}
return results
# Store scores for each model
model_scores = {}
for name, model in num_feat_models.items():
model_scores[name] = train_clf(model, X_train, y_train, X_test, y_test)
# Convert scores dictionary to a DataFrame
scores_df = pd.DataFrame(model_scores).T
scores_df.reset_index(inplace=True)
scores_df.rename(columns={'index': 'Algorithm'}, inplace=True)
scores_df = scores_df.sort_values(by="F1 Score", ascending=False)
では、モデルの性能をチェックしてみましょう。
scores_df
Algorithm | Accuracy | Precision | Recall | F1 Score | Roc Auc Score | Confusion Matrix | |
---|---|---|---|---|---|---|---|
0 | Random Forest | 0.973274 | 0.986726 | 0.961207 | 0.973799 | 0.973691 | [[214, 3], [9, 223]] |
7 | LGBM | 0.966592 | 0.977974 | 0.956897 | 0.96732 | 0.966928 | [[212, 5], [10, 222]] |
5 | XGB | 0.962138 | 0.969432 | 0.956897 | 0.963124 | 0.962319 | [[210, 7], [10, 222]] |
1 | Decision Tree | 0.957684 | 0.965066 | 0.952586 | 0.958785 | 0.95786 | [[209, 8], [11, 221]] |
4 | KNN | 0.955457 | 0.969027 | 0.943966 | 0.956332 | 0.955854 | [[210, 7], [13, 219]] |
2 | Logistic Regression | 0.953229 | 0.968889 | 0.939655 | 0.954048 | 0.953699 | [[210, 7], [14, 218]] |
3 | SVC | 0.951002 | 0.972973 | 0.931034 | 0.951542 | 0.951692 | [[211, 6], [16, 216]] |
6 | Multinomial NB | 0.919822 | 0.962264 | 0.87931 | 0.918919 | 0.921222 | [[209, 8], [28, 204]] |
時間が限られているため、それぞれのメトリクスについて説明するつもりはありません。知っておくべきことは
上記の出力から、以下のことがわかります:
スコア
True Positive (TP):
True Negative (TN):
False Positive (FP) - 別名I型エラー:
False Negative (FN) - 別名II型エラー:
これらの4つの尺度は、しばしば混同行列(Confusion Matrix)として知られるものに一緒に表示されます。これらの用語は、Accuracy、Precision、Recall、F1スコアのような、分類における他の多くのパフォーマンス・メトリクスの基礎を形成するため、これらの用語を理解することは非常に重要です。
私たちが知っておかなければならないのは、Accuracy(精度)が高ければ高いほど良いモデルだということです。
次のコード・セルは、各モデルの混同行列(Confusion Matrix)をプロットします。
# Plot
classes = ['ham', 'spam']
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(20, 10))
for idx, ax in enumerate(axes.flatten()):
algo = scores_df.iloc[idx]['Algorithm']
cm = scores_df.iloc[idx]['Confusion Matrix']
# Prepare the labels for each cell
labels = np.array([["TP: " + str(cm[0, 0]), "FP: " + str(cm[0, 1])],
["FN: " + str(cm[1, 0]), "TN: " + str(cm[1, 1])]])
sns.heatmap(cm, annot=labels, fmt="", cmap="Blues", ax=ax, cbar=False)
ax.set_title(algo)
ax.set_xticklabels(classes)
ax.set_yticklabels(classes, rotation=0)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
plt.tight_layout()
plt.show()
要約すると、以下のコードは新しいメッセージの数値特徴を作成し、すべてのモデルを使って予測します。各モデルの予測値を比較することができます。
def extract_numerical_features(message):
features = {
'msg_len': len(message),
'num_exclamation_marks': message.count('!'),
'num_punctuations': sum(1 for char in message if char in string.punctuation),
'num_upper_case': sum(1 if char.isupper() else 0 for char in message),
'num_numeric': sum(1 if char.isdigit() else 0 for char in message),
'num_words': len(word_tokenize(message)),
'num_sentences': len(sent_tokenize(message))
}
return features
def predict_new_data(message, num_feat_models, encoder):
# Extract features
features = extract_numerical_features(message)
# Convert to dataframe
df = pd.DataFrame(features, index=[0])
# Dictionary to store the results
predictions = {}
for model_name, model in num_feat_models.items():
pred = model.predict(df)
# Assuming encoder.inverse_transform(pred) returns a list of class labels
predicted_class = encoder.inverse_transform(pred)[0]
predictions[model_name] = predicted_class
return predictions
def compare_predictions(text_samples, num_feat_models, encoder):
for text in text_samples:
print(f"Predicting for: {text}")
predictions = predict_new_data(text, num_feat_models, encoder)
# Print individual model predictions
for model, pred in predictions.items():
print(f"{model} predicted: {pred}")
# Check if all model predictions are the same; if not, compare the results
unique_predictions = set(predictions.values())
if len(unique_predictions) == 1:
print("All models agree on the prediction.")
else:
print("Models made different predictions.")
print("----------")
# Usage:
text_samples = [
"Congratulations! You have won a $1000 Walmart gift card. Click here to claim: https://bit.ly/3kceqyh",
"Hey, how are you doing? Let's meet up soon!"
]
# Assuming 'num_feat_models' is a dictionary of your models and 'encoder' is your label encoder
print("Comparing predictions on new data using numeric features model")
compare_predictions(text_samples, num_feat_models, encoder)
Comparing predictions on new data using numeric features model Predicting for: Congratulations! You have won a $1000 Walmart gift card. Click here to claim: https://bit.ly/3kceqyh Random Forest predicted: spam Decision Tree predicted: spam Logistic Regression predicted: spam SVC predicted: spam KNN predicted: ham XGB predicted: spam Multinomial NB predicted: spam LGBM predicted: spam Models made different predictions. ---------- Predicting for: Hey, how are you doing? Let's meet up soon! Random Forest predicted: ham Decision Tree predicted: ham Logistic Regression predicted: ham SVC predicted: ham KNN predicted: ham XGB predicted: ham Multinomial NB predicted: ham LGBM predicted: ham All models agree on the prediction. ----------