pytorch_lightning エラー|validation_epoch_endで止まる

現象

validation_epoch_end内で、引数のリストの長さをべた書きして、処理を書いていました。

下図で言うと、valデータ合計÷バッチサイズ=25が、リストの長さだと想定していました。

f:id:pizza3900:20210813164258p:plain
イメージ図
(補足|validation_epoch_endの引数としては、validation_stepの戻り値のリストが返ってくる。)

しかし、その処理でエラーになっていました。

原因

原因は、学習前に自動的に動く、"sanity check" という機能でした。 学習前に、評価データを少し使って、挙動を確認するようです。 "少し"とは、デフォルトでは2バッチ分です。

対策

  1. 集計を工夫する 引数で来るリストの長さが変わってもよいように、集計の計算を変更すればよいです。

例えば、リストの長さが変わるので、for文で回すときはlen(リスト)とします。

for i in range(len(リスト)):
     ....

そもそもここで詰まったのは、validation_epoch_end引数のサイズが変わることを、考慮してなかったからでした。 実装の甘さが露呈してしまいました・・・。


一応、他の方法も記しておきます。

2. フラグを作って、1回目のvalidation_epoch_end呼び出しは、そのまま戻るようにする。

下記のイメージです。

IS_DONE_CHEK = 0    # 初期化
...
def validation_epoch_end(self, outputs):
        if IS_DONE_CHEK==0:
            IS_DONE_CHEK=1
            return    # 関数の呼び出しが初の時は、何もしない
        ...


3. num_sanity_val_stepsを-1にする。

Trainerの引数で、num_sanity_val_steps=-1とすれば、validation dataすべてに対してsanity checkが動くようです。 すべてに対して動いてくれれば、集計等の計算で困ることはないかもしれません。

参考

pytorch-lightning.readthedocs.io