AIによるメールの件名生成¶

1. お知らせ¶

残念なお知らせですが、今日はメールの件名生成はできません。メールには個人情報が多く含まれており、学習に利用するメールデータをみなさんと共有できないからです。

その代わり、ニュース記事の本文からニュースのタイトルを生成してみたいと思います

メールのデータセットから件名を生成する流れ

  • メール件名生成 : メール本文 -> メールの件名

ニュース記事からニュースのタイトルを生成する流れ

  • ニュースタイトル生成 : ニュース記事 -> ニュースのタイトル

コンセプトはほとんど同じです。興味があれば、後で同じコードを使って少し調整するだけで、メールのデータを学習させることが出来ます。サポートが必要な場合は遠慮なくnuwan@qualitia.comまでご連絡ください。

重要事項 1:

  • このドキュメントからgoogleのcolabにコードをコピーペーストします。このとき、コードセルを順番どおりに実行するように気をつけてください。

重要事項 2:

  • コードを実行するとき問題があればすぐに私か平野さんまで連絡してください。できるだけすぐに解決します。でなければ、続きのチュートリアルがうまく動作しなくなります。

1.2. Google ColabでGPUを有効にするには?¶

  • Step 1: Google Colabを開く
    • まず, Google Colab https://colab.research.google.com/ にアクセスしてGoogleのアカウントでログインします。



  • Step 2: 新しい「ノートブック」の作成
    • ログイン後、「ノートブックを新規作成」ボタンをクリックして新しいノートブックを作成します。



  • Step 3: Hardware AcceleratorにGPUを選択します
    • 次に、上部メニューから「ランタイム」を選択し、「ランタイムのタイプ」をクリックします。
    • ランタイムのタイプのデフォルトは 「Python 3」で、これはそのままにしておきます。
    • 「ハードウェアアクセラレータ」の一覧で、「T4 GPU」を選択し、「保存」をクリックします。

1.3. 参考情報について¶


説明中に 参考情報: を点線の間に記述しました。この内容は基本的に無視したいただいて構いません。 家に帰ってから興味があれば参照してみてください。


2. 必要なライブラリのインストール¶

まず、必要なライブラリをインストールしてみましょう。

pipコマンドを使用してライブラリをインストールすることが出来ます。 本プロジェクトではつぎのライブラリを使用します。 Google Colabでシステムのコマンドを実行する場合は、コマンドの先頭に「!」または「%」を付けてください。

使用するライブラリ一覧

  • transformers : Huggingfaceのトランスフォーマーライブラリは学習と推論に使用します。
  • datasets : Huggingfaceのデータセットライブラリはデータセットを読み込むのに使用します。
  • transformers[ja] : テキストデータをトークン化するために必要なmecab、fugashi、ipadicなどをインストールします。
  • sentencepiece : Sentencepieceはテキストデータをトークン化するために使用します。
  • torch : Pytorchはディープラーニングのフレームワークとして使用します。
  • bert_score : Bert_scoreはモデルを評価するために使用します。
  • absl-py : Abseilはログに使用されます。(今回直接は使用しませんが、別のライブラリが使用します)
  • evalute : Evaluateはモデルを評価するために使用します。
  • matplotlib : Matplotlibはグラフを表示するのに使用します。
  • seaborn : Seabornはグラフを表示するのに使用します。
