残念なお知らせですが、今日はメールの件名生成はできません。メールには個人情報が多く含まれており、学習に利用するメールデータをみなさんと共有できないからです。
その代わり、ニュース記事の本文からニュースのタイトルを生成してみたいと思います
メールのデータセットから件名を生成する流れ
ニュース記事からニュースのタイトルを生成する流れ
コンセプトはほとんど同じです。興味があれば、後で同じコードを使って少し調整するだけで、メールのデータを学習させることが出来ます。サポートが必要な場合は遠慮なくnuwan@qualitia.comまでご連絡ください。
重要事項 1:
重要事項 2:
まず、必要なライブラリをインストールしてみましょう。
pipコマンドを使用してライブラリをインストールすることが出来ます。 本プロジェクトではつぎのライブラリを使用します。 Google Colabでシステムのコマンドを実行する場合は、コマンドの先頭に「!」または「%」を付けてください。
使用するライブラリ一覧
%pip install transformers[torch,ja]==4.33.3 datasets==2.14.5 sentencepiece matplotlib seaborn evaluate absl-py bert_score pandas tokenizers==0.13.3
Collecting transformers[ja,torch]==4.33.3 Downloading transformers-4.33.3-py3-none-any.whl (7.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.6/7.6 MB 51.3 MB/s eta 0:00:00 Collecting datasets==2.14.5 Downloading datasets-2.14.5-py3-none-any.whl (519 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 519.6/519.6 kB 52.3 MB/s eta 0:00:00 Collecting sentencepiece Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 62.1 MB/s eta 0:00:00 Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1) Requirement already satisfied: seaborn in /usr/local/lib/python3.10/dist-packages (0.12.2) Collecting evaluate Downloading evaluate-0.4.1-py3-none-any.whl (84 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.1/84.1 kB 11.4 MB/s eta 0:00:00 Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (1.4.0) Collecting bert_score Downloading bert_score-0.3.13-py3-none-any.whl (61 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.1/61.1 kB 7.8 MB/s eta 0:00:00 Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3) Collecting tokenizers==0.13.3 Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 100.9 MB/s eta 0:00:00 Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (3.12.4) Collecting huggingface-hub<1.0,>=0.15.1 (from transformers[ja,torch]==4.33.3) Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 302.0/302.0 kB 32.5 MB/s eta 0:00:00 Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (1.23.5) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (23.2) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (6.0.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (2023.6.3) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (2.31.0) Collecting safetensors>=0.3.1 (from transformers[ja,torch]==4.33.3) Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 78.6 MB/s eta 0:00:00 Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (4.66.1) Requirement already satisfied: torch!=1.12.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (2.1.0+cu118) Collecting accelerate>=0.20.3 (from transformers[ja,torch]==4.33.3) Downloading accelerate-0.24.1-py3-none-any.whl (261 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 261.4/261.4 kB 30.3 MB/s eta 0:00:00 Collecting fugashi>=1.0 (from transformers[ja,torch]==4.33.3) Downloading fugashi-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (600 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 600.9/600.9 kB 53.2 MB/s eta 0:00:00 Collecting ipadic<2.0,>=1.0.0 (from transformers[ja,torch]==4.33.3) Downloading ipadic-1.0.0.tar.gz (13.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.4/13.4 MB 77.9 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Collecting unidic-lite>=1.0.7 (from transformers[ja,torch]==4.33.3) Downloading unidic-lite-1.0.8.tar.gz (47.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 47.4/47.4 MB 16.9 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Collecting unidic>=1.0.2 (from transformers[ja,torch]==4.33.3) Downloading unidic-1.1.0.tar.gz (7.7 kB) Preparing metadata (setup.py) ... done Collecting sudachipy>=0.6.6 (from transformers[ja,torch]==4.33.3) Downloading SudachiPy-0.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 24.0 MB/s eta 0:00:00 Collecting sudachidict-core>=20220729 (from transformers[ja,torch]==4.33.3) Downloading SudachiDict_core-20230927-py3-none-any.whl (71.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.7/71.7 MB 8.8 MB/s eta 0:00:00 Collecting rhoknp<1.3.1,>=1.1.0 (from transformers[ja,torch]==4.33.3) Downloading rhoknp-1.3.0-py3-none-any.whl (86 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.8/86.8 kB 11.5 MB/s eta 0:00:00 Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (9.0.0) Collecting dill<0.3.8,>=0.3.0 (from datasets==2.14.5) Downloading dill-0.3.7-py3-none-any.whl (115 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 14.0 MB/s eta 0:00:00 Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (3.4.1) Collecting multiprocess (from datasets==2.14.5) Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 15.5 MB/s eta 0:00:00 Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (2023.6.0) Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (3.8.6) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.1.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.43.1) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.5) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (9.4.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2) Collecting responses<0.19 (from evaluate) Downloading responses-0.18.0-py3-none-any.whl (38 kB) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1) Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[ja,torch]==4.33.3) (5.9.5) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (23.1.0) Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (3.3.1) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (6.0.4) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (4.0.3) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (1.9.2) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (1.4.0) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (1.3.1) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers[ja,torch]==4.33.3) (4.5.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[ja,torch]==4.33.3) (3.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[ja,torch]==4.33.3) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[ja,torch]==4.33.3) (2023.7.22) Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (1.12) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (3.2) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (3.1.2) Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (2.1.0) Collecting wasabi<1.0.0,>=0.6.0 (from unidic>=1.0.2->transformers[ja,torch]==4.33.3) Downloading wasabi-0.10.1-py3-none-any.whl (26 kB) Collecting plac<2.0.0,>=1.1.3 (from unidic>=1.0.2->transformers[ja,torch]==4.33.3) Downloading plac-1.4.1-py2.py3-none-any.whl (22 kB) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (2.1.3) Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (1.3.0) Building wheels for collected packages: ipadic, unidic, unidic-lite Building wheel for ipadic (setup.py) ... done Created wheel for ipadic: filename=ipadic-1.0.0-py3-none-any.whl size=13556703 sha256=c734594f6ebc044bf9e2ba4e0351f66bffea4a62d1c1c56f2c5445abd1449190 Stored in directory: /root/.cache/pip/wheels/5b/ea/e3/2f6e0860a327daba3b030853fce4483ed37468bbf1101c59c3 Building wheel for unidic (setup.py) ... done Created wheel for unidic: filename=unidic-1.1.0-py3-none-any.whl size=7406 sha256=21c1c4b149b24a152647c2b7122b508f03f674ab6c7275fbd942ecf8992bfd0b Stored in directory: /root/.cache/pip/wheels/7a/72/72/1f3d654c345ea69d5d51b531c90daf7ba14cc555eaf2c64ab0 Building wheel for unidic-lite (setup.py) ... done Created wheel for unidic-lite: filename=unidic_lite-1.0.8-py3-none-any.whl size=47658816 sha256=bbdeafc6a0e71520c94eb7be23a273d26c3e3b6792b08dcaba6da4902f71935e Stored in directory: /root/.cache/pip/wheels/89/e8/68/f9ac36b8cc6c8b3c96888cd57434abed96595d444f42243853 Successfully built ipadic unidic unidic-lite Installing collected packages: wasabi, unidic-lite, tokenizers, sudachipy, sentencepiece, plac, ipadic, sudachidict-core, safetensors, rhoknp, fugashi, dill, unidic, responses, multiprocess, huggingface-hub, transformers, accelerate, datasets, bert_score, evaluate Attempting uninstall: wasabi Found existing installation: wasabi 1.1.2 Uninstalling wasabi-1.1.2: Successfully uninstalled wasabi-1.1.2 Successfully installed accelerate-0.24.1 bert_score-0.3.13 datasets-2.14.5 dill-0.3.7 evaluate-0.4.1 fugashi-1.3.0 huggingface-hub-0.18.0 ipadic-1.0.0 multiprocess-0.70.15 plac-1.4.1 responses-0.18.0 rhoknp-1.3.0 safetensors-0.4.0 sentencepiece-0.1.99 sudachidict-core-20230927 sudachipy-0.6.7 tokenizers-0.13.3 transformers-4.33.3 unidic-1.1.0 unidic-lite-1.0.8 wasabi-0.10.1
さて、始める前にGPUが利用可能かどうか確認します。 次のコードで確認できます。
import torch
if torch.cuda.is_available():
status = "GPU is enabled."
device_count = torch.cuda.device_count()
current_device = torch.cuda.current_device()
print(f"{status}\ndevice count: {device_count}, current device: {current_device}")
else:
print("GPU is disabled.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
GPU is enabled. device count: 1, current device: 0 device: cuda
重要 !!!:
GPU is disabled.
Device: cpu
機械学習やディープラーニングでは結果の再現性を確保するために乱数のシードの設定が重要になります。 今回のプロジェクトでは42を使用します。
これらのシードを設定しても、異なるプラットフォームやGPUアーキテクチャ間で再現性を実現することは依然として困難であることに注意してください。
import random
import numpy as np
import torch
import warnings
warnings.filterwarnings('ignore')
def seed_everything(seed_value):
random.seed(seed_value) # Python
np.random.seed(seed_value) # Numpy
torch.manual_seed(seed_value) # CPU
if torch.cuda.is_available():
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # GPU if available
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_value = 42
seed_everything(seed_value)
先ほどお知らせしたように、ニュースのデータセットを使用します。データセットは日本語のLivedoorニュースのコーパスです。
データセットをロードするには、load_dataset関数を使います。 load_datasetを使用して、Hugging Face Hubからデータセットをロードします。この関数は、多くのデータセットに簡単にアクセスする方法を提供します。また、トレーニング、検証、テストセットへの分割など、標準的なデータ処理を実現することができます。
参考情報:
llm-book/livedoor-news-corpus: データセットの識別子や場所
train_ratio=0.8: データセットの80%を学習に割り当てる
validation_ratio=0.1: データセットの10%を検証に割り当てる
seed=42: 再現性確保のために乱数のシードに固定値を使用します
shuffle=False: 学習データ、検証データ、テストデータに分割する際にシャッフルしないようにします
from datasets import load_dataset
dataset = load_dataset(
"llm-book/livedoor-news-corpus",
train_ratio = 0.8,
validation_ratio = 0.1,
seed=42,
shuffle=False,
)
Downloading builder script: 0%| | 0.00/3.88k [00:00<?, ?B/s]
Downloading readme: 0%| | 0.00/823 [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/8.86M [00:00<?, ?B/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating validation split: 0 examples [00:00, ? examples/s]
Generating test split: 0 examples [00:00, ? examples/s]
dataset
DatasetDict({ train: Dataset({ features: ['url', 'date', 'title', 'content', 'category'], num_rows: 5893 }) validation: Dataset({ features: ['url', 'date', 'title', 'content', 'category'], num_rows: 736 }) test: Dataset({ features: ['url', 'date', 'title', 'content', 'category'], num_rows: 738 }) })
それぞれのデータセットは次のような要素で構成されています。
とりあえず、テストデータをpandasのデータフレームにロードし、それがどのように見えるか確認してみましょう。
dataframeは2次元のラベル付きデータ構造で、異なるタイプの列を持つことができます。Excelの表のようなものだと考えてください。
要約すると、このコードはpandasライブラリをインポートし、データセットの'test'部分をpandas DataFrameに変換し、そのDataFrameの最初の5行を表示しています。
参考情報:
import pandas as pd: この行はpandasライブラリをインポートし、pdという別名で使えるようにします。Pandasは、データ分析や操作、特に表形式のデータで広く使われているPythonライブラリです。
test_df = pd.DataFrame(dataset['test']): This line does two main things:
import pandas as pd
test_df = pd.DataFrame(dataset['test'])
test_df.head()
url | date | title | content | category | |
---|---|---|---|---|---|
0 | http://news.livedoor.com/article/detail/5936102/ | 2011-10-14T09:11:00+0900 | ゼンショー「事実無根」と反論 | 10月13日の夜、ゼンショーの広報室長がTwitterで読売新聞の報道に「事実無根」と反論し... | topic-news |
1 | http://news.livedoor.com/article/detail/5936557/ | 2011-10-14T11:16:00+0900 | 「報ステ」OP曲演奏のジャズミュージシャンに“売名行為”と批判相次ぐ | 先日、福島県が行っている新米の放射性物質本検査が全て終了した。規制値を超える放射性セシウムは... | topic-news |
2 | http://news.livedoor.com/article/detail/5936721/ | 2011-10-14T11:46:00+0900 | 「何のための“予約”なんですか」孫社長に批判殺到 | ソフトバンクは、今朝から発売が始まった“iPhone4S”をはじめとする、ソフトバンク全ての... | topic-news |
3 | http://news.livedoor.com/article/detail/5937177/ | 2011-10-15T10:00:00+0900 | あまりにも多すぎる「会いたくて」への皮肉か!? 「西野カナゲーム」が流行 | いま巷で「西野カナゲーム」なるものが流行してるという。 簡単に説明すると、1人目、2人目は... | topic-news |
4 | http://news.livedoor.com/article/detail/5937649/ | 2011-10-14T16:03:00+0900 | 憶測呼ぶ紳助さんの“天敵”引退 | 10月13日発売の東スポに「紳助の天敵が引退」との見出しが躍った。その天敵とは、警察庁の安藤... | topic-news |
このタスクではgoogleのmt5-smallモデルを使用します。
mt5はgoogleが公開したText-to-textトランスフォーマーモデルです。多言語モデルで、翻訳、要約、分類、その他多くのタスクに使用できます。
Text-to-textトランスフォーマーとは、テキストを入力してテキスト出力するような、さまざまなタスクに使用できるトランスフォーマーモデルのクラスです。例えば、
など。
今回はm5-smallを使用しますが、ほかにもたくさんの同じようなモデルがあります。例えば、mt5-baseです。 mt5-baseはmt5-smallより大きなモデルで、もっとたくさんのパラメータがあり学習時間もかかりますが、学習能力は高いです。
モデルのトークナイザをロードするためにAutoTokenizerを使用します。 次のコードは mt5-small モデルのトークナイザをロードします。
トークナイザとはとは文字列をトークンに分割する機能です。例えば、"Hello world!"という文字列は['Hello', 'world', '!']というトークンのリストに分割されます。
このコードでは、transformersライブラリからAutoTokenizerクラスをインポートし、"google/mt5-small"モデルを指定して、関連するトークナイザをmt5_tokenizerという変数にセットしますします。このトークナイザーはmT5 smallモデルのテキスト入力時に入力をトークン化します。
参考情報:
コードの内訳を見てみましょう:
from transformers import AutoTokenizer: transformersライブラリからAutoTokenizerクラスをimportします。このクラスは、与えられたモデルに対して適切なトークナイザを自動的に取得するように設計されています。
MODEL_NAME = "google/mt5-small": この行は、定数MODEL_NAMEにGoogleが提供する小さなバージョンであるmT5(多言語(multilingual) T5)モデルの識別子を設定します。T5(Text-to-Textトランスフォーマー)は、様々なNLPタスクのために設計されたトランスフォーマーベースの一般的なモデルで、多言語バージョン(mT5)は複数の言語で学習されています。
mt5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME): この行はmT5モデルに関連付けられたトークナイザを初期化します。from_pretrainedメソッドは、指定されたモデル(google/mt5-small)のトークナイザをHugging Faceのモデルハブから取得し、mt5_tokenizer変数に代入します。
from transformers import AutoTokenizer
MODEL_NAME = "google/mt5-small"
mt5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
mt5_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
(…)small/resolve/main/tokenizer_config.json: 0%| | 0.00/82.0 [00:00<?, ?B/s]
(…)oogle/mt5-small/resolve/main/config.json: 0%| | 0.00/553 [00:00<?, ?B/s]
(…)ogle/mt5-small/resolve/main/spiece.model: 0%| | 0.00/4.31M [00:00<?, ?B/s]
(…)all/resolve/main/special_tokens_map.json: 0%| | 0.00/99.0 [00:00<?, ?B/s]
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
トークナイザーは、特にMT5のようなトランスフォーマベースのモデルを扱う場合、自然言語処理パイプラインにおいて非常に重要なコンポーネントです。これは、人間が読めるテキストからモデルが期待するフォーマットへの変換をします。
Encoding: テキスト(この場合は日本語)をモデルが理解できる形式に変換します。これには、テキストを単語や形態素などに分割し、モデルか理解できるIDにマッピングすることが含まれます。
Decoding: モデルの出力(トークンID)を人間が読めるテキストに変換します。
たとえば、次のような日本語の文章があるとしましょう。
「本日はAIトレーニングセッションへようこそ!」
text = "本日はAIトレーニングセッションへようこそ!"
encoded_text = mt5_tokenizer(text)
print("Encoded text: ", encoded_text)
tokenized_text = mt5_tokenizer.tokenize(text)
print("Tokenized text: ", tokenized_text)
decoded_text = mt5_tokenizer.decode(encoded_text["input_ids"], skip_special_tokens=True)
print("Decoded Text: ", decoded_text)
Encoded text: {'input_ids': [259, 212152, 15428, 96992, 191286, 6031, 15578, 68875, 309, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} Tokenized text: ['▁', '本日は', 'AI', 'トレーニング', 'セッション', 'へ', 'よう', 'こそ', '!'] Decoded Text: 本日はAIトレーニングセッションへようこそ!
MT5トークナイザは「本日はAIトレーニングセッションへようこそ!」という文章を単語などのモデルが理解できる単位に分割しました。 それぞれのトークンにはユニークなIDが割り振られていて、decodeするときにはこれを逆に利用して元の文章に戻します。トークン化とdecodeは可逆的に設計されていて、元の文章は完全に復元可能です。
参考情報:
Encoded Text:
input_ids: これらはそれぞれのトークンに割り当てられたIDです。 [259, 212152, 15428, 96992, 191286, 6031, 15578, 68875, 309, 1]の並びはTokenized textのトークンと対応します。
attention_mask: この[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]はinput_idsに含まれる全てのトークンがモデルに対応していることを示しています。
Tokenized Text:
Tokenized textは['▁', '本日は', 'AI', 'トレーニング', 'セッション', 'へ', 'よう', 'こそ', '!']となりました。これが、MT5 tokenizerが入力文章を分割したトークンです。
Decoded Text:
次のステップ(モデル用の入力テキストを準備する)では、トークン化されたテキストとタイトルの最大長を決定する必要があります。 トークン化された記事やタイトルの最大長かわからないので調べる必要があります。 長さが長いほど良い結果が得られますが、トレーニングに時間がかかります。そこで、長さと学習時間のバランスのいいところを見つける必要があります。
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def dist_info(text, max_bin_size=1024):
token_counts = [len(mt5_tokenizer.tokenize(content)) for content in text]
sns.histplot(token_counts, bins=100, binrange=(0, max_bin_size))
plt.title("Tokenized Text Length Distribution")
plt.xlabel("Tokenized Text Length")
plt.ylabel("Count")
plt.show()
percentiles = [25, 50, 75, 90, 95, 99]
for p in percentiles:
print(f"{p}% of the dataset has token count below: {np.percentile(token_counts, p)}")
このコードは分布を出力する関数を定義しただけですので、まだ何も出力されません。
上の関数を実行してトークンの分布を確認してみましょう。
dist_info(dataset["train"]["content"])
25% of the dataset has token count below: 398.0 50% of the dataset has token count below: 589.0 75% of the dataset has token count below: 815.0 90% of the dataset has token count below: 1066.0 95% of the dataset has token count below: 1302.0 99% of the dataset has token count below: 1969.08
上記により、オススメは:
計算に制約がある場合(GPUメモリが限られているなど)、max_lengthを1024にすることを検討してください。これでデータの90%をカバーできます。メモリ不足のエラーを防ぐために、バッチサイズを調整する必要があるかもしれません。
計算量の制約がなく、データセットを切り捨てずにできるだけ多くのデータを確実に利用したい場合は、1024以上の長さを検討するといいでしょう(例えば、95パーセンタイルまでカバーするには1280)。ただし、1024より長い配列はすべてのT5モデルでサポートされない可能性があることに注意してください。利用するモデルが1024より長い配列を扱えることを確認する必要があります。
今回は時間の都合で最大長512を使用します。あとで時間があれば1024でも試してみてください。もっといい結果になるかも知れません。
さっきの関数を使用して、タイトルのトークン数の分布を確認してみましょう。
本文でさえ90%の範囲か1024トークンに収まるので、タイトルはおそらく128トークン以下だろう、と予測して、max_bin_sizeを128で実行してみます。
dist_info(dataset["train"]["title"], max_bin_size=128)
25% of the dataset has token count below: 16.0 50% of the dataset has token count below: 21.0 75% of the dataset has token count below: 25.0 90% of the dataset has token count below: 29.800000000000182 95% of the dataset has token count below: 32.0 99% of the dataset has token count below: 40.0
上記の結果により以下のように分析します。
2. 最大長48~50の場合:
- Pros:
- データセットの99%がカバーされます
- これでも、メモリや計算量で効率は比較的よさそうです
- Cons:
- かなり短いタイトルの場合、若干パディングが多くなるが、これらは全体的に少ないことを考えれば、これは些細な問題である
以上を考慮すると:
タイトルのmax_lengthを32にするのは、効率性を考えるとよい選択でしょう。
ほぼすべてのタイトルが切り捨てられないようにしたい場合は、max_lengthを48または50にするとよいでしょう。
今回は、タイトルの最大長として64を使うつもりです。切り捨てなしでタイトルの99%(あるいはおそらく99.99%)をカバーでき、メモリと計算の面でも比較的効率的です
さて、上記の値 512, 64をセットしましょう。
SOURCE_MAX_LEN = 512
TARGET_MAX_LEN = 64
生データには、無関係な情報、HTMLタグ、特殊文字などのノイズが含まれることがあります。そのようなノイズを除去またはクリーニングすることで、モデルがデータの本質的な部分に集中できるようにします。
特定の文字やシーケンスは、トークン化やモデリング処理において特別な意味を持つ場合があります。これらは前処理で適切に処理するか、エスケープする必要があります。
text_clean_preprocess関数はテキスト入力を受け取り、それに対していくつかのクリーニング処理を行います。newlineの改行文字(\n)は空文字列に、全角スペースは半角スペースに、タブ文字は空文字列に、キャリッジリターン文字(\r)は空文字列に置き換えます。最後に、テキストを小文字に変換し、クリーニングされたテキストを返します。
NEWLINE_CHAR = "\n"
SPACE_CHAR = "\u3000"
TAB_CHAR = "\t"
CARRIAGE_RETURN_CHAR = "\r"
def text_clean_preprocess(text, newline_char=NEWLINE_CHAR, space_char=SPACE_CHAR, tab_char=TAB_CHAR, carriage_return_char=CARRIAGE_RETURN_CHAR):
text = text.replace(newline_char, "")
text = text.replace(space_char, " ")
text = text.replace(tab_char, "")
text = text.replace(carriage_return_char, "")
text = text.lower()
return text
mT5のようなモデルでは、特定の形式、つまりトークンIDのシーケンスでの入力を期待します。この関数は、生テキストをそのようなシーケンスにトークン化します。
参考情報:
Consistent Data Length: ニューラル・モデルは通常、固定長の入力を必要とします。配列をSOURCE_MAX_LEN と TARGET_MAX_LEN になるように、パディングしたり切り詰めたりすることで、入力とターゲットの長さが固定長になるようにします。
Tokenization: 前処理されたコンテンツ(入力)とdata["title"](ターゲット)のタイトルは、mt5_tokenizerを使用してトークンIDに変換されます。
Attention Masking: attention maskは実際のトークン(1)とパディングされたトークン(0)を区別し、モデルが入力の意味のある部分に集中できるようにします。
def tokenize_data(data):
input_text = [text_clean_preprocess(content) for content in data["content"]]
target_text = data["title"]
inputs = mt5_tokenizer(input_text, max_length=SOURCE_MAX_LEN, truncation=True, padding="max_length")
targets = mt5_tokenizer(target_text, max_length=TARGET_MAX_LEN, truncation=True, padding="max_length")
label_attention_mask = [1] * len(targets["input_ids"])
return {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"decoder_input_ids": targets["input_ids"],
"labels": targets.input_ids,
"label_attention_mask": label_attention_mask
}
前のステップでは、データを変換する関数を作成しました。今度はその関数をデータセットに適用してみましょう。 データセットに関数を適用するにはmap関数を使用します。
参考情報:
tokenized_dataset = dataset.map(
tokenize_data,
batched=True,
remove_columns=["content", "title", "url", "date", "category"]
)
Map: 0%| | 0/5893 [00:00<?, ? examples/s]
Map: 0%| | 0/736 [00:00<?, ? examples/s]
Map: 0%| | 0/738 [00:00<?, ? examples/s]
さて、いよいよ事前に学習したモデルをロードします。上で述べたように、我々はmt5-smallモデルを使用します。 以下のコードは、"google/mt5-small "モデルをロードします。一度ロードされたモデルは、翻訳、要約、テキスト生成など様々なタスクに使うことができる。
一度もモデルをダウンロードしたことのない環境でこれを実行する場合、from_pretrainedメソッドはHugging Faceモデルハブからモデルの重みをダウンロードします。この場合、インターネット接続が有効であることを確認してください。
参考情報:
AutoModelForSeq2SeqLM: これは、提供されたモデル名またはパスに基づいて、適切なSequence to sequenceモデル(T5、BERT、GPT-2など)を自動的に推論し、ロードするように設計されたクラスです。モデルの具体的なアーキテクチャはわからないが、名前かパスは知っているという場合に便利です。
from_pretrained: これは、与えられた名前もしくはパスに基づいてモデルをロードするクラスメソッドです。モデルの重みと設定をダウンロードし、適切なモデルクラスのインスタンスを返します。
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
pytorch_model.bin: 0%| | 0.00/1.20G [00:00<?, ?B/s]
(…)mall/resolve/main/generation_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]
今回は、テキスト(コンテンツ)の並び(Sequence)を別のテキストの並び(タイトル)に変換するトレーニングを行うので、DataCollatorForSeq2Seqを使います。
DataCollator関数は、データセットからサンプルのリストを受け取り、それらをTensorsの辞書としてバッチに照合します。
Data Collatorは以下のような動作をします:
seq2seqモデルを微調整する際にDataCollatorForSeq2Seqを使用するのは良いプラクティスなので、学習のワークフローに組み込むことが推奨されます。
要約すると、以下のコードは、コンピュータプログラム(モデル)が会話の中で自分自身の応答を理解し、生成できるように、テキストデータを準備し整理するツール(data_collator)をセットアップしています。
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(
tokenizer=mt5_tokenizer,
model=model,
)
モデルを学習する際には、検証用データセットでその性能を評価することが重要です。
BERTScoreはトランスフォーマベースのモデルとして有名なBERTモデルに基づいており、要約や翻訳などのテキスト生成タスクを評価するのに適しています。
先に進む前に、BERTスコアの使い方を確認しましょう。
この関数は、BERTScore メトリックを使用して、生成された文章とラベル付きの文章の平均BERTスコアを計算します。
import evaluate
def compute_bert_score_demo(preds, labels):
bert_score_metric = evaluate.load("bertscore")
bert_score_metric.add_batch(predictions=preds, references=labels)
result = bert_score_metric.compute(lang="ja")
avg_scores = {k: sum(v) / len(v) for k, v in result.items() if k != "hashcode"}
return avg_scores
original = "リンゴが好きです。"
candidate_1 ="リンゴが大好きです。"
candidate_2 = "赤い車を買うつもりです。"
bs_results = {
"候補1": compute_bert_score_demo([candidate_1], [original]),
"候補2": compute_bert_score_demo([candidate_2], [original]),
}
bs_df = pd.DataFrame(bs_results).T
bs_df
Downloading builder script: 0%| | 0.00/7.95k [00:00<?, ?B/s]
(…)cased/resolve/main/tokenizer_config.json: 0%| | 0.00/29.0 [00:00<?, ?B/s]
(…)tilingual-cased/resolve/main/config.json: 0%| | 0.00/625 [00:00<?, ?B/s]
(…)ultilingual-cased/resolve/main/vocab.txt: 0%| | 0.00/996k [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/714M [00:00<?, ?B/s]
precision | recall | f1 | |
---|---|---|---|
候補1 | 0.966469 | 0.986506 | 0.976384 |
候補2 | 0.680357 | 0.723368 | 0.701203 |
compute_bert_score関数は、推論した値と元のラベルの両方を含むeval_preds変数を引数として受け取ります。次に、mt5_tokenizer.batch_decode関数を使用して推論値とラベルをデコードします。最後に、bert_score_metric.compute関数を使用してBERTスコアを計算します。
def compute_bert_score(eval_preds):
bert_score_metric = evaluate.load("bertscore")
predictions, labels = eval_preds
decoded_preds = mt5_tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = mt5_tokenizer.batch_decode(labels, skip_special_tokens=True)
result = bert_score_metric.compute(
predictions=decoded_preds, references=decoded_labels, lang="ja"
)
return {
"bertscore_precision": sum(result["precision"]) / len(result["precision"]),
"bertscore_recall": sum(result["recall"]) / len(result["recall"]),
"bertscore_f1": sum(result["f1"]) / len(result["f1"])
}
Seq2SeqTrainingArguments
これはtransformersライブラリが提供するクラスで、sequence-to-sequenceモデル(要約モデルのように、あるシーケンスを別のシーケンスに変換するモデル)のトレーニング関連の引数やハイパーパラメータを定義することができます。これはTrainingArgumentsのサブクラスで、トレーニング関連の引数を定義するための汎用クラスです。
パフォーマンスを向上させるために使用できるハイパーパラメータは他にもたくさんありますが、これらはデータセット、モデル、タスクに依存します。さらに、最適なハイパーパラメータを見つけるための専用のハイパーパラメータ・チューニング手法もあります。 今回は、トレーニング時間とパフォーマンスのバランスが重要なので、以下のハイパーパラメータを使うことにします。
参考情報:
per_device_train_batch_size, per_device_eval_batch_size: 学習と評価のバッチサイズ。バッチサイズとは、モデルが重みを更新する前に処理するデータサンプルの数です。
learning_rate: これは初期学習率です。
lr_scheduler_type, warmup_ratio: これらは学習率のスケジュールを定義します。"linear"は学習率がエポックごとに直線的に減少することを意味します。warmup_ratioは、学習率が直線的に増加してから減衰を開始するまでの学習ステップの割合を定義します。
num_train_epochs: モデルを訓練するエポック数。「一つの訓練データを何回繰り返して学習させるか」の数のことです。
evaluation_strategy, save_strategy, logging_strategy: これらはそれぞれ、評価、保存、ログのタイミングを定義します。"epoch "は、これらのアクションが各エポックの終了時に実行されることを意味します。
logging_steps: トレーニングメトリクスのログの頻度(100ステップごと).
logging_dir: ログの保存ディレクトリ
do_train, do_eval: それぞれ学習や評価を行うべきかどうかを判断するためのフラグ。
output_dir: モデルのチェックポイントなどのトレーニングの出力が保存されるディレクトリ。
save_total_limit: 保存されるモデルのチェックポイントの最大数。古いチェックポイントは削除されます。
load_best_model_at_end: Trueの場合、評価指標に従った最良のモデルがトレーニングの最後にロードされます。
push_to_hub: Trueの場合、モデルはHugging Face Model Hubにプッシュされます。
predict_with_generate: シーケンスを生成するためにpredictメソッドがgenerateメソッドを使用するようにします。
コメントアウトされたパラメータ(optim, gradient_accumulation_steps, weight_decay, fp16)は、さらに学習をカスタマイズするために使用できる追加のオプション引数です。例えば、fp16=Trueを指定すると、半精度(16ビット)浮動小数点のトレーニングが有効になり、より高速なパフォーマンスが得られる可能性があります。
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
NUM_EPOCHS = 5
LEARNING_RATE = 5e-4
WARMUP_RATIO = 0.1
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 4
WEIGHT_DECAY = 0.01
LOGGING_DIR = "./logs"
OUTPUT_DIR = "./results"
training_args = Seq2SeqTrainingArguments(
per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
learning_rate=LEARNING_RATE,
# lr_scheduler_type="linear", #comment on colab
# warmup_ratio=0.1, #comment on colab
num_train_epochs=NUM_EPOCHS,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
logging_steps=100,
logging_dir=LOGGING_DIR,
do_train=True,
do_eval=True,
output_dir=OUTPUT_DIR,
save_total_limit=2,
load_best_model_at_end=True,
push_to_hub=False,
predict_with_generate=True,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
)
以下のコードでは、さっき定義したmodel、training_args、data_collatorを使用してTrainerインスタンスを初期化しています。Trainerはモデルの学習と評価を行います。
参考情報:
Trainerクラスは、バッチ処理、ロギング、チェックポイントの保存など、トレーニングの低レベルな詳細の多くを抽象化します。これにより、数行のコードで簡単にモデルをトレーニングできるようになります。
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
data_collator=data_collator,
tokenizer=mt5_tokenizer,
compute_metrics=compute_bert_score,
)
モデル学習の準備はすべて整いました。いよいよhuggingfaceライブラリを使ってモデルをトレーニングします。以下のコードを使ってモデルをトレーニングすることができます。
非常にシンプルですね。trainer.train() は、学習プロセスの多くの複雑さを抽象化し、多様なタスクで様々なモデルを学習するための統一されたインターフェースを提供します。これにより、ユーザーは学習ループの詳細や勾配計算などに煩わされることなく、モデルの学習を容易に行うことができます。もちろん、カスタマイズが必要な場合は、トレーナーとそのコンポーネントは細かく制御可能です。
trainer.train()
Epoch | Training Loss | Validation Loss | Bertscore Precision | Bertscore Recall | Bertscore F1 |
---|---|---|---|---|---|
0 | 7.068200 | 1.034466 | 0.706013 | 0.664744 | 0.684008 |
1 | 1.268600 | 0.894248 | 0.712798 | 0.681310 | 0.696195 |
2 | 1.079000 | 0.864152 | 0.712160 | 0.684890 | 0.697850 |
4 | 0.996100 | 0.854636 | 0.713781 | 0.689665 | 0.701086 |
4 | 0.956100 | 0.847545 | 0.715506 | 0.692384 | 0.703341 |
TrainOutput(global_step=920, training_loss=2.273634487649669, metrics={'train_runtime': 2265.4763, 'train_samples_per_second': 13.006, 'train_steps_per_second': 0.406, 'total_flos': 1.556004590321664e+16, 'train_loss': 2.273634487649669, 'epoch': 4.99})
この処理は20分程度かかります。
ステップごとに結果を分析してみましょう:
Training Loss, Validation Loss:
BERTScoreの値:
モデルの学習が終わったら、次はモデルの評価です。モデルを評価するには、以下のコードを使えば簡単です。
Hugging Face Trainerは、検証データセット(指定されている場合はテスト・データセット)でモデルを評価し、計算されたメトリックスを返します。このメソッドは、未知のデータに対するモデルのパフォーマンスを測定するために使用されます。
results = trainer.evaluate()
results_df = pd.DataFrame([results])
results_df
eval_loss | eval_bertscore_precision | eval_bertscore_recall | eval_bertscore_f1 | eval_runtime | eval_samples_per_second | eval_steps_per_second | epoch | |
---|---|---|---|---|---|---|---|---|
0 | 0.847545 | 0.715506 | 0.692384 | 0.703341 | 63.7571 | 11.544 | 1.443 | 4.99 |
モデルをトレーニングし、評価しました。次はモデルをテストする番です。
次のステップでは、モデルを使って予測を生成します。
予測を生成する方法はたくさんあります。
などなど。
ここではBeam Searchを使います。Beam Searchは、限られた集合の中で最も有望なノードを展開することによってグラフを探索するヒューリスティクな探索アルゴリズムです。これはベスト・ファースト探索の最適化であり、必要なメモリを削減することができます。
参考情報:
ここではBeam Searchについての説明はしません。しかし、もし興味があれば、次のブログ記事でBeam Searchやその他の方法についてもっと読むことができます。 https://huggingface.co/blog/how-to-generate
パラメータが異なれば、結果も異なります。従って、いろいろなパラメータを試してみて、それが結果にどう影響するかを見てみるのがいいです。これは正解も不正解もありません。
簡潔にまとめるとこうなります:
関数:
def generate_title(content):
# inputs = [text_clean_preprocess(content)]
inputs = [f"summarize: " + text_clean_preprocess(content)]
batch = mt5_tokenizer.batch_encode_plus(
inputs, max_length=512, truncation=True,
padding="longest", return_tensors="pt")
input_ids = batch['input_ids']
input_mask = batch['attention_mask']
model.eval()
outputs = model.generate(
input_ids=input_ids.cuda(),
attention_mask=input_mask.cuda(),
max_length=64,
temperature=1.1, #
num_beams=6, #24
diversity_penalty=3.0, #1.8
num_beam_groups=3,
num_return_sequences=3,
repetition_penalty=9.0,
# early_stopping=True, #false
# max_new_tokens=64,
# do_sample = True
)
generated_titles = [mt5_tokenizer.decode(ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for ids in outputs]
return generated_titles
print("Total titles in dataset:", len(dataset["test"]["title"]))
selected_index = [75, 140, 286]
for index in selected_index:
print("original: ", dataset["test"]["title"][index])
titles = generate_title(dataset["test"]["content"][index])
for i, title in enumerate(titles):
print(f"Generated title {i+1}: {title}")
print()
Total titles in dataset: 738 original: いいとも!で紹介された「ヒドすぎる」名前が話題に Generated title 1: 『ザザザの斬新!赤ちゃんネーム』で紹介された「キラキラネーム」がありえない Generated title 2: 『ザザザの斬新!赤ちゃんネーム』が「ありえない」とネットニュースで話題 Generated title 3: ネットスラングで「キラキラネーム」がありえないとネットで話題【話題】 original: 日本の引きこもりに海外から相次ぐ心配の声 Generated title 1: 日本の引きこもりが海外で話題に Generated title 2: 日本の引きこもりが海外で話題 Generated title 3: 「日本=出る釘は打たれる」など、日本の引きこもりが海外で話題に original: 甲子園出場する石巻工「約5000万円が必要」呼びかけに物議 Generated title 1: 【Sports Watch】石巻工業高校が総額約5000万円の協賛金を募っている Generated title 2: 【Sports Watch】石巻工業高校が総額約5000万円の協賛金を募っている理由 Generated title 3: 石巻工業高校が、総額約5000万円の協賛金を募っている【話題】
生成されたタイトルは実際のタイトルにかなり似ていることがわかります。これは良い兆候で、モデルが実際のタイトルに似たタイトルを生成するように学習していることを示しています。
また、生成されたタイトルは実際のタイトルと同一ではないことにも注目してください。これは想定内で、モデルは正確なタイトルを生成するようにトレーニングされていないからです。その代わりに、実際のタイトルに似たタイトルを生成するように学習されています。
パラメータが違えば、結果も違ってきます。だから、いろいろなパラメーターを試して、それが結果にどう影響するかを見てみるといいです。
我々のモデルはlivedoorニュースコーパスで学習され、livedoorのニュースでテストされました。では、livedoor以外のニュースにどう反応するか見てみたいと思います。yahooニュースから取得したニュースの内容を貼り付けました。私たちのモデルからの出力を見てみましょう。
# https://news.yahoo.co.jp/pickup/6476740
yahoo_news_1_original_title = """【速報】女川原発2号機「再稼働目標 3か月延期へ」来年5月に<東北電力>"""
yahoo_news_1 = """東北電力は来年2月を目標としていた「女川原発2号機の再稼働」について、来年5月に延期することを明らかにした。今年11月としていた安全対策工事の完了時期が来年2月に延びるため。
東北電力によると、工事が3か月延びるのは、発電所内の設備などにつながるケーブルが火災などで損傷しないようにする「火災防護対策」を追加したことが主な要因。この対策を巡っては、他の電力会社が原子力規制員会から指摘を受けた事例を踏まえ、東北電力では去年10月から追加で工事をすることを準備していた。
その工程を精査した結果、3か月ほど完了時期が延びることが判明したもので、それに伴って女川原発2号機の再稼働目標も3か月延期し、来年5月頃となった。
"""
print("Original yahoo title: ", yahoo_news_1_original_title)
for i, title in enumerate(generate_title(yahoo_news_1)):
print(f"Generated title {i+1}: {title}")
Original yahoo title: 【速報】女川原発2号機「再稼働目標 3か月延期へ」来年5月に<東北電力> Generated title 1: 東北電力、来年2月を目標としていた「女川原発2号機再稼働」を発表 Generated title 2: 東北電力、来年2月を目標としていた「女川原発2号機再稼働」について3か月延期 Generated title 3: 【ニュース】東北電力、来年2月を目標としていた「女川原発2号機再稼働」について3か月延期
もし興味があれば、あなたのメールをコピーペーストして、どのように反応するか見てみるといいです。
sample_email = """
おはようございます。
アナハイム・エレクトロニクス事務局です。
本日は、まもなく締め切りとなります、10/18(水)開催の対面イベント「モルゲンレーテ社製品無料体験会&相談会 @アナハイム・エレクトロニクスカリフォルニア本社」のご案内です!
実際に触れてみないとわからないとお困りの方、ぜひこの機会に体験してみてください。
モルゲンレーテ社製品無料体験会は特に以下のような方にお勧めです。
・電子・電気機器の新規導入でご検討中の方
※小規模から対応できます!規模は問いません。
・電子機器の更改でクラウド移行を検討している方
・顧客管理システムも含めセキュアに電子機器を構築したい方。
モルゲンレーテ社製品を検討はしているけど、実際に触れてみないとわからない、導入を検討しているけど、何から始めればよいか分からないといった課題、お困りごとのある方は、ぜひこの機会に体験してみてください。
また、体験会の後は、弊社エンジニアの個別相談会も予定しております。
皆さまのお悩みなどお気軽にお話いただければと思います。
以下、無料体験会の詳細、申し込み方法をご確認のうえ、ぜひお気軽にご参加ください!
イベントの参加は無料! 参加お申込みは10月16日(月)となります。
みなさまのご参加お待ちしています!
"""
for i, title in enumerate(generate_title(sample_email)):
print(f"Generated title {i+1}: {title}")
Generated title 1: アナハイム・エレクトロニクスカリフォルニア本社にて開催の対面イベント「モルゲンレーテ社製品無料体験会&相談会」を開催 Generated title 2: アナハイム・エレクトロニクスカリフォルニア本社にて開催の対面イベント「モルゲンレーテ社製品無料体験会&相談会】 Generated title 3: 「モルゲンレーテ社製品無料体験会&相談会 アナハイム・エレクトロニクスカリフォルニア本社」
今回のモデルを改良する方法はたくさんあります。以下にそのいくつかを示します。
転移学習:: タイトル生成タスクでファインチューニングする前に、まず関連タスク(例えば要約)でモデルをファインチューニングして、モデルをドメインに慣れさせることができます。
正則化: ドロップアウト、レイヤーの正規化、ウェイトの減衰のようなテクニックは、オーバーフィッティングを防ぐのに役立ちます。