集團網(wǎng)站 源碼網(wǎng)站設(shè)計公司排名
checkpoint_blocks
函數(shù)實現(xiàn)了一種分塊梯度檢查點機制 (checkpoint_blocks
),目的是通過分塊(chunking)執(zhí)行神經(jīng)網(wǎng)絡(luò)模塊,減少內(nèi)存使用。在深度學(xué)習(xí)訓(xùn)練中,梯度檢查點(activation checkpointing)是一種顯存優(yōu)化技術(shù)。該代碼可以:
- 對神經(jīng)網(wǎng)絡(luò)的塊(blocks)按需分塊,并對每塊應(yīng)用梯度檢查點。
- 動態(tài)調(diào)整計算開銷與顯存占用的權(quán)衡。
1. 源代碼:
from typing import Any, Tuple, List, Callable, Optional
import torch
import torch.utils.checkpoint
import functoolstry:import deepspeeddeepspeed_is_installed = True
except ImportError:deepspeed_is_installed = FalseBLOCK_ARG = Any
BLOCK_ARGS = Tuple[BLOCK_ARG, ...] # List[BLOCK_ARGS]def get_checkpoint_fn():return torch.utils.checkpoint.checkpoint # deepspeed.checkpointing.checkpointdef checkpoint_blocks(blocks: List[Callable],args: BLOCK_ARGS,blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS:"""Chunk a list of b