In [1]:
%pip install transformers[torch,ja]==4.33.3 datasets==2.14.5 sentencepiece matplotlib seaborn evaluate absl-py bert_score pandas tokenizers==0.13.3
Collecting transformers[ja,torch]==4.33.3
  Downloading transformers-4.33.3-py3-none-any.whl (7.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.6/7.6 MB 51.3 MB/s eta 0:00:00
Collecting datasets==2.14.5
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 519.6/519.6 kB 52.3 MB/s eta 0:00:00
Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 62.1 MB/s eta 0:00:00
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1)
Requirement already satisfied: seaborn in /usr/local/lib/python3.10/dist-packages (0.12.2)
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.1/84.1 kB 11.4 MB/s eta 0:00:00
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (1.4.0)
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.1/61.1 kB 7.8 MB/s eta 0:00:00
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)
Collecting tokenizers==0.13.3
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 100.9 MB/s eta 0:00:00
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (3.12.4)
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers[ja,torch]==4.33.3)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 302.0/302.0 kB 32.5 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (1.23.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (23.2)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (2023.6.3)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (2.31.0)
Collecting safetensors>=0.3.1 (from transformers[ja,torch]==4.33.3)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 78.6 MB/s eta 0:00:00
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (4.66.1)
Requirement already satisfied: torch!=1.12.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformers[ja,torch]==4.33.3) (2.1.0+cu118)
Collecting accelerate>=0.20.3 (from transformers[ja,torch]==4.33.3)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 261.4/261.4 kB 30.3 MB/s eta 0:00:00
Collecting fugashi>=1.0 (from transformers[ja,torch]==4.33.3)
  Downloading fugashi-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (600 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 600.9/600.9 kB 53.2 MB/s eta 0:00:00
Collecting ipadic<2.0,>=1.0.0 (from transformers[ja,torch]==4.33.3)
  Downloading ipadic-1.0.0.tar.gz (13.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.4/13.4 MB 77.9 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting unidic-lite>=1.0.7 (from transformers[ja,torch]==4.33.3)
  Downloading unidic-lite-1.0.8.tar.gz (47.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 47.4/47.4 MB 16.9 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting unidic>=1.0.2 (from transformers[ja,torch]==4.33.3)
  Downloading unidic-1.1.0.tar.gz (7.7 kB)
  Preparing metadata (setup.py) ... done
Collecting sudachipy>=0.6.6 (from transformers[ja,torch]==4.33.3)
  Downloading SudachiPy-0.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 24.0 MB/s eta 0:00:00
Collecting sudachidict-core>=20220729 (from transformers[ja,torch]==4.33.3)
  Downloading SudachiDict_core-20230927-py3-none-any.whl (71.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.7/71.7 MB 8.8 MB/s eta 0:00:00
Collecting rhoknp<1.3.1,>=1.1.0 (from transformers[ja,torch]==4.33.3)
  Downloading rhoknp-1.3.0-py3-none-any.whl (86 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.8/86.8 kB 11.5 MB/s eta 0:00:00
Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (9.0.0)
Collecting dill<0.3.8,>=0.3.0 (from datasets==2.14.5)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 14.0 MB/s eta 0:00:00
Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (3.4.1)
Collecting multiprocess (from datasets==2.14.5)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 15.5 MB/s eta 0:00:00
Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (2023.6.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.14.5) (3.8.6)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.1.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.43.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.5)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2)
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[ja,torch]==4.33.3) (5.9.5)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (23.1.0)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (3.3.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (6.0.4)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (4.0.3)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (1.9.2)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (1.4.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.14.5) (1.3.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers[ja,torch]==4.33.3) (4.5.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[ja,torch]==4.33.3) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[ja,torch]==4.33.3) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[ja,torch]==4.33.3) (2023.7.22)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (3.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (3.1.2)
Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (2.1.0)
Collecting wasabi<1.0.0,>=0.6.0 (from unidic>=1.0.2->transformers[ja,torch]==4.33.3)
  Downloading wasabi-0.10.1-py3-none-any.whl (26 kB)
Collecting plac<2.0.0,>=1.1.3 (from unidic>=1.0.2->transformers[ja,torch]==4.33.3)
  Downloading plac-1.4.1-py2.py3-none-any.whl (22 kB)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.10->transformers[ja,torch]==4.33.3) (1.3.0)
Building wheels for collected packages: ipadic, unidic, unidic-lite
  Building wheel for ipadic (setup.py) ... done
  Created wheel for ipadic: filename=ipadic-1.0.0-py3-none-any.whl size=13556703 sha256=c734594f6ebc044bf9e2ba4e0351f66bffea4a62d1c1c56f2c5445abd1449190
  Stored in directory: /root/.cache/pip/wheels/5b/ea/e3/2f6e0860a327daba3b030853fce4483ed37468bbf1101c59c3
  Building wheel for unidic (setup.py) ... done
  Created wheel for unidic: filename=unidic-1.1.0-py3-none-any.whl size=7406 sha256=21c1c4b149b24a152647c2b7122b508f03f674ab6c7275fbd942ecf8992bfd0b
  Stored in directory: /root/.cache/pip/wheels/7a/72/72/1f3d654c345ea69d5d51b531c90daf7ba14cc555eaf2c64ab0
  Building wheel for unidic-lite (setup.py) ... done
  Created wheel for unidic-lite: filename=unidic_lite-1.0.8-py3-none-any.whl size=47658816 sha256=bbdeafc6a0e71520c94eb7be23a273d26c3e3b6792b08dcaba6da4902f71935e
  Stored in directory: /root/.cache/pip/wheels/89/e8/68/f9ac36b8cc6c8b3c96888cd57434abed96595d444f42243853
Successfully built ipadic unidic unidic-lite
Installing collected packages: wasabi, unidic-lite, tokenizers, sudachipy, sentencepiece, plac, ipadic, sudachidict-core, safetensors, rhoknp, fugashi, dill, unidic, responses, multiprocess, huggingface-hub, transformers, accelerate, datasets, bert_score, evaluate
  Attempting uninstall: wasabi
    Found existing installation: wasabi 1.1.2
    Uninstalling wasabi-1.1.2:
      Successfully uninstalled wasabi-1.1.2
Successfully installed accelerate-0.24.1 bert_score-0.3.13 datasets-2.14.5 dill-0.3.7 evaluate-0.4.1 fugashi-1.3.0 huggingface-hub-0.18.0 ipadic-1.0.0 multiprocess-0.70.15 plac-1.4.1 responses-0.18.0 rhoknp-1.3.0 safetensors-0.4.0 sentencepiece-0.1.99 sudachidict-core-20230927 sudachipy-0.6.7 tokenizers-0.13.3 transformers-4.33.3 unidic-1.1.0 unidic-lite-1.0.8 wasabi-0.10.1

3. GPUが利用できるか確認¶

さて、始める前にGPUが利用可能かどうか確認します。 次のコードで確認できます。

In [2]:
import torch

if torch.cuda.is_available():
    status = "GPU is enabled."
    device_count = torch.cuda.device_count()
    current_device = torch.cuda.current_device()
    print(f"{status}\ndevice count: {device_count}, current device: {current_device}")
else:
    print("GPU is disabled.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
GPU is enabled.
device count: 1, current device: 0
device: cuda

重要 !!!:

  • もし次のような出力が出来たら、私か平野さんにすぐに連絡してください。
GPU is disabled.
Device: cpu

4. シードの設定 (乱数の制御)¶

機械学習やディープラーニングでは結果の再現性を確保するために乱数のシードの設定が重要になります。 今回のプロジェクトでは42を使用します。

これらのシードを設定しても、異なるプラットフォームやGPUアーキテクチャ間で再現性を実現することは依然として困難であることに注意してください。

In [3]:
import random
import numpy as np
import torch

import warnings
warnings.filterwarnings('ignore')
def seed_everything(seed_value):
    random.seed(seed_value) # Python
    np.random.seed(seed_value) # Numpy
    torch.manual_seed(seed_value) # CPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # GPU if available
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_value = 42
seed_everything(seed_value)

5. Data¶

先ほどお知らせしたように、ニュースのデータセットを使用します。データセットは日本語のLivedoorニュースのコーパスです。

5.1 データの読み込み¶

データセットをロードするには、load_dataset関数を使います。 load_datasetを使用して、Hugging Face Hubからデータセットをロードします。この関数は、多くのデータセットに簡単にアクセスする方法を提供します。また、トレーニング、検証、テストセットへの分割など、標準的なデータ処理を実現することができます。

要約すると、以下のコードは "livedoor-news-corpus"データセットをロードし、シャッフルなし、固定ランダムシードで、トレーニング(80%)、検証(10%)、テスト(10%)のセットに分割します。

参考情報:

  • llm-book/livedoor-news-corpus: データセットの識別子や場所

  • train_ratio=0.8: データセットの80%を学習に割り当てる

  • validation_ratio=0.1: データセットの10%を検証に割り当てる

  • seed=42: 再現性確保のために乱数のシードに固定値を使用します

  • shuffle=False: 学習データ、検証データ、テストデータに分割する際にシャッフルしないようにします


In [4]:
from datasets import load_dataset

dataset = load_dataset(
    "llm-book/livedoor-news-corpus",
    train_ratio = 0.8,
    validation_ratio = 0.1,
    seed=42,
    shuffle=False,

)
Downloading builder script:   0%|          | 0.00/3.88k [00:00<?, ?B/s]
Downloading readme:   0%|          | 0.00/823 [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/8.86M [00:00<?, ?B/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating validation split: 0 examples [00:00, ? examples/s]
Generating test split: 0 examples [00:00, ? examples/s]
「dataset」と実行して中身がどうなってるか確認してみましょう。
In [5]:
dataset
Out[5]:
DatasetDict({
    train: Dataset({
        features: ['url', 'date', 'title', 'content', 'category'],
        num_rows: 5893
    })
    validation: Dataset({
        features: ['url', 'date', 'title', 'content', 'category'],
        num_rows: 736
    })
    test: Dataset({
        features: ['url', 'date', 'title', 'content', 'category'],
        num_rows: 738
    })
})
これはdatasetsライブラリのcustom dictionaryオブジェクトです。学習データ、検証データのように分割されてデータセットをまとめて監理します。

それぞれのデータセットは次のような要素で構成されています。

  • url: ニュース記事が掲載されたURL。
  • date: ニュース記事が公開された日付。
  • title: ニュース記事のタイトルまたは見出し。
  • content: ニュース記事の本文。
  • category: ニュース記事が該当するカテゴリーまたはトピック。

5.2 Simple Data Exploration¶

とりあえず、テストデータをpandasのデータフレームにロードし、それがどのように見えるか確認してみましょう。

dataframeは2次元のラベル付きデータ構造で、異なるタイプの列を持つことができます。Excelの表のようなものだと考えてください。

要約すると、このコードはpandasライブラリをインポートし、データセットの'test'部分をpandas DataFrameに変換し、そのDataFrameの最初の5行を表示しています。


参考情報:

  1. import pandas as pd: この行はpandasライブラリをインポートし、pdという別名で使えるようにします。Pandasは、データ分析や操作、特に表形式のデータで広く使われているPythonライブラリです。

  2. test_df = pd.DataFrame(dataset['test']): This line does two main things:

    • dataset['test']: データセットの「test」部分を取り出します。データセットは辞書のような構造になっていて、キーが 'test' のデータを取り出します。
    • pd.DataFrame(...): これは、抽出された 'test' データを pandas DataFrame に変換します。DataFrame は、2 次元のラベル付きデータ構造で、データベースのテーブルや Excel のスプレッドシート、R のデータフレームに似ています。


  1. test_df.head(): この行は、test_df DataFrame の最初の5行を表示します。pandasのhead()メソッドは、DataFrameの最初の行を素早く確認するのに便利です。

In [6]:
import pandas as pd
test_df = pd.DataFrame(dataset['test'])
test_df.head()
Out[6]:
url date title content category
0 http://news.livedoor.com/article/detail/5936102/ 2011-10-14T09:11:00+0900 ゼンショー「事実無根」と反論 10月13日の夜、ゼンショーの広報室長がTwitterで読売新聞の報道に「事実無根」と反論し... topic-news
1 http://news.livedoor.com/article/detail/5936557/ 2011-10-14T11:16:00+0900 「報ステ」OP曲演奏のジャズミュージシャンに“売名行為”と批判相次ぐ 先日、福島県が行っている新米の放射性物質本検査が全て終了した。規制値を超える放射性セシウムは... topic-news
2 http://news.livedoor.com/article/detail/5936721/ 2011-10-14T11:46:00+0900 「何のための“予約”なんですか」孫社長に批判殺到 ソフトバンクは、今朝から発売が始まった“iPhone4S”をはじめとする、ソフトバンク全ての... topic-news
3 http://news.livedoor.com/article/detail/5937177/ 2011-10-15T10:00:00+0900 あまりにも多すぎる「会いたくて」への皮肉か!? 「西野カナゲーム」が流行 いま巷で「西野カナゲーム」なるものが流行してるという。 簡単に説明すると、1人目、2人目は... topic-news
4 http://news.livedoor.com/article/detail/5937649/ 2011-10-14T16:03:00+0900 憶測呼ぶ紳助さんの“天敵”引退 10月13日発売の東スポに「紳助の天敵が引退」との見出しが躍った。その天敵とは、警察庁の安藤... topic-news
これでデータフレーム内のデータ見ることができました。中身を確認することでデータセットに関するよいアイデアを得ることができます。

¶

6. Model and Tokenizer¶

このタスクではgoogleのmt5-smallモデルを使用します。

mt5とは?¶

mt5はgoogleが公開したText-to-textトランスフォーマーモデルです。多言語モデルで、翻訳、要約、分類、その他多くのタスクに使用できます。

Text to text トランスフォーマーとは?¶

Text-to-textトランスフォーマーとは、テキストを入力してテキスト出力するような、さまざまなタスクに使用できるトランスフォーマーモデルのクラスです。例えば、

  • 例えば、質問を入力して、答えを出力する
  • 英語の文章を入力して、フランス語を出力する
  • ニュース記事を入力して、見出しを出力する

など。

今回はm5-smallを使用しますが、ほかにもたくさんの同じようなモデルがあります。例えば、mt5-baseです。 mt5-baseはmt5-smallより大きなモデルで、もっとたくさんのパラメータがあり学習時間もかかりますが、学習能力は高いです。

6.1. Set Tokenizer¶

モデルのトークナイザをロードするためにAutoTokenizerを使用します。 次のコードは mt5-small モデルのトークナイザをロードします。

トークナイザとは?¶

トークナイザとはとは文字列をトークンに分割する機能です。例えば、"Hello world!"という文字列は['Hello', 'world', '!']というトークンのリストに分割されます。

このコードでは、transformersライブラリからAutoTokenizerクラスをインポートし、"google/mt5-small"モデルを指定して、関連するトークナイザをmt5_tokenizerという変数にセットしますします。このトークナイザーはmT5 smallモデルのテキスト入力時に入力をトークン化します。


参考情報:

コードの内訳を見てみましょう:

  1. from transformers import AutoTokenizer: transformersライブラリからAutoTokenizerクラスをimportします。このクラスは、与えられたモデルに対して適切なトークナイザを自動的に取得するように設計されています。

  2. MODEL_NAME = "google/mt5-small": この行は、定数MODEL_NAMEにGoogleが提供する小さなバージョンであるmT5(多言語(multilingual) T5)モデルの識別子を設定します。T5(Text-to-Textトランスフォーマー)は、様々なNLPタスクのために設計されたトランスフォーマーベースの一般的なモデルで、多言語バージョン(mT5)は複数の言語で学習されています。

  3. mt5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME): この行はmT5モデルに関連付けられたトークナイザを初期化します。from_pretrainedメソッドは、指定されたモデル(google/mt5-small)のトークナイザをHugging Faceのモデルハブから取得し、mt5_tokenizer変数に代入します。


In [7]:
from transformers import AutoTokenizer

MODEL_NAME = "google/mt5-small"

mt5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
mt5_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
(…)small/resolve/main/tokenizer_config.json:   0%|          | 0.00/82.0 [00:00<?, ?B/s]
(…)oogle/mt5-small/resolve/main/config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]
(…)ogle/mt5-small/resolve/main/spiece.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]
(…)all/resolve/main/special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565

トークナイザ詳細¶

トークナイザーは、特にMT5のようなトランスフォーマベースのモデルを扱う場合、自然言語処理パイプラインにおいて非常に重要なコンポーネントです。これは、人間が読めるテキストからモデルが期待するフォーマットへの変換をします。

トークナイザの使い方:¶

  • Encoding: テキスト(この場合は日本語)をモデルが理解できる形式に変換します。これには、テキストを単語や形態素などに分割し、モデルか理解できるIDにマッピングすることが含まれます。

  • Decoding: モデルの出力(トークンID)を人間が読めるテキストに変換します。

Example:

たとえば、次のような日本語の文章があるとしましょう。

「本日はAIトレーニングセッションへようこそ!」

In [8]:
text = "本日はAIトレーニングセッションへようこそ!"

encoded_text = mt5_tokenizer(text)
print("Encoded text: ", encoded_text)

tokenized_text = mt5_tokenizer.tokenize(text)
print("Tokenized text: ", tokenized_text)

decoded_text = mt5_tokenizer.decode(encoded_text["input_ids"], skip_special_tokens=True)
print("Decoded Text: ", decoded_text)
Encoded text:  {'input_ids': [259, 212152, 15428, 96992, 191286, 6031, 15578, 68875, 309, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Tokenized text:  ['▁', '本日は', 'AI', 'トレーニング', 'セッション', 'へ', 'よう', 'こそ', '!']
Decoded Text:  本日はAIトレーニングセッションへようこそ!
さて、文章が小さなパーツに分割されたと思います。このパーツがトークンと呼ばれます。 また、トークンはそれぞれユニークなIDを持っています。

MT5トークナイザは「本日はAIトレーニングセッションへようこそ!」という文章を単語などのモデルが理解できる単位に分割しました。 それぞれのトークンにはユニークなIDが割り振られていて、decodeするときにはこれを逆に利用して元の文章に戻します。トークン化とdecodeは可逆的に設計されていて、元の文章は完全に復元可能です。


参考情報:

Encoded Text:

  • input_ids: これらはそれぞれのトークンに割り当てられたIDです。 [259, 212152, 15428, 96992, 191286, 6031, 15578, 68875, 309, 1]の並びはTokenized textのトークンと対応します。

  • attention_mask: この[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]はinput_idsに含まれる全てのトークンがモデルに対応していることを示しています。

Tokenized Text:

  • Tokenized textは['▁', '本日は', 'AI', 'トレーニング', 'セッション', 'へ', 'よう', 'こそ', '!']となりました。これが、MT5 tokenizerが入力文章を分割したトークンです。

    • '▁': '▁' (アンダースコア)はMT5が使用しているSentencePieceのようなトークナイザでは空白を意味します。日本語のような空白のない言語では文頭か文中か、語頭か語中かを区別するのに役立ちます。
    • '本日は': 切り出されたトークンです。形態素解析で分割した場合は「本日」「は」はそれぞれ別々の形態素として分割されますが、機械学習のトークナイザの場合には、必ずしも単語に分割されるとは限りません。
    • 'よう': こちらも同様です。形態素解析であれば、「ようこそ」で1つの分割としたいところですが、「よう」「こそ」に分割されてそれぞれ別々のトークンとして解析されています。
    • '!': 元の文章では「!」(全角)でしたが、トークン化された結果は「!」半角になっています。このように汎化される場合もあります。

Decoded Text:

  • Decoded Textは分割されたトークンIDを逆引きして結合したものです。元の文章「本日はAIトレーニングセッションへようこそ!」に戻りました。ただし、'▁' は削除しています。また、「!」はトークンとして使用された半角の「!」に変わっています。

6.2 ニュース記事やタイトルのトークン数の分布¶

次のステップ(モデル用の入力テキストを準備する)では、トークン化されたテキストとタイトルの最大長を決定する必要があります。 トークン化された記事やタイトルの最大長かわからないので調べる必要があります。 長さが長いほど良い結果が得られますが、トレーニングに時間がかかります。そこで、長さと学習時間のバランスのいいところを見つける必要があります。

In [9]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def dist_info(text, max_bin_size=1024):
    token_counts = [len(mt5_tokenizer.tokenize(content)) for content in text]

    sns.histplot(token_counts, bins=100, binrange=(0, max_bin_size))
    plt.title("Tokenized Text Length Distribution")
    plt.xlabel("Tokenized Text Length")
    plt.ylabel("Count")
    plt.show()

    percentiles = [25, 50, 75, 90, 95, 99]
    for p in percentiles:
        print(f"{p}% of the dataset has token count below: {np.percentile(token_counts, p)}")

このコードは分布を出力する関数を定義しただけですので、まだ何も出力されません。

6.2.1 記事のトークンの長さの分布¶

上の関数を実行してトークンの分布を確認してみましょう。

In [10]:
dist_info(dataset["train"]["content"])
25% of the dataset has token count below: 398.0
50% of the dataset has token count below: 589.0
75% of the dataset has token count below: 815.0
90% of the dataset has token count below: 1066.0
95% of the dataset has token count below: 1302.0
99% of the dataset has token count below: 1969.08
トークン化されたコンテンツを512と1024の長さでそれぞれを別々に分析してみましょう。
  1. 最大長512にした場合:
    • 長所:
      • 学習や評価が速い
      • メモリをあまり使わない
    • 短所:
      • 中央値が519であるため、データセットの約50%が捨てられてしまう。

  2. 最大長1024にした場合:
    • 長所:
      • データセットの90%を利用できる
      • メモリと計算量の点も比較的効率的だ
    • 短所:
      • 長さが長くなるため、512に比べて学習が遅くなる可能性がある
      • メモリ使用量が増えるため、バッチサイズを小さくする必要があるかもしれない

  1. 最大長1024以上にした場合:
    • 長所:
      • データセットを捨てることなく利用できる
    • 短所:
      • 計算量が大幅に増加する
      • メモリ使用量が大幅に増加する
      • 1024より長い配列を扱えないモデルもある

上記により、オススメは:

  • 計算に制約がある場合(GPUメモリが限られているなど)、max_lengthを1024にすることを検討してください。これでデータの90%をカバーできます。メモリ不足のエラーを防ぐために、バッチサイズを調整する必要があるかもしれません。

  • 計算量の制約がなく、データセットを切り捨てずにできるだけ多くのデータを確実に利用したい場合は、1024以上の長さを検討するといいでしょう(例えば、95パーセンタイルまでカバーするには1280)。ただし、1024より長い配列はすべてのT5モデルでサポートされない可能性があることに注意してください。利用するモデルが1024より長い配列を扱えることを確認する必要があります。

今回は時間の都合で最大長512を使用します。あとで時間があれば1024でも試してみてください。もっといい結果になるかも知れません。

6.2.2 タイトルのトークン数の分布¶

さっきの関数を使用して、タイトルのトークン数の分布を確認してみましょう。

本文でさえ90%の範囲か1024トークンに収まるので、タイトルはおそらく128トークン以下だろう、と予測して、max_bin_sizeを128で実行してみます。

In [11]:
dist_info(dataset["train"]["title"], max_bin_size=128)
25% of the dataset has token count below: 16.0
50% of the dataset has token count below: 21.0
75% of the dataset has token count below: 25.0
90% of the dataset has token count below: 29.800000000000182
95% of the dataset has token count below: 32.0
99% of the dataset has token count below: 40.0

上記の結果により以下のように分析します。

  1. 最大長32の場合:
    • 長所:
      • タイトルのデータセットの90%をカバーします
      • メモリや計算量で効率がよさそうです
    • 短所:
      • 10%のタイトルは切り捨てられてしまいます。



2. 最大長48~50の場合: - Pros: - データセットの99%がカバーされます - これでも、メモリや計算量で効率は比較的よさそうです - Cons: - かなり短いタイトルの場合、若干パディングが多くなるが、これらは全体的に少ないことを考えれば、これは些細な問題である

以上を考慮すると:

  • タイトルのmax_lengthを32にするのは、効率性を考えるとよい選択でしょう。

  • ほぼすべてのタイトルが切り捨てられないようにしたい場合は、max_lengthを48または50にするとよいでしょう。

今回は、タイトルの最大長として64を使うつもりです。切り捨てなしでタイトルの99%(あるいはおそらく99.99%)をカバーでき、メモリと計算の面でも比較的効率的です

さて、上記の値 512, 64をセットしましょう。

In [12]:
SOURCE_MAX_LEN = 512
TARGET_MAX_LEN = 64

7. 前処理とNormalization¶

生データには、無関係な情報、HTMLタグ、特殊文字などのノイズが含まれることがあります。そのようなノイズを除去またはクリーニングすることで、モデルがデータの本質的な部分に集中できるようにします。

特定の文字やシーケンスは、トークン化やモデリング処理において特別な意味を持つ場合があります。これらは前処理で適切に処理するか、エスケープする必要があります。

次の関数は、非常に単純な前処理を行います。高度なモデルを作るときには、もっと高度な前処理が必要になるかもしれませんが、今回は、非常に単純な前処理のみ行います。

text_clean_preprocess関数はテキスト入力を受け取り、それに対していくつかのクリーニング処理を行います。newlineの改行文字(\n)は空文字列に、全角スペースは半角スペースに、タブ文字は空文字列に、キャリッジリターン文字(\r)は空文字列に置き換えます。最後に、テキストを小文字に変換し、クリーニングされたテキストを返します。

In [13]:
NEWLINE_CHAR = "\n"
SPACE_CHAR = "\u3000"
TAB_CHAR = "\t"
CARRIAGE_RETURN_CHAR = "\r"

def text_clean_preprocess(text, newline_char=NEWLINE_CHAR, space_char=SPACE_CHAR, tab_char=TAB_CHAR, carriage_return_char=CARRIAGE_RETURN_CHAR):
    text = text.replace(newline_char, "")
    text = text.replace(space_char, " ")
    text = text.replace(tab_char, "")
    text = text.replace(carriage_return_char, "")
    text = text.lower()

    return text
さて、次はモデルが期待するデータを準備しましょう。

mT5のようなモデルでは、特定の形式、つまりトークンIDのシーケンスでの入力を期待します。この関数は、生テキストをそのようなシーケンスにトークン化します。

次の関数は、要約モデルを学習するためのデータを準備し、構造化します。このモデルでは、テキストの内容からタイトルを予測または生成するために使用されます。

参考情報:

  • Consistent Data Length: ニューラル・モデルは通常、固定長の入力を必要とします。配列をSOURCE_MAX_LEN と TARGET_MAX_LEN になるように、パディングしたり切り詰めたりすることで、入力とターゲットの長さが固定長になるようにします。

  • Tokenization: 前処理されたコンテンツ(入力)とdata["title"](ターゲット)のタイトルは、mt5_tokenizerを使用してトークンIDに変換されます。

  • Attention Masking: attention maskは実際のトークン(1)とパディングされたトークン(0)を区別し、モデルが入力の意味のある部分に集中できるようにします。


In [14]:
def tokenize_data(data):

    input_text = [text_clean_preprocess(content) for content in data["content"]]

    target_text = data["title"]

    inputs = mt5_tokenizer(input_text, max_length=SOURCE_MAX_LEN, truncation=True, padding="max_length")
    targets = mt5_tokenizer(target_text, max_length=TARGET_MAX_LEN, truncation=True, padding="max_length")
    label_attention_mask = [1] * len(targets["input_ids"])

    return {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "decoder_input_ids": targets["input_ids"],
        "labels": targets.input_ids,
        "label_attention_mask": label_attention_mask

    }

前のステップでは、データを変換する関数を作成しました。今度はその関数をデータセットに適用してみましょう。 データセットに関数を適用するにはmap関数を使用します。

要約すると、次のコードでは、さっき定義したtokenize_data関数を使用してデータセットの前処理とトークン化を行います。そして、機械学習タスクに適したデータセットにするために、特定の列を削除します。

参考情報:

  • mapメソッドはtokenize_data関数を使用してデータセットをバッチ処理します。特に大規模なデータセットでは効率がいいです。
  • remove_columnsパラメータは、マッピング後に削除する列を指定します。トークン化されたデータのみを残し、余計な列を削除します。
  • tokenized_datasetは、トークン化された内容で、remove_columnsで指定された列を除いた、元のデータを処理したものとなります。

In [15]:
tokenized_dataset = dataset.map(
    tokenize_data,
    batched=True,
    remove_columns=["content", "title", "url", "date", "category"]
)
Map:   0%|          | 0/5893 [00:00<?, ? examples/s]
Map:   0%|          | 0/736 [00:00<?, ? examples/s]
Map:   0%|          | 0/738 [00:00<?, ? examples/s]

8. モデルの読み込み¶

さて、いよいよ事前に学習したモデルをロードします。上で述べたように、我々はmt5-smallモデルを使用します。 以下のコードは、"google/mt5-small "モデルをロードします。一度ロードされたモデルは、翻訳、要約、テキスト生成など様々なタスクに使うことができる。

一度もモデルをダウンロードしたことのない環境でこれを実行する場合、from_pretrainedメソッドはHugging Faceモデルハブからモデルの重みをダウンロードします。この場合、インターネット接続が有効であることを確認してください。


参考情報:

  • AutoModelForSeq2SeqLM: これは、提供されたモデル名またはパスに基づいて、適切なSequence to sequenceモデル(T5、BERT、GPT-2など)を自動的に推論し、ロードするように設計されたクラスです。モデルの具体的なアーキテクチャはわからないが、名前かパスは知っているという場合に便利です。

  • from_pretrained: これは、与えられた名前もしくはパスに基づいてモデルをロードするクラスメソッドです。モデルの重みと設定をダウンロードし、適切なモデルクラスのインスタンスを返します。


In [16]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)
pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]
(…)mall/resolve/main/generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]
次に、HuggingfaceのData Collatorを用意します。

