การจำแนกดอกไม้ (Flower) โดยใช้ Transfer Learning บน FastAI

ใช้ FastAI version 2.2.7

จัดเรียงรูปในโฟลเดอร์ flowers ดังนี้

  • train
    • class1
    • class2
  • valid
    • class1
    • class2
  • test
    • class1
    • class2

โหลด library

from fastai.vision.all import *
from fastai.metrics import error_rate, accuracy
import torchvision.models as models

กำหนดค่า

path       = 'D:/dataset/flowers/'
test_path  = path+'test/'
img_size   = 160
batch_size = 32
epoch      = 10
fig_size   = None
tool       = 'fai'
classify   = 'flower'
model_name = 'resnet18'
res_name   = str(tool)+'_'+str(classify)+'_'+str(model_name)+'_s'+str(img_size)+'_e'+str(epoch)
res_path   = 'D:/thesis/practice/fai/result/'
pth_save   = res_path+res_name

เตรียม Datablock

db = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                  get_items=get_image_files, 
                  splitter=GrandparentSplitter(),
                  get_y=parent_label,
                  item_tfms=Resize(img_size),
                  batch_tfms=[*aug_transforms(
                      size=128, 
                      min_scale=0.75,
                      do_flip=True,
                      flip_vert=False,
                      max_rotate=10.0, 
                      max_zoom=1.2, 
                      max_lighting=0.2, 
                      max_warp=0.2, 
                      p_affine=0.2),
                Normalize.from_stats(*imagenet_stats)])

dls = db.dataloaders(path,bs=batch_size,num_workers=0)

แสดงจำนวนคลาสและรูปตัวอย่าง

print("Number of classes: ", dls.c)
print("Classes: ", dls.vocab)
dls.show_batch(max_n=10, nrows=2)

โหลดโมเดล เช่น resnet18 มีอยู่ใน fastai อยู่แล้ว ส่วน mobilenetv2 เป็นของ pytorch

learn = cnn_learner(dls, resnet18, normalize=True, pretrained=True, metrics=[error_rate, accuracy])
# or
learn = cnn_learner(dls, models.mobilenet_v2, cut=-1, normalize=True, pretrained=True, metrics=[error_rate, accuracy])

การปริ้นท์โมเดล

learn.model
learn.summary()

หา Learning Rate

learn.lr_find()

method พล็อดกราฟ

@patch
@delegates(subplots)

def plot_metrics(self: Recorder, nrows=None, ncols=None, figsize=None, **kwargs):
    metrics = np.stack(self.values)
    names = self.metric_names[1:-1]
    n = len(names) - 1
    if nrows is None and ncols is None:
        nrows = int(math.sqrt(n))
        ncols = int(np.ceil(n / nrows))
    elif nrows is None: nrows = int(np.ceil(n / ncols))
    elif ncols is None: ncols = int(np.ceil(n / nrows))
    figsize = figsize or (ncols * 6, nrows * 4)
    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = [ax if i < n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
    for i, (name, ax) in enumerate(zip(names, [axs[0]] + axs)):
        ax.plot(metrics[:, i], color='#1f77b4' if i == 0 else '#ff7f0e', label='valid' if i > 0 else 'train')
        ax.set_title(name if i > 1 else 'losses')
        ax.legend(loc='best')
    plt.show()

train

import time
startTime = time.time()

learn.fit_one_cycle(epoch)

endTime = time.time()

แสดงเวลา

def printTime(startTime, endTime):
    print("Strat time = "+str(startTime))
    print("End time = "+str(endTime))
    print("Use time = "+str(endTime-startTime))
printTime(startTime, endTime)

พล็อต Loss ของ FastAI และกราฟที่เราสร้างขึ้น

learn.recorder.plot_loss()
learn.recorder.plot_metrics()

ดูการแสดงผลทดสอบ validation

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=fig_size)
interp.print_classification_report()
interp.plot_top_losses(9)
interp.most_confused(min_val=1)

เซฟโมเดล

learn.save(pth_save)
learn.export(pth_save+'.pkl')

สร้าง Test Set

# สร้าง test_dl
Path.BASE_PATH = test_path
test_files     = get_image_files(test_path)
test_dl        = dls.test_dl(test_files, with_labels=True)

# สร้าง df กับคอลัมน์ image และ label
image = []
label = []
for i in range(len(test_files)):
    image.append(os.path.split(test_files[i])[1])
    label.append(test_files[i].parent.stem)
data = {'image':image, 
        'y_label':label} 
test_df = pd.DataFrame(data) 

ทำนาย Test Set ด้วย 3 วิธี วิธี 1 และ 2 ได้ค่าเท่ากัน เพียงแต่อันที่ 2 ให้ คำทำนายเป็นชื่อคลาสด้วย

# 1) ทำนายด้วย get_preds
y_prob, y_true, y_pred = learn.get_preds(dl=test_dl, with_decoded=True)
test_df["y1_true"] = pd.DataFrame(list(y_true))
test_df["y1_pred"] = pd.DataFrame(list(y_pred))
test_df["y1_prob"] = pd.DataFrame(torch.max(y_prob,1)[0].tolist())

# ----------------------------------------------------------
# 2) ทำนายด้วย predict
test_dir_path = Path(test_path)
test_df_image = (test_df['y_label']+'/'+test_df['image']).apply(lambda x: test_dir_path/x)
y_labels = []
y_preds = []
y_probs = []
for img in test_df_image:
    cls,y_pred,y_prob = learn.predict(img)
    y_labels.append(cls)
    #y_preds.append(y_pred.item())
    #y_probs.append(torch.max(y_prob).item())
test_df["y2_cls"] = pd.DataFrame(list(y_labels))
#test_df["y2_pred"] = pd.DataFrame(list(y_preds))
#test_df["y2_prob"] = pd.DataFrame(list(y_probs))

# ----------------------------------------------------------
# 3) ทำนายด้วย tta
pred_tta = learn.tta(dl=test_dl)
#test_df["y3_true"] = pd.DataFrame(list(pred_tta[1]))
test_df["y3_pred"] = pd.DataFrame(np.argmax(pred_tta[0],axis=1).tolist())
test_df["y3_prob"] = pd.DataFrame(torch.max(pred_tta[0],1)[0].tolist())

เซฟผลป็น csv

csv_save = test_df.to_csv(r''+pth_save+'_pred_test.csv', index = None, header=True)

สร้าง method แสดงผล

import pandas as pd
import numpy as np
import seaborn as sn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
%matplotlib inline 

def showReport(y_true, y_pred, labels, name):
    print(classification_report(y_true, y_pred))
    cf = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cf, index = labels, columns = labels)
    plt.figure(figsize=fig_size)
    sn.heatmap(df_cm, annot=True, cmap="Blues", fmt='d')
    plt.title(name,fontsize = 15)
    plt.show()

แสดงผล

labels = test_dl.vocab
pred   = pd.read_csv(pth_save+'_pred_test.csv')

showReport(pred.y1_true, pred.y1_pred, labels, model_name+' test by get_preds')
#showReport(pred.y1_true, pred.y2_pred, labels, model_name+' test By predict')
showReport(pred.y1_true, pred.y3_pred, labels, model_name+' test by tta')

หรือจะแสดงผลด้วย default ของ FastAI ก็ได้

interp_test = ClassificationInterpretation.from_learner(learn,dl=test_dl)
interp_test.plot_confusion_matrix(figsize=fig_size)
interp_test.print_classification_report()
interp_test.plot_top_losses(9)
interp_test.most_confused(min_val=1)
Previous
Next Post »