๋ชฉ๋ก2024/03/28 (1)

SJ_Koding

๊ฐœ์ธ ๊ธฐ๋ก์šฉ PyTorch ์˜ค๋ฅ˜ ๋ชจ์Œ

1. timm์—์„œ model๋ฅผ loadํ•œ ๋’ค, ์ €์žฅ๋œ ptํŒŒ์ผ์„ ๋ถˆ๋Ÿฌ์™€ inference๋ฅผ ์‹œํ‚ค๋ฉด ์„ฑ๋Šฅ์ด ํฌ๊ฒŒ ๊ฐ์†Œํ–ˆ์Œ. --> model.eval()๋ฅผ ๋ฐ˜๋“œ์‹œ ์‹คํ–‰ํ•ด์ค˜์•ผํ•จ. with torch.no_grad๋Š” parameter update๋ฅผ ํ†ต์ œํ•  ๋ฟ, forward ๊ณผ์ •๊นŒ์ง€์˜ update๋ฅผ ํ†ต์ œํ•˜์ง€๋Š” ์•Š๋Š”๋‹ค. timm์€ default mode๋กœ train mode๋กœ ์„ค์ • ๋˜์–ด์žˆ์œผ๋ฏ€๋กœ eval()์„ ํ†ตํ•ด ๋ชจ๋“œ๋ฅผ ๋ฐ”๊ฟ”์ค˜์•ผํ•œ๋‹ค. eval()๋ชจ๋“œ๋Š” BatchNormalization์˜ ํŒŒ๋ผ๋ฉ”ํ„ฐ๋ฅผ Train์‹œ ์…‹ํŒ…ํ•œ ๊ฐ’์„ ๊ทธ๋Œ€๋กœ ๊ฐ€์ ธ์˜ค์ง€๋งŒ, train()์€ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์— ๋”ฐ๋ผ BN์˜ parameter๋ฅผ ๋ณ€ํ™”์‹œํ‚ค๊ธฐ ๋•Œ๋ฌธ์—, ์„ฑ๋Šฅ์ด ์—‰๋ง์ด ๋˜์—ˆ๋˜ ๊ฒƒ์ด๋‹ค. (์ถ”๊ฐ€๋กœ dropout ๋ ˆ์ด์–ด๋„ ์™„์ „ํžˆ ๋ฌด์‹œํ•ด์ค€๋‹ค.) ์•„๋งˆ ..

PyTorch Code/Pytorch 2024. 3. 28. 14:08