今回は、テキスト(コンテンツ)の並び(Sequence)を別のテキストの並び(タイトル)に変換するトレーニングを行うので、DataCollatorForSeq2Seqを使います。

DataCollator関数は、データセットからサンプルのリストを受け取り、それらをTensorsの辞書としてバッチに照合します。

Data Collatorは以下のような動作をします:

  • 必要であれば、入力を同じ長さにパディングまたは切り詰める。
  • 入力をモデルに適したデータ型と形式に変換する。
  • 入力をバッチ処理する。
  • 入力に含まれる特殊なキー(ラベルなど)を処理する。

seq2seqモデルを微調整する際にDataCollatorForSeq2Seqを使用するのは良いプラクティスなので、学習のワークフローに組み込むことが推奨されます。

要約すると、以下のコードは、コンピュータプログラム(モデル)が会話の中で自分自身の応答を理解し、生成できるように、テキストデータを準備し整理するツール(data_collator)をセットアップしています。

In [17]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(
    tokenizer=mt5_tokenizer,
    model=model,
)

9. 評価の指標¶

モデルを学習する際には、検証用データセットでその性能を評価することが重要です。

これは、モデルがどの程度学習しているのか、学習不足なのか過学習なのかを理解するのに役立ちます。また、異なるモデルを比較し、最適なモデルを選択するのにも役立ちます。 評価を測定するために、多くのメトリクスがあります。例えば、BLEU、ROUGE、METEORなどです
今回はBERTScoreを使用します。

9.1 Bert Score¶

BERTScoreは2つのテキストの類似度を測る指標です。

BERTScoreはトランスフォーマベースのモデルとして有名なBERTモデルに基づいており、要約や翻訳などのテキスト生成タスクを評価するのに適しています。

先に進む前に、BERTスコアの使い方を確認しましょう。

この関数は、BERTScore メトリックを使用して、生成された文章とラベル付きの文章の平均BERTスコアを計算します。

In [18]:
import evaluate

def compute_bert_score_demo(preds, labels):
    bert_score_metric = evaluate.load("bertscore")
    bert_score_metric.add_batch(predictions=preds, references=labels)
    result = bert_score_metric.compute(lang="ja")
    avg_scores = {k: sum(v) / len(v) for k, v in result.items() if k != "hashcode"}

    return avg_scores

original = "リンゴが好きです。"
candidate_1 ="リンゴが大好きです。"
candidate_2 = "赤い車を買うつもりです。"

bs_results = {
    "候補1": compute_bert_score_demo([candidate_1], [original]),
    "候補2": compute_bert_score_demo([candidate_2], [original]),
}

bs_df = pd.DataFrame(bs_results).T
bs_df
Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]
(…)cased/resolve/main/tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]
(…)tilingual-cased/resolve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]
(…)ultilingual-cased/resolve/main/vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]
model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]
Out[18]:
precision recall f1
候補1 0.966469 0.986506 0.976384
候補2 0.680357 0.723368 0.701203
結果を見ると、候補1は候補2よりBERTScoreが高く、候補1の方が元の文により類似していると言えます。実際、候補1は元の文により近く、候補2は全く異なる文であるため、正しく評価されたようです。
では、compute_bert_score_demoをリファクタリングして、学習済みモデルの評価に再利用できる関数にしてみましょう。

