hye-log

[๋ถ€์ŠคํŠธ์บ ํ”„ AI Tech]WEEK 02_DAY 08 ๋ณธ๋ฌธ

Boostcourse/AI Tech 4๊ธฐ

[๋ถ€์ŠคํŠธ์บ ํ”„ AI Tech]WEEK 02_DAY 08

iihye_ 2022. 9. 29. 01:29

๐ŸŽ€ ๊ฐœ๋ณ„ํ•™์Šต


[6] ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

1. model.save()

1) ํ•™์Šต์˜ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜

2) ๋ชจ๋ธ์˜ ํ˜•ํƒœ์™€ parameter๋ฅผ ์ €์žฅ

# Print model's state_dict

# state_dict: ๋ชจ๋ธ์˜ parameter๋ฅผ ํ‘œ์‹œ
for param_tensor in model.state_dict():
	print(param_tensor, "\t", model.state_dict()[param_tensor].size())
    
# ๋ชจ๋ธ์˜ parameter๋ฅผ ์ €์žฅ
torch.save(model.state_dict(), os.path.join(MODEL_PATH, "model.pt"))

# ๊ฐ™์€ ๋ชจ๋ธ์˜ ํ˜•ํƒœ์—์„œ parameter๋งŒ load
new_model = TheModelClass()
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt")))

# ๋ชจ๋ธ์˜ architecture์™€ ํ•จ๊ป˜ ์ €์žฅ
torch.save(model, os.path.join(MODEL_PATH, "model.pt"))
# ๋ชจ๋ธ์˜ architecture์™€ ํ•จ๊ป˜ load
model = torch.load(os.path.join(MODEL_PATH, "model.pt"))

 

2. checkpoints

1) ํ•™์Šต์˜ ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜์—ฌ ์ตœ์„ ์˜ ๊ฒฐ๊ณผ๋ฅผ ์„ ํƒ

2) earlystopping ๊ธฐ๋ฒ• ์‚ฌ์šฉ ์‹œ ์ด์ „ ํ•™์Šต์˜ ๊ฒฐ๊ณผ๋ฌผ์„ ์ €์žฅ

3) epoch, loss, metric์„ ํ•จ๊ป˜ ์ €์žฅํ•˜์—ฌ ํ™•์ธ

# checkpoint

# ๋ชจ๋ธ์˜ ์ •๋ณด๋ฅผ epoch์™€ ํ•จ๊ป˜ ์ €์žฅ
torch.save({'epoch': e,
	    'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,},
            f'saved/checkpoint_model_{e}_{epoch_loss/len(dataloader}_{epoch_acc/len(dataloader)}.pt')
            
# load checkpoint
checkpoint = torch.load(PATH)
# load model of checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
# load optimizer of checkpoint
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# load epoch of checkpoint
epoch = checkpoint['epoch']
# load loss of checkpoint
loss = checkpoint['loss']

 

3. Transfer Learning

1) ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋งŒ๋“  ๋ชจ๋ธ์„ ํ˜„์žฌ ๋ฐ์ดํ„ฐ์— ์ ์šฉ

2) ์ผ๋ฐ˜์ ์œผ๋กœ ๋Œ€์šฉ๋Ÿ‰ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋งŒ๋“ค์–ด์ง„ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ๋” ์ข‹์Œ

3) Deep Learning์—์„œ ์ผ๋ฐ˜์ ์ธ ํ•™์Šต ๊ธฐ๋ฒ•

4) TorchVision์—์„œ ๋‹ค์–‘ํ•œ ๊ธฐ๋ณธ ๋ชจ๋ธ ์ œ๊ณต

- ์ฐธ๊ณ  : https://pytorch.org/vision/0.8/models.html

 

torchvision.models — Torchvision 0.8.1 documentation

torchvision.models The models subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection and video classific

pytorch.org

5) freezing : pretrained model์„ ํ™œ์šฉ์‹œ ๋ชจ๋ธ์˜ ์ผ๋ถ€๋ถ„์„ frozen ์‹œํ‚ด

# transfer learning

# vgg16 ๋ชจ๋ธ์„ vgg์— ํ• ๋‹น
vgg = models.vgg16(pretrained=True).to(device)

