การจำแนกตัวเลข (Mnist) from scratch บน FastAI

 Ref: https://github.com/fastai/fastbook


อธิบายการทำงานส่วนการเตรียมข้อมูล

ส่วน image

  1. ได้เป็น list ของ Path
  2. วนลูปเพื่อแปลงที่ละภาพเป็น tensor จะได้เป็น list ที่มีสมาชิกเป็น tensor [28, 28]
  3. ใช้ stack แปลง list เป็น tensor 3 มิติ ของแต่ละคลาส คลาสสามที่มีขนาด tensor [6131, 28, 28]
  4. แปลงเป็น float แล้ว normalize ด้วย /255
  5. ใช้ cat รวมทุกคลาสเข้าด้วยกัน เป็น tensor [12396, 28, 28]
  6. ใช้ view เพื่อแปลงขนาดรูปให้เหลือมิติเดียว จะได้เป็น tensor [12396, 784]

ส่วน label

  1. กำหนดตัวเลขคลาส แล้วคูณจำนวนรูป จะได้เวคเตอร์สาม tensor [6131]
  2. นำทุกคลาสมาบวกกัน จะได้เวคเตอร์ [12396]
  3. ใช้ unsqueeze เพื่อแปลงให้เป็น [12396, 1]

from fastai.vision.all import *

path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path


# สร้าง Train set
threes = (path/'train'/'3').ls().sorted()  
sevens = (path/'train'/'7').ls().sorted()

three_tensors  = [tensor(Image.open(o)) for o in threes]  
seven_tensors  = [tensor(Image.open(o)) for o in sevens]

stacked_threes = torch.stack(three_tensors).float()/255  
stacked_sevens = torch.stack(seven_tensors).float()/255

train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28) 
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)

# สร้าง Valid set
valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])

valid_3_tens = valid_3_tens.float()/255
valid_7_tens = valid_7_tens.float()/255

valid_x    = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y    = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)


# สร้าง Dataset และ DataLoader
dset       = list(zip(train_x,train_y))
valid_dset = list(zip(valid_x,valid_y))

dl         = DataLoader(dset, batch_size=256)
valid_dl   = DataLoader(valid_dset, batch_size=256)


# กำหนดเมธอด

def init_params(size, std=1.0): 
     return (torch.randn(size)*std).requires_grad_()
    
def linear1(xb): 
    return xb@weights + bias

def sigmoid(x): 
    return 1/(1+torch.exp(-x))

def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

def train_epoch(model, lr, params):
    for xb,yb in dl:
        calc_grad(xb, yb, model)
        for p in params:
            p.data -= p.grad*lr
            p.grad.zero_()

def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)


# กำหนดพารามิเตอร์
weights    = init_params((28*28,1))
bias       = init_params(1)
lr         = 1.
params     = weights,bias


# train and validate
for i in range(20):
    train_epoch(linear1, lr, params)
    print(validate_epoch(linear1), end=' ')

แบบสรุปย่อไม่ใช้เมธอด

from fastai.vision.all import *

path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path

train_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'3').ls()]).float()/255
train_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'7').ls()]).float()/255

valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()]).float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()]).float()/255

train_x = torch.cat([train_3_tens, train_7_tens]).view(-1, 28*28)
train_y = tensor([1]*len(train_3_tens) + [0]*len(train_7_tens)).unsqueeze(1)

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)

dset       = list(zip(train_x,train_y))
valid_dset = list(zip(valid_x,valid_y))

dl         = DataLoader(dset, batch_size=256)
valid_dl   = DataLoader(valid_dset, batch_size=256)

def init_params(size, std=1.0): 
     return (torch.randn(size)*std).requires_grad_()
    
weights    = init_params((28*28,1))
bias       = init_params(1)
lr         = 1.
params     = weights,bias

for i in range(20):
    for xb,yb in dl:
        preds = xb@weights + bias
        preds = preds.sigmoid()
        loss  = torch.where(yb==1, 1-preds, preds).mean()
        loss.backward()
        for p in params:
            p.data -= p.grad*lr
            p.grad.zero_()
    accs = []
    for xb,yb in valid_dl:
        preds = xb@weights + bias
        preds = preds.sigmoid()
        correct = (preds>0.5) == yb
        acc = correct.float().mean()
        accs.append(acc)
        all_acc = torch.stack(accs).mean().item()
    print(round(all_acc, 4))
Previous
Next Post »