compute_bert_score関数は、推論した値と元のラベルの両方を含むeval_preds変数を引数として受け取ります。次に、mt5_tokenizer.batch_decode関数を使用して推論値とラベルをデコードします。最後に、bert_score_metric.compute関数を使用してBERTスコアを計算します。

In [19]:
def compute_bert_score(eval_preds):
    bert_score_metric = evaluate.load("bertscore")
    predictions, labels = eval_preds

    decoded_preds = mt5_tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = mt5_tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = bert_score_metric.compute(
        predictions=decoded_preds, references=decoded_labels, lang="ja"
    )

    return {
        "bertscore_precision": sum(result["precision"]) / len(result["precision"]),
        "bertscore_recall": sum(result["recall"]) / len(result["recall"]),
        "bertscore_f1": sum(result["f1"]) / len(result["f1"])
    }

10. Training¶

これでついにモデルの評価に使える関数ができました。

  • 次はトレーニング引数を定義します。

10.1 トレーニング引数¶

Seq2SeqTrainingArguments

これはtransformersライブラリが提供するクラスで、sequence-to-sequenceモデル(要約モデルのように、あるシーケンスを別のシーケンスに変換するモデル)のトレーニング関連の引数やハイパーパラメータを定義することができます。これはTrainingArgumentsのサブクラスで、トレーニング関連の引数を定義するための汎用クラスです。