class MyNewNet(nn.Module):   
    def __init__(self):
        super(MyNewNet, self).__init__()
        self.vgg19 = models.vgg19(pretrained=True)
        # ๋ชจ๋ธ์˜ ๋งˆ์ง€๋ง‰์— linear_layer ์ถ”๊ฐ€
        self.linear_layers = nn.Linear(1000, 1)


    # Defining the forward pass    
    def forward(self, x):
        x = self.vgg19(x)        
        return self.linear_layers(x)
        
# ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด๋ฅผ ์ œ์™ธํ•˜๊ณ  frozen
for param in my_model.parameters():
    param.requires_grad = False
    
for param in my_model.linear_layers.parameters():
    param.requires_grad = True

[7] Monitoring tools for PyTorch

1. Tensorboard

1) TensorFlow์˜ ํ”„๋กœ์ ํŠธ๋กœ ๋งŒ๋“ค์–ด์ง„ ์‹œ๊ฐํ™” ๋„๊ตฌ

2) ํ•™์Šต ๊ทธ๋ž˜ํ”„, metric, ํ•™์Šต ๊ฒฐ๊ณผ์˜ ์‹œ๊ฐํ™” ์ง€์›

3) PyTorch ์—ฐ๊ฒฐ ๊ฐ€๋Šฅ

4) ์ข…๋ฅ˜

- scalar : metric ๋“ฑ ์ƒ์ˆ˜ ๊ฐ’์˜ ์—ฐ์†์„ ํ‘œ์‹œ

- graph : ๋ชจ๋ธ์˜ computational graph ํ‘œ์‹œ

- histogram : weight ๋“ฑ ๊ฐ’์˜ ๋ถ„ํฌ๋ฅผ ํ‘œํ˜„

- image : ์˜ˆ์ธก ๊ฐ’๊ณผ ์‹ค์ œ ๊ฐ’์˜ ๋น„๊ต๋ฅผ ํ‘œ์‹œ

- mesh : 3d ํ˜•ํƒœ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ํ‘œํ˜„

# Tensorboard

import os
logs_base_dir = "logs"
# Tensorboard ๊ธฐ๋ก์„ ์œ„ํ•œ directory ์ƒ์„ฑ
os.makedirs(logs_base_dir, exist_ok=True)

from torch.utils.tensorboard import SummaryWriter
import numpy as np

# ๊ธฐ๋ก ์ƒ์„ฑ ๊ฐ์ฒด SummaryWriter ์ƒ์„ฑ
exp  = f"{logs_base_dir}/ex3"
writer = SummaryWriter(exp)
for n_iter in range(100):
	# add_scalar : scalar ๊ฐ’์„ ๊ธฐ๋ก
    writer.add_scalar('Loss/train', np.random.random(), n_iter)
    writer.add_scalar('Loss/test', np.random.random(), n_iter)
    writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
    writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
# ๊ฐ’ ๊ธฐ๋ก(disk์— ์“ฐ๊ธฐ)
writer.flush()

# tensorboard ์ˆ˜ํ–‰
%load_ext tensorboard
%tensorboard --logdir "logs"

 

2. Weights&Biases

1) ๋จธ์‹ ๋Ÿฌ๋‹ ์‹คํ—˜์„ ์›ํ™œํžˆ ์ง€์›ํ•˜๊ธฐ ์œ„ํ•œ ์ƒ์šฉ๋„๊ตฌ

2) ํ˜‘์—…, code versioning, ์‹คํ—˜ ๊ฒฐ๊ณผ ๊ธฐ๋ก ๋“ฑ ์ œ๊ณต

# wandb
!pip install wandb-q

# config ์„ค์ •
config={'epochs':EPOCHS, 'batch_size':BATCH_SIZE, 'learning_rate':LEARNING_RATE}
wandb.init(project='my-test-project', config=config)

for e in range(1, EPOCHS+1):
	epoch_loss = 0
    epoch_acc = 0
    for x_batch, y_batch in train_dataset:
    	x_batch, y_batch = x_batch.to(device), y_batch.to(device).type(torch.cuda.FloatTensor)
        #...
        optimizer.step()
        #...
        
