Ref: https://github.com/fastai/fastbook
อธิบายการทำงานส่วนการเตรียมข้อมูล
ส่วน image
- ได้เป็น list ของ Path
- วนลูปเพื่อแปลงที่ละภาพเป็น tensor จะได้เป็น list ที่มีสมาชิกเป็น tensor [28, 28]
- ใช้ stack แปลง list เป็น tensor 3 มิติ ของแต่ละคลาส คลาสสามที่มีขนาด tensor [6131, 28, 28]
- แปลงเป็น float แล้ว normalize ด้วย /255
- ใช้ cat รวมทุกคลาสเข้าด้วยกัน เป็น tensor [12396, 28, 28]
- ใช้ view เพื่อแปลงขนาดรูปให้เหลือมิติเดียว จะได้เป็น tensor [12396, 784]
ส่วน label
- กำหนดตัวเลขคลาส แล้วคูณจำนวนรูป จะได้เวคเตอร์สาม tensor [6131]
- นำทุกคลาสมาบวกกัน จะได้เวคเตอร์ [12396]
- ใช้ unsqueeze เพื่อแปลงให้เป็น [12396, 1]
from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
# สร้าง Train set
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()
three_tensors = [tensor(Image.open(o)) for o in threes]
seven_tensors = [tensor(Image.open(o)) for o in sevens]
stacked_threes = torch.stack(three_tensors).float()/255
stacked_sevens = torch.stack(seven_tensors).float()/255
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
# สร้าง Valid set
valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = valid_7_tens.float()/255
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
# สร้าง Dataset และ DataLoader
dset = list(zip(train_x,train_y))
valid_dset = list(zip(valid_x,valid_y))
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)
# กำหนดเมธอด
def init_params(size, std=1.0):
return (torch.randn(size)*std).requires_grad_()
def linear1(xb):
return xb@weights + bias
def sigmoid(x):
return 1/(1+torch.exp(-x))
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets==1, 1-predictions, predictions).mean()
def calc_grad(xb, yb, model):
preds = model(xb)
loss = mnist_loss(preds, yb)
loss.backward()
def train_epoch(model, lr, params):
for xb,yb in dl:
calc_grad(xb, yb, model)
for p in params:
p.data -= p.grad*lr
p.grad.zero_()
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds>0.5) == yb
return correct.float().mean()
def validate_epoch(model):
accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
return round(torch.stack(accs).mean().item(), 4)
# กำหนดพารามิเตอร์
weights = init_params((28*28,1))
bias = init_params(1)
lr = 1.
params = weights,bias
# train and validate
for i in range(20):
train_epoch(linear1, lr, params)
print(validate_epoch(linear1), end=' ')
แบบสรุปย่อไม่ใช้เมธอด
from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
train_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'3').ls()]).float()/255
train_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'7').ls()]).float()/255
valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()]).float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()]).float()/255
train_x = torch.cat([train_3_tens, train_7_tens]).view(-1, 28*28)
train_y = tensor([1]*len(train_3_tens) + [0]*len(train_7_tens)).unsqueeze(1)
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
dset = list(zip(train_x,train_y))
valid_dset = list(zip(valid_x,valid_y))
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)
def init_params(size, std=1.0):
return (torch.randn(size)*std).requires_grad_()
weights = init_params((28*28,1))
bias = init_params(1)
lr = 1.
params = weights,bias
for i in range(20):
for xb,yb in dl:
preds = xb@weights + bias
preds = preds.sigmoid()
loss = torch.where(yb==1, 1-preds, preds).mean()
loss.backward()
for p in params:
p.data -= p.grad*lr
p.grad.zero_()
accs = []
for xb,yb in valid_dl:
preds = xb@weights + bias
preds = preds.sigmoid()
correct = (preds>0.5) == yb
acc = correct.float().mean()
accs.append(acc)
all_acc = torch.stack(accs).mean().item()
print(round(all_acc, 4))
Sign up here with your email
ConversionConversion EmoticonEmoticon