パフォーマンスを向上させるために使用できるハイパーパラメータは他にもたくさんありますが、これらはデータセット、モデル、タスクに依存します。さらに、最適なハイパーパラメータを見つけるための専用のハイパーパラメータ・チューニング手法もあります。 今回は、トレーニング時間とパフォーマンスのバランスが重要なので、以下のハイパーパラメータを使うことにします。

要約するとこのコードはモデルのコンフィグを設定しています。

参考情報:

  • per_device_train_batch_size, per_device_eval_batch_size: 学習と評価のバッチサイズ。バッチサイズとは、モデルが重みを更新する前に処理するデータサンプルの数です。

  • learning_rate: これは初期学習率です。

  • lr_scheduler_type, warmup_ratio: これらは学習率のスケジュールを定義します。"linear"は学習率がエポックごとに直線的に減少することを意味します。warmup_ratioは、学習率が直線的に増加してから減衰を開始するまでの学習ステップの割合を定義します。

  • num_train_epochs: モデルを訓練するエポック数。「一つの訓練データを何回繰り返して学習させるか」の数のことです。

  • evaluation_strategy, save_strategy, logging_strategy: これらはそれぞれ、評価、保存、ログのタイミングを定義します。"epoch "は、これらのアクションが各エポックの終了時に実行されることを意味します。

  • logging_steps: トレーニングメトリクスのログの頻度(100ステップごと).

  • logging_dir: ログの保存ディレクトリ

  • do_train, do_eval: それぞれ学習や評価を行うべきかどうかを判断するためのフラグ。

  • output_dir: モデルのチェックポイントなどのトレーニングの出力が保存されるディレクトリ。

  • save_total_limit: 保存されるモデルのチェックポイントの最大数。古いチェックポイントは削除されます。

  • load_best_model_at_end: Trueの場合、評価指標に従った最良のモデルがトレーニングの最後にロードされます。

  • push_to_hub: Trueの場合、モデルはHugging Face Model Hubにプッシュされます。

  • predict_with_generate: シーケンスを生成するためにpredictメソッドがgenerateメソッドを使用するようにします。

  • コメントアウトされたパラメータ(optim, gradient_accumulation_steps, weight_decay, fp16)は、さらに学習をカスタマイズするために使用できる追加のオプション引数です。例えば、fp16=Trueを指定すると、半精度(16ビット)浮動小数点のトレーニングが有効になり、より高速なパフォーマンスが得られる可能性があります。


In [20]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

NUM_EPOCHS = 5
LEARNING_RATE = 5e-4
WARMUP_RATIO = 0.1
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 4
WEIGHT_DECAY = 0.01
LOGGING_DIR = "./logs"
OUTPUT_DIR = "./results"

training_args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,

    learning_rate=LEARNING_RATE,
    # lr_scheduler_type="linear", #comment on colab
    # warmup_ratio=0.1, #comment on colab

    num_train_epochs=NUM_EPOCHS,

    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    logging_steps=100,

    logging_dir=LOGGING_DIR,
    do_train=True,
    do_eval=True,
    output_dir=OUTPUT_DIR,

    save_total_limit=2,
    load_best_model_at_end=True,

    push_to_hub=False,
    predict_with_generate=True,

    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
)

10.2 Trainer¶

以下のコードでは、さっき定義したmodel、training_args、data_collatorを使用してTrainerインスタンスを初期化しています。Trainerはモデルの学習と評価を行います。


参考情報:

  • Trainer: これはtransformersライブラリが提供するクラスで、モデルの学習と評価を処理します。バッチ、ログ、チェックポイントの保存など、トレーニングの低レベルの機能の多くを抽象化した高レベルのAPIです。
  • model: これは学習されるモデルです
  • args: これはトレーニング関連の引数を含むSeq2SeqTrainingArgumentsのインスタンスです。
  • data_collator: データサンプルをバッチに照合するDataCollatorForSeq2Seqのインスタンスです。
  • compute_metrics: これは評価指標を計算する関数です。今回は、先に定義したcompute_bert_score関数です。
  • train_dataset: 学習用データセットです。
  • eval_dataset: 評価用のデータセットです。
  • tokenizer: トークナイザです。

Trainerクラスは、バッチ処理、ロギング、チェックポイントの保存など、トレーニングの低レベルな詳細の多くを抽象化します。これにより、数行のコードで簡単にモデルをトレーニングできるようになります。

In [21]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=mt5_tokenizer,
    compute_metrics=compute_bert_score,
)

10.3 Model Training¶

モデル学習の準備はすべて整いました。いよいよhuggingfaceライブラリを使ってモデルをトレーニングします。以下のコードを使ってモデルをトレーニングすることができます。

非常にシンプルですね。trainer.train() は、学習プロセスの多くの複雑さを抽象化し、多様なタスクで様々なモデルを学習するための統一されたインターフェースを提供します。これにより、ユーザーは学習ループの詳細や勾配計算などに煩わされることなく、モデルの学習を容易に行うことができます。もちろん、カスタマイズが必要な場合は、トレーナーとそのコンポーネントは細かく制御可能です。

In [22]:
trainer.train()
[920/920 37:43, Epoch 4/5]
Epoch Training Loss Validation Loss Bertscore Precision Bertscore Recall Bertscore F1
0 7.068200 1.034466 0.706013 0.664744 0.684008
1 1.268600 0.894248 0.712798 0.681310 0.696195
2 1.079000 0.864152 0.712160 0.684890 0.697850
4 0.996100 0.854636 0.713781 0.689665 0.701086
4 0.956100 0.847545 0.715506 0.692384 0.703341

Out[22]:
TrainOutput(global_step=920, training_loss=2.273634487649669, metrics={'train_runtime': 2265.4763, 'train_samples_per_second': 13.006, 'train_steps_per_second': 0.406, 'total_flos': 1.556004590321664e+16, 'train_loss': 2.273634487649669, 'epoch': 4.99})

この処理は20分程度かかります。

ステップごとに結果を分析してみましょう:

  1. Training Loss, Validation Loss:

    • Training Lossは、最初のエポックから2番目のエポックまで大きく減少し、2番目のエポックから3番目のエポックまでわずかに減少しています。これはモデルの学習が進んでいることを示しています。
    • Validation Lossもエポックごとに減少しています。これはモデルが検証データに対してうまく一般化されていることを示しています。

  2. BERTScoreの値:

    • BERTScoreのPrecision、Recall、F1はいずれもエポックごとに増加しています。これは、モデルの生成したタイトルが、元のタイトルに似てきていることを示しています。

11. モデルの評価¶

モデルの学習が終わったら、次はモデルの評価です。モデルを評価するには、以下のコードを使えば簡単です。

Hugging Face Trainerは、検証データセット(指定されている場合はテスト・データセット)でモデルを評価し、計算されたメトリックスを返します。このメソッドは、未知のデータに対するモデルのパフォーマンスを測定するために使用されます。

In [23]:
results = trainer.evaluate()

results_df = pd.DataFrame([results])
results_df
[92/92 00:55]
Out[23]:
eval_loss eval_bertscore_precision eval_bertscore_recall eval_bertscore_f1 eval_runtime eval_samples_per_second eval_steps_per_second epoch
0 0.847545 0.715506 0.692384 0.703341 63.7571 11.544 1.443 4.99
以上の結果から、モデルはうまく機能していると言えます。

12. モデルのテスト¶

モデルをトレーニングし、評価しました。次はモデルをテストする番です。

12.1. Generate Predictions (予測の生成)¶

次のステップでは、モデルを使って予測を生成します。

予測を生成する方法はたくさんあります。

  • Greedy Search
  • Beam Search
  • Top-K Sampling
  • Top-P Sampling

などなど。

ここではBeam Searchを使います。Beam Searchは、限られた集合の中で最も有望なノードを展開することによってグラフを探索するヒューリスティクな探索アルゴリズムです。これはベスト・ファースト探索の最適化であり、必要なメモリを削減することができます。


参考情報:

ここではBeam Searchについての説明はしません。しかし、もし興味があれば、次のブログ記事でBeam Searchやその他の方法についてもっと読むことができます。 https://huggingface.co/blog/how-to-generate


パラメータが異なれば、結果も異なります。従って、いろいろなパラメータを試してみて、それが結果にどう影響するかを見てみるのがいいです。これは正解も不正解もありません。

generate_title関数は、我々のモデルモデルを使用して、与えられたテキスト(コンテンツ)を処理し、そのコンテンツにに合ったタイトルのリストを生成します。

簡潔にまとめるとこうなります:

関数:

  • 提供されたコンテンツを前処理します。
  • mt5_tokenizerを使用して、クリーニングされたコンテンツをトークン化する。
  • トークン化された入力をモデルに食わせて、max_length、Beam Searchの設定、ランダムなtempratureなどの制約やプリファレンスを使用して、タイトルの候補を生成します。
In [24]:
def generate_title(content):

    # inputs = [text_clean_preprocess(content)]
    inputs = [f"summarize: " + text_clean_preprocess(content)]

    batch = mt5_tokenizer.batch_encode_plus(
        inputs, max_length=512, truncation=True,
        padding="longest", return_tensors="pt")

    input_ids = batch['input_ids']
    input_mask = batch['attention_mask']


    model.eval()

    outputs = model.generate(
        input_ids=input_ids.cuda(),
        attention_mask=input_mask.cuda(),
        max_length=64,
        temperature=1.1, #
        num_beams=6, #24
        diversity_penalty=3.0, #1.8
        num_beam_groups=3,
        num_return_sequences=3,
        repetition_penalty=9.0,
        # early_stopping=True, #false
        # max_new_tokens=64,
        # do_sample = True
    )

    generated_titles = [mt5_tokenizer.decode(ids, skip_special_tokens=True,
                            clean_up_tokenization_spaces=False)
                        for ids in outputs]

    return generated_titles
では、テストセットからランダムに記事を選び、タイトルを生成してみましょう。
In [25]:
print("Total titles in dataset:", len(dataset["test"]["title"]))

selected_index = [75, 140, 286]

for index in selected_index:
    print("original: ", dataset["test"]["title"][index])
    titles = generate_title(dataset["test"]["content"][index])
    for i, title in enumerate(titles):
        print(f"Generated title {i+1}: {title}")
    print()
Total titles in dataset: 738
original:  いいとも!で紹介された「ヒドすぎる」名前が話題に
Generated title 1: 『ザザザの斬新!赤ちゃんネーム』で紹介された「キラキラネーム」がありえない
Generated title 2: 『ザザザの斬新!赤ちゃんネーム』が「ありえない」とネットニュースで話題
Generated title 3: ネットスラングで「キラキラネーム」がありえないとネットで話題【話題】

original:  日本の引きこもりに海外から相次ぐ心配の声
Generated title 1: 日本の引きこもりが海外で話題に
Generated title 2: 日本の引きこもりが海外で話題
Generated title 3: 「日本=出る釘は打たれる」など、日本の引きこもりが海外で話題に

original:  甲子園出場する石巻工「約5000万円が必要」呼びかけに物議
Generated title 1: 【Sports Watch】石巻工業高校が総額約5000万円の協賛金を募っている
Generated title 2: 【Sports Watch】石巻工業高校が総額約5000万円の協賛金を募っている理由
Generated title 3: 石巻工業高校が、総額約5000万円の協賛金を募っている【話題】

生成されたタイトルは実際のタイトルにかなり似ていることがわかります。これは良い兆候で、モデルが実際のタイトルに似たタイトルを生成するように学習していることを示しています。

また、生成されたタイトルは実際のタイトルと同一ではないことにも注目してください。これは想定内で、モデルは正確なタイトルを生成するようにトレーニングされていないからです。その代わりに、実際のタイトルに似たタイトルを生成するように学習されています。

パラメータが違えば、結果も違ってきます。だから、いろいろなパラメーターを試して、それが結果にどう影響するかを見てみるといいです。

12.2. Yahoo! Newsで試してみる¶

我々のモデルはlivedoorニュースコーパスで学習され、livedoorのニュースでテストされました。では、livedoor以外のニュースにどう反応するか見てみたいと思います。yahooニュースから取得したニュースの内容を貼り付けました。私たちのモデルからの出力を見てみましょう。

In [26]:
# https://news.yahoo.co.jp/pickup/6476740

yahoo_news_1_original_title = """【速報】女川原発2号機「再稼働目標 3か月延期へ」来年5月に<東北電力>"""

yahoo_news_1 = """東北電力は来年2月を目標としていた「女川原発2号機の再稼働」について、来年5月に延期することを明らかにした。今年11月としていた安全対策工事の完了時期が来年2月に延びるため。

東北電力によると、工事が3か月延びるのは、発電所内の設備などにつながるケーブルが火災などで損傷しないようにする「火災防護対策」を追加したことが主な要因。この対策を巡っては、他の電力会社が原子力規制員会から指摘を受けた事例を踏まえ、東北電力では去年10月から追加で工事をすることを準備していた。
その工程を精査した結果、3か月ほど完了時期が延びることが判明したもので、それに伴って女川原発2号機の再稼働目標も3か月延期し、来年5月頃となった。
"""

print("Original yahoo title: ", yahoo_news_1_original_title)
for i, title in enumerate(generate_title(yahoo_news_1)):
    print(f"Generated title {i+1}: {title}")
Original yahoo title:  【速報】女川原発2号機「再稼働目標 3か月延期へ」来年5月に<東北電力>
Generated title 1: 東北電力、来年2月を目標としていた「女川原発2号機再稼働」を発表
Generated title 2: 東北電力、来年2月を目標としていた「女川原発2号機再稼働」について3か月延期
Generated title 3: 【ニュース】東北電力、来年2月を目標としていた「女川原発2号機再稼働」について3か月延期
livedoorのニュースでなくても、我々のモデルがタイトルを生成していることがわかります。しかもなかなか良いタイトルです。つまり、我々のモデルがうまく機能しているということです。

12.3. メールで試してみる (Just for fun)¶

もし興味があれば、あなたのメールをコピーペーストして、どのように反応するか見てみるといいです。

In [27]:
sample_email = """
おはようございます。
アナハイム・エレクトロニクス事務局です。

本日は、まもなく締め切りとなります、10/18(水)開催の対面イベント「モルゲンレーテ社製品無料体験会&相談会 @アナハイム・エレクトロニクスカリフォルニア本社」のご案内です!

実際に触れてみないとわからないとお困りの方、ぜひこの機会に体験してみてください。

モルゲンレーテ社製品無料体験会は特に以下のような方にお勧めです。

・電子・電気機器の新規導入でご検討中の方
 ※小規模から対応できます!規模は問いません。
・電子機器の更改でクラウド移行を検討している方
・顧客管理システムも含めセキュアに電子機器を構築したい方。

モルゲンレーテ社製品を検討はしているけど、実際に触れてみないとわからない、導入を検討しているけど、何から始めればよいか分からないといった課題、お困りごとのある方は、ぜひこの機会に体験してみてください。

また、体験会の後は、弊社エンジニアの個別相談会も予定しております。

皆さまのお悩みなどお気軽にお話いただければと思います。

以下、無料体験会の詳細、申し込み方法をご確認のうえ、ぜひお気軽にご参加ください!

イベントの参加は無料! 参加お申込みは10月16日(月)となります。
みなさまのご参加お待ちしています!
"""

for i, title in enumerate(generate_title(sample_email)):
    print(f"Generated title {i+1}: {title}")
Generated title 1: アナハイム・エレクトロニクスカリフォルニア本社にて開催の対面イベント「モルゲンレーテ社製品無料体験会&相談会」を開催
Generated title 2: アナハイム・エレクトロニクスカリフォルニア本社にて開催の対面イベント「モルゲンレーテ社製品無料体験会&相談会】
Generated title 3: 「モルゲンレーテ社製品無料体験会&相談会 アナハイム・エレクトロニクスカリフォルニア本社」
  • 適切に学習されたモデルであれば、ニュースのタイトルのような文章が生成されるはずです。
  • つまり、特定のコーパスからどんなスタイルでも生成できるということです(メール・コーパスを使えば、メールの件名を生成できる)。

13. モデルを改良する方法¶

今回のモデルを改良する方法はたくさんあります。以下にそのいくつかを示します。

13.1. データ¶

  • データサイズ:今回は比較的小さなデータセットを使用しました。データが多ければ多いほど、モデルはより良いものになります。
  • クリーニング:今回はデータのクリーニングを行っていません。モデルを改善するために、いくつかのクリーニングを行うことができます。
  • 正規化: 小文字データ、全角スペースなど、非常に単純なことだけを行いました。しかし、もっとできることがあります。
    • 例えば、日本語の引用符は「」や『』など様々な種類があります。これらを一つのタイプに統一して正規化することもできます。

13.2. モデルとパラメータ¶

  • モデルのサイズ: 計算リソースがあれば、より良いパフォーマンスを得るために、より大きなバージョンを使用することを検討します。
  • ハイパーパラメータ: 学習率、バッチサイズ、最適化アルゴリズムを変えて実験します。
  • モデル・アーキテクチャ: さまざまなモデル・アーキテクチャを試すことができます。例えば、BART、T5、Pegasusなどを試すことができます。
  • mt5 は多言語に対応しているので、日本語の T5 モデルを使ってみると、より良いパフォーマンスが得られます。

13.3. 学習戦略¶

  • 転移学習:: タイトル生成タスクでファインチューニングする前に、まず関連タスク(例えば要約)でモデルをファインチューニングして、モデルをドメインに慣れさせることができます。

    • ファインチューニング [記事 → 要約]
    • モデルを保存
    • 保存したモデルを使って、再度 [記事 -> タイトル] をファインチューニング
  • 正則化: ドロップアウト、レイヤーの正規化、ウェイトの減衰のようなテクニックは、オーバーフィッティングを防ぐのに役立ちます。

13.4. カリキュラム学習¶

  • 動的サンプリング: BERTScoreを使用して、トレーニングサンプルを難易度に基づいてランク付けします。「簡単な」サンプルでトレーニングを開始し、徐々に難易度を上げます。

13.5. 後処理とルール¶

  • 長さのコントロール: タイトルが一定の長さに収まるようにルールを導入することができます。
  • 用語: タイトルに表示すべき(または表示すべきでない)特定の用語や語句がある場合、これらのルールを後処理で強制することができます。

13.6. デコーディング戦略を試す:¶

  • ビーム検索: 各ステップで最も可能性の高い次の単語を選ぶ代わりに、ビームサーチは複数のシークエンスを考慮するため、より良い結果を得ることができます。
  • 温度サンプリング: 温度パラメータを調整することで、出力のランダム性を制御できます。高い値を設定すると出力はよりランダムになり、低い値を設定するとより決定論的になります。

14. まとめ¶

  • Google Colabの使い方を学びました
  • Huggingface transformersライブラリとDatasetsライブラリのの使い方を学びました
  • 評価関数の使い方について話しました
  • Bert Scoreについて話しました
  • トークン化とモデルトレーニングに慣れました
  • Huggingface Trainerの使い方を学びました
  • テキスト生成タスクのためにディープラーニングモデルをファインチューニングする経験を得ました
  • 最後に、学習済みモデルを使用して予測を生成する方法を学びました