pytorch_lightning エラー|validation_epoch_endで止まる
現象
validation_epoch_end内で、引数のリストの長さをべた書きして、処理を書いていました。
下図で言うと、valデータ合計÷バッチサイズ=25が、リストの長さだと想定していました。
(補足|validation_epoch_endの引数としては、validation_stepの戻り値のリストが返ってくる。)
しかし、その処理でエラーになっていました。
原因
原因は、学習前に自動的に動く、"sanity check" という機能でした。 学習前に、評価データを少し使って、挙動を確認するようです。 "少し"とは、デフォルトでは2バッチ分です。
対策
- 集計を工夫する 引数で来るリストの長さが変わってもよいように、集計の計算を変更すればよいです。
例えば、リストの長さが変わるので、for文で回すときはlen(リスト)とします。
for i in range(len(リスト)): ....
そもそもここで詰まったのは、validation_epoch_end引数のサイズが変わることを、考慮してなかったからでした。 実装の甘さが露呈してしまいました・・・。
一応、他の方法も記しておきます。
2. フラグを作って、1回目のvalidation_epoch_end呼び出しは、そのまま戻るようにする。
下記のイメージです。
IS_DONE_CHEK = 0 # 初期化 ... def validation_epoch_end(self, outputs): if IS_DONE_CHEK==0: IS_DONE_CHEK=1 return # 関数の呼び出しが初の時は、何もしない ...
3. num_sanity_val_stepsを-1にする。
Trainerの引数で、num_sanity_val_steps=-1とすれば、validation dataすべてに対してsanity checkが動くようです。 すべてに対して動いてくれれば、集計等の計算で困ることはないかもしれません。