# wandb์— ๊ธฐ๋ก
wandb.log({'accuracy': train_acc, 'loss': train_loss})


๐ŸŽ€ ์˜ค๋Š˜์˜ ํšŒ๊ณ 

์˜ค์ „์—๋Š” ๊ฐ•์˜ ํŒŒ์ดํ† ์น˜ ๊ฐ•์˜ 6๊ฐ•์„ ๋“ฃ๊ณ , ์˜คํ›„์—๋Š” 7๊ฐ•๊ณผ ๊ธฐ๋ณธ 2 ๊ณผ์ œ๋ฅผ ํ’€์–ด๋ณด์•˜๋‹ค. ํ™•์‹คํžˆ ๋ถ€๋•์ด๊ฐ€ ๋‚˜์˜ค์ง€ ์•Š๋Š” ๊ธฐ๋ณธ 1 ๊ณผ์ œ๋ณด๋‹ค ์–‘์€ ์ ์–ด๋ณด์˜€์œผ๋‚˜ ์–ด๋ ต๋‹ค.. ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ๋ฐ์ดํ„ฐ์…‹์„ ๋งŒ๋“ค๊ณ  ์ฝ”๋“œ๋กœ ํ•˜๋‚˜์”ฉ ์ž‘์„ฑํ•ด๋ณด๋‹ˆ ์ดํ•ด๊ฐ€ ๊ฐ€๋Š”๊ฑฐ ๊ฐ™๊ธฐ๋„ ํ•˜๊ณ  ์•„์ง๋„ ์ž˜ ๋ชจ๋ฅด๊ฒ ๋‹คใ… ใ…  ํ”ผ์–ด์„ธ์…˜ ๋•Œ๋Š” ๊ณผ์ œ ๊ด€๋ จ๋œ ์ด์•ผ๊ธฐ๋งŒ ํ•œ๊ฑฐ ๊ฐ™๋‹ค. ๋ฒŒ์จ ์ˆ˜์š”์ผ, ์ด๋ฒˆ ์ฃผ๋Š” ๊ณผ์ œ ๋งˆ๊ฐ์ด ๋นจ๋ผ์„œ ๋‚ด์ผ๊นŒ์ง€ ๊ธฐ๋ณธ 1, 2 ๊ณผ์ œ ๋‹ค์‹œ ์ ๊ฒ€ํ•˜๊ณ  ์ œ์ถœํ•˜๋Š” ๊ฒŒ ๋ชฉํ‘œ์ด๋‹ค. ์˜ค๋Š˜ ๋‘๋Ÿฐ๋‘๋Ÿฐ 1ํšŒ์ฐจ์—์„œ๋Š” ๋ณ€์„ฑ์œค ๋งˆ์Šคํ„ฐ๊ป˜์„œ ๋งˆ์Šคํ„ฐ๋‹˜์˜ ์‚ถ์— ๋Œ€ํ•ด์„œ ์†Œ๊ฐœํ•ด์ฃผ์…จ๋‹ค. ์‹คํŒจ์™€ ์„ฑ๊ณต์ด ๋ฐ˜๋ณต๋œ ์ธ์ƒ์ด์—ˆ๋Š”๋ฐ, ์ง€๊ธˆ ์šฐ๋ฆฌ๊ฐ€ ๋ณด๊ธฐ์—๋Š” ์„ฑ๊ณตํ•œ ์ธ์ƒ์œผ๋กœ ๋ณด์ด์ง€๋งŒ, ๋งˆ์Šคํ„ฐ๋‹˜๋„ ์ˆ˜๋„ ์—†๋Š” ์‹คํŒจ๋ฅผ ๊ฒช์–ด์˜ค๋ฉด์„œ ๋‹จ๋‹จํ•ด์ง€์‹ ๊ฑฐ ๊ฐ™์•„์„œ ๋ถ€๋Ÿฌ์› ๋‹ค..!! ๋‚˜๋„ ์–ธ์  ๊ฐ€ ๋งˆ์Šคํ„ฐ๋‹˜์ฒ˜๋Ÿผ ๋‹จ๋‹จํ•œ ์‚ฌ๋žŒ์ด ๋  ์ˆ˜ ์žˆ๊ฒ ์ง€..

728x90
Comments