画像の4倍超解像と圧縮劣化画質改善の深層学習プログラムをPyTorchに移植した。

“SRNet-DR” の “GitHub” ページ:
https://github.com/ImpactCrater/SRNet-DR

私は2016年に1枚画像の超解像プログラムとして話題になった “waifu2x” を使用して機械学習に興味を持ち、その後は更に高度な “SRGAN” を試し、以後は画像の4倍超解像と圧縮劣化画質改善の深層学習プログラムを自作するようになりました。

ネットワークはResNet系のResidual of Residual Network(ショートカットは3層構造)です。

先日までは機械学習フレームワークとしてTensorFlow(以下TF) v1.x系でスクリプトを書いていたのですが、所有PCのOSがUbuntu 20.04 LTSにアップグレードしてからPythonの互換性の問題で古いTF v1.x系が利用出来なくなったので、TF v2.x系に移植を試みました。

TF v1.x系は静的なグラフを構築してから動作する “Define and Run” 方式だったのが、 TF v2.x系は動的な “Eager Execution” であり、 “Define by Run” 方式となり、仕組みが一変しました。
それ自体は問題無いのですが、移植して実行してみると、メモリー使用量がとても多くなってしまい、それまでのモデルがRAMに収まらずPCがフリーズしてしまいました。
コードを見直しても改善せず、4日ほど前ですが、諦めて現在人気のある機械学習フレームワークである “PyTorch” に移行する事に致しました。

私はPyTorchは初めてでしたが、非常に簡単でした。

まずは、サクッとインストール。

学習に使う画像データ一式があるフォルダーから画像データを取り出すDatasetクラスを継承してデータ オーグメンテーションや変換のメソッドを組み込みました。

それだけで既定のDataloaderクラスでシャッフルしたりしてミニバッチ単位でデータを取り出せるようになりました。

モデルも torch.nn.Module クラスを継承して、 __init__ にて layersList[] に次々と torch.nn.Conv2d や Swish() 活性化関数や torch.nn.GroupNorm や torch.nn.PixelShuffle などのパーツを .append して行きまして、 def forward(self, x): の中で、 x = self.layers[i](x) のようにデータの流れを記述すると、スキップ コネクションが3層構造のResidual of Residual Networkも簡単に構築出来ました。

圧縮画像ファイル形式であるwebpを利用し、画像を非可逆圧縮劣化させ、それを元に戻すような変換も学習させています。

[プログラム内の学習の流れ]
モデルのインスタンスを作成
GPUが利用可能ならGPUを利用するようにモデルを .to(device)
オプティマイザーを作成
SSIM損失関数のインスタンスを作成
評価用画像のDatasetを作成
評価用画像のDataLoaderを作成
utils.save_image() で評価用画像を保存
学習用画像のDatasetを作成
学習用画像のDataLoaderを作成
epoch分のforループ
Datasetのリストを更新
DataLoaderからシャッフルしながらミニバッチ単位で入力画像データと正解画像データのペアを取り出し
model.train() でtraining モードに設定
データを .to(device) してからモデルに入力
SSIM Lossを計算
optimizer.zero_grad() で勾配を初期化
loss.backward() で誤差逆伝播により勾配を計算
optimizer.step() でパラメーターを更新
一定のステップが回った所でValidationを実行
model.eval() でevaluation モードに設定
with torch.no_grad(): でValidationのスコープ内では勾配計算をさせないように設定。
評価用画像データのミニバッチをモデルに入力
utils.save_image() で生成画像を保存
torch.save(model.to(“cpu”).state_dict(), savePath) でモデル データを保存

といった具合にとても簡潔です。

[プログラム作成の注意点]
Validation実行時に、 with torch.no_grad() をして置かないとメモリーを消費してしまう。

epoch毎に各ステップの損失関数の値を累積してその平均値を出す時に、単純に totalLoss += loss などとすると、lossには計算グラフの勾配情報があり、totalLossにまで計算グラフの勾配情報が累積的に記録されてしまう事でメモリー消費量が増えてしまう。
よって、 totalLoss += float(loss) や totalLoss += loss.detach().clone() などとして勾配情報の履歴を追跡させないようにする必要がある。

落とし穴として、 DataLoader(datasetValidation, batch_size=miniBatchSize, shuffle=False, num_workers=0, drop_last=False) のオプションの num_workers=0 を 0以外に設定すると、メモリーアクセスの問題が発生して2epoch目でメモリー使用量が急増する。
https://github.com/pytorch/pytorch/issues/13246

因みにこのプログラムではモデルが大きく、CPU処理の場合はメモリーは24GiBくらいはあった方が良さそうです。
GPU処理の場合は16GiBで足りる事を確認しました。

自作の人工ニューラル ネットワークのパラメーター数の確認の為に、”torchsummary” をインストールしてモデル情報を出力してみました。

================================================================
Total params: 78,502,467
Trainable params: 78,502,467
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 7530.75
Params size (MB): 299.46
Estimated Total Size (MB): 7830.32

パラメーター数は凡そ7千8百万超でした。
巨大なモデルと言えるでしょう。

[2021年10月19日追記]
{
現在は私の超解像復元モデルのパラメーター数は凡そ2億5千万超になっております。
更に、世界的には1千億以上のパラメーター数のディープ ラーニング モデルも登場しており、ネットワークのモデルは指数関数的に巨大化の一途を辿っているようです。
}

コメント

タイトルとURLをコピーしました