2022年3月21日 星期一

PyTorch Lightning 初探--高度抽象化的機器學習架構

pytorch  的程式碼已經相當簡潔了,但是面對海量的架構,還有沒有更更更更簡化的可能呢? 有的, PyTorch Lightning 就是你要的。 假日硬啃了一遍,真是被網路上一些過時的教程害死。接下來就是浪費了一個週末得到的血淚心得XD

  1. 你的聖經就是 HOW TO ORGANIZE PYTORCH INTO LIGHTNING ,因為版本迭代實在太快,大概半年就翻了一翻,千萬不要,再次強調,千萬不要看半年之前其它神人的神文。神人再神,也沒時間去更新這些教程:)
  2. 安裝套件的最新穩定版本。 anaconda  的管理程式不知為何建議了一個非常舊的版本,加倍了我除錯的工作,令人印象深刻。
  3. 你會想用  tensorboard 的
    from pytorch_lightning.loggers import TensorBoardLogger
    logger = TensorBoardLogger("tb_logs", name="my_model")
    trainer = Trainer(logger=logger)
  4. 你會想限制 epoch 次數/時間的
    trainer = Trainer(max_time="00:12:00:00", max_epochs=10)
  5. 請在 LightningModule 乖乖的實作 train_dataloader 等函式,那是讓程式決定 batch size 等參數的途徑。不要沿用原來的載入方式
  6. training, validation, predict 是主要的需實作函式, test 不是;不少 pytorch 的 test 其實指涉到 lightning 的 predict ,請格外注意;會實作 training, validation, test 的可能偏向 nlp 的 transformer based language model
  7. lightning 本身沒有辦法推進準確率或降低 loss,有時間我會建議試試 Optuna 

沒有留言:

張貼留言