Fine-tune là gì

     

1.1 Fine-tuning là gì ?

Chắc hẳn hồ hết ai thao tác với các model trong deep learning gần như đã nghe/quen với khái niệm Transfer learningFine tuning. Khái niệm tổng quát: Transfer learning là tận dụng học thức học được từ là một vấn đề để vận dụng vào 1 vấn đề có tương quan khác. Một ví dụ đối chọi giản: thay bởi train 1 mã sản phẩm mới trọn vẹn cho việc phân các loại chó/mèo, người ta có thể tận dụng 1 model đã được train bên trên ImageNet dataset với hằng triệu ảnh. Pre-trained model này sẽ được train tiếp bên trên tập dataset chó/mèo, quy trình train này diễn ra nhanh hơn, tác dụng thường tốt hơn. Có rất nhiều kiểu Transfer learning, các chúng ta cũng có thể tham khảo trong bài xích này: Tổng hợp Transfer learning. Trong bài này, mình đã viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine-tune là gì

Bạn đã xem: Fine tune là gì

Đang xem: Fine tuning là gì

Hiểu đơn giản, fine-tuning là các bạn lấy 1 pre-trained model, tận dụng một trong những phần hoặc tổng thể các layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo thành 1 model mới. Thường các layer đầu của mã sản phẩm được freeze (đóng băng) lại – tức weight các layer này sẽ không bị đổi khác giá trị trong quy trình train. Lý do bởi những layer này đã có chức năng trích xuất tin tức mức trìu tượng thấp , kỹ năng này được học tập từ quy trình training trước đó. Ta freeze lại nhằm tận dụng được tài năng này và giúp việc train ra mắt nhanh rộng (model chỉ bắt buộc update weight ở những layer cao). Có rất nhiều các Object detect model được xây dựng dựa trên các Classifier model. VD Retina model (Object detect) được xây dựng với backbone là Resnet.


*

1.2 vì sao pytorch thay bởi vì Keras ?

Chủ đề bài viết hôm nay, mình sẽ trả lời fine-tuning Resnet50 – 1 pre-trained mã sản phẩm được cung cấp sẵn vào torchvision của pytorch. Lý do là pytorch mà không hẳn Keras ? lý do bởi việc fine-tuning mã sản phẩm trong keras rất đối chọi giản. Dưới đấy là 1 đoạn code minh hoạ cho vấn đề xây dựng 1 Unet dựa vào Resnet trong Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer(“activation_9”).outputlayer_7 = resnet.get_layer(“activation_21”).outputlayer_13 = resnet.get_layer(“activation_39”).outputlayer_16 = resnet.get_layer(“activation_48”).output#Adding outputs decoder with encoder layersfcn1 = Conv2D(…)(layer_16)fcn2 = Conv2DTranspose(…)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(…)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(…)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(…)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Bạn có thể thấy, fine-tuning mã sản phẩm trong Keras thực thụ rất đối kháng giản, dễ dàng làm, dễ dàng hiểu. Việc add thêm các nhánh rất giản đơn bởi cú pháp 1-1 giản. Vào pytorch thì ngược lại, tạo ra 1 mã sản phẩm Unet tương tự sẽ khá vất vả cùng phức tạp. Người mới học sẽ gặp mặt khó khăn vày trên mạng không nhiều các hướng dẫn cho việc này. Vậy nên bài bác này mình đã hướng dẫn cụ thể cách fine-tune vào pytorch để áp dụng vào bài toán Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?


*

Khi nhìn vào 1 bức ảnh, mắt thường có xu thế tập trung chú ý vào 1 vài chủ thể chính. Ảnh trên đây là 1 minh hoạ, màu tiến thưởng được áp dụng để biểu lộ mức độ thu hút. Saliency prediction là việc mô phỏng sự triệu tập của mắt tín đồ khi quan tiếp giáp 1 bức ảnh. Cầm thể, bài toán yên cầu xây dựng 1 model, mã sản phẩm này nhận hình ảnh đầu vào, trả về 1 mask mô rộp mức độ thu hút. Như vậy, mã sản phẩm nhận vào 1 input đầu vào image và trả về 1 mask có kích thước tương đương.

Để rõ rộng về câu hỏi này, bạn có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataset phổ biến nhất: SALICON DATASET

2.2 Unet

Note: Bạn có thể bỏ qua phần này nếu đã biết về Unet

Đây là một trong bài toán Image-to-Image. Để giải quyết bài toán này, mình sẽ xây dựng 1 mã sản phẩm theo kiến trúc Unet. Unet là 1 kiến trúc được sử dụng nhiều trong câu hỏi Image-to-image như: semantic segmentation, tự động color, super resolution … bản vẽ xây dựng của Unet có điểm tựa như với bản vẽ xây dựng Encoder-Decoder đối xứng, được thêm những skip connection tự Encode thanh lịch Decode tương ứng. Về cơ bản, những layer càng tốt càng trích xuất tin tức ở mức trìu tượng cao, điều này đồng nghĩa cùng với việc các thông tin nút trìu tượng rẻ như đường nét, màu sắc sắc, độ phân giải… sẽ bị mất non đi trong quá trình lan truyền. Tín đồ ta thêm những skip-connection vào để giải quyết vấn đề này.

Với phần Encode, feature-map được downscale bằng các Convolution. Ngược lại, tại phần decode, feature-map được upscale bởi các Upsampling layer, trong bài này bản thân sử dụng những Convolution Transpose.


*

2.3 Resnet

Để giải quyết và xử lý bài toán, mình sẽ xây dựng dựng mã sản phẩm Unet với backbone là Resnet50. Bạn nên tò mò về Resnet nếu chưa biết về phong cách thiết kế này. Hãy quan gần kề hình minh hoạ dưới đây. Resnet50 được phân thành các khối phệ . Unet được gây ra với Encoder là Resnet50. Ta sẽ lôi ra output của từng khối, tạo những skip-connection kết nối từ Encoder lịch sự Decoder. Decoder được xây đắp bởi các Convolution Transpose layer (xen kẽ trong các số đó là các lớp Convolution nhằm mục đích mục đích sút số chanel của feature map -> giảm số lượng weight mang đến model).

Theo ý kiến cá nhân, pytorch rất đơn giản code, dễ dàng nắm bắt hơn tương đối nhiều so với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, việc fine-tuning model trong pytorch lại khó khăn hơn tương đối nhiều so với Keras. Trong Keras, ta không bắt buộc quá thân mật tới loài kiến trúc, luồng cách xử trí của model, chỉ việc lấy ra những output tại một số ít layer nhất định làm skip-connection, ghép nối và tạo nên ra mã sản phẩm mới.

Xem thêm: Thay Thế Hàm Getchar() Trong C, Hỏi Về Return 0 Và Getch()


*

3. Code

Tất cả code của chính mình được gói gọn trong tệp tin notebook Salicon_main.ipynb. Bạn cũng có thể tải về với run code theo links github: github/trungthanhnguyen0502 . Trong nội dung bài viết mình vẫn chỉ gửi ra gần như đoạn code chính.

Import những package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ….

3.1 utils functions

Trong pytorch, tài liệu có máy tự dimension không giống với Keras/TF/numpy. Thông thường với numpy hay keras, hình ảnh có dimension theo thiết bị tự (batchsize,h,w,chanel)(batchsize, h, w, chanel)(batchsize,h,w,chanel). đồ vật tự trong Pytorch trái lại là (batchsize,chanel,h,w)(batchsize, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng 2 hàm toTensor với toNumpy để thay đổi qua lại thân hai format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): … ## display multi imagedef plot_imgs(imgs): …

3.2 Define model

3.2.1 Conv and Deconv

Mình sẽ xây dựng 2 function trả về module Convolution cùng Convolution Transpose (Deconv)

def Deconv(n_input, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = return nn.Sequential(*block) def Conv(n_input, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta đang copy những layer phải giữ từ bỏ resnet50 vào unet. Tiếp nối khởi tạo những Conv / Deconv layer và các layer nên thiết.

Forward function: cần bảo đảm luồng cách xử trí của resnet50 được không thay đổi giống code nơi bắt đầu (trừ Fully-connected layer). Tiếp đến ta ghép nối các layer lại theo kiến trúc Unet đã biểu lộ trong phần 2.

Xem thêm: Những Câu Nói Vần Hay - Những Câu Nói Vần Hài Hước

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet to lớn make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use to reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device(“cuda”)resnet50 = models.resnet50(pretrained=True)model = Unet(resnet50)model.to(device)## Freeze resnet50″s layers in Unetfor i, child in enumerate(model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset and Dataloader

Dataset trả dấn 1 list những image_path và mask_dir, trả về image và mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split(“/”).split(“.”) mask_fn = f”self.mask_dir/img_name.png” img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = “image”: img, “mask”: mask sample = self.transforms(**sample) img = sample mask = sample # khổng lồ Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob(“./Salicon_dataset/image/train/*.jpg”)mask_dir = “./Salicon_dataset/mask/train”train_transform = A.Compose(, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataset = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)# thử nghiệm datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài toán dễ dàng và đơn giản và để cho dễ hiểu, mình đang train theo cách đơn giản dễ dàng nhất, không validate vào qúa trình train mà chỉ lưu model sau 1 số epoch độc nhất định

train_params = optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5model.train()saved_dir = “model”os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce=”mean”)for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 demo model

img_fns = glob(“./Salicon_dataset/image/val/*.jpg”)mask_dir = “./Salicon_dataset/mask/val”val_transform = A.Compose()model.eval()val_dataset = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết trái thu được: