Copyright (c) 2020 BioDatarium (Authors: Dalila Rendon, Ewald Enzinger)
For this project, we are fine-tuning a pretrained faster_rcnn model to accurately identify males and females of the spotted-wing drosophila (Drosophila suzukii) and individuals of Drosophila melanogaster in mixed images.
Spotted-wing drosophila is an invasive fly pest of berries, and monitoring for its presence in crop fields is an important element for management. Collecting spotted-wing drosophila is often done with drowing cup traps. One of the downsides is that these traps often collect a lot of bycatch, including other drosophila species. Trap samples are often inspected under dissecting microscopes to detect and count the presence of spotted-wing drosophila.
For this project, we aim to build an automated species recognition model that can identify and count males and females of spotted-wing drosophila in microscope images of mixed samples. The training data consists of 353 images of mixed flies, with a total of 1719 male, 2390 female spotted-wing drosophila and 3381 Drosophila melanogaster labeled individuals.
This script is an overview of one of the many approaches to solve this problem, but optimizing models with different types of images is a constant work in progress. Some of the code used here was adapted from https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
Image copyright: Dalila Rendon, 2020
!pip install pycocotools
#######################################################################
# Importing python packages
#######################################################################
# os, used for listing files from directories, checking if files exist, etc.
import os
# Numpy for linear algebra and n-dimension arrays
import numpy as np
# Pytorch, our deep learning framework
import torch
# PIL (pillow) for image loading/resizing
from PIL import Image
# Path is used similar to "os" for listing files from directories, but with
# a pattern, e.g. "*.xml" to only list XML files
from pathlib import Path
# XML library for parsing the annotation XML files
import xml.etree.ElementTree as ET
# Pre-existing object detection model (Faster-RCNN)
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def get_transform(train):
transforms = []
transforms.append(ToTensor())
if train:
transforms.append(RandomHorizontalFlip(0.5))
return Compose(transforms)
#######################################################################
# Define a "CustomDataset" class that defines __init__, __getitem__,
# and __len__ methods (functions). These are later used by Pytorch
# during training when data batches are being requested.
#######################################################################
class CustomDataset(object):
# In the __init__ method, we receive as input the data folder.
# Optionally, we can also define "transforms", which is a list
# of functions that "transform" the data, e.g. for augmentation.
# For example, one could use a RandomHorizontalFlip function from
# torchvision to randomly flip images horizontally to create augmented
# training data.
def __init__(self, data_folder, transforms = None):
self.data_folder = data_folder
self.transforms = transforms
# create a list of all image files, sorting them to
# ensure that they are aligned with the XML files
# NOTE: This requires that there is exactly one XML file for each
# image file, with the same name!!
self.image_file_names = list(sorted(Path(data_folder).glob("*.JPG")))
self.xml_file_names = list(sorted(Path(data_folder).glob("*.xml")))
print(f"{len(self.image_file_names)} JPG files")
print(f"{len(self.xml_file_names)} xml files")
def __getitem__(self, idx):
# Create full image file path from data_folder and the current file name
# referenced by idx
image_path = Path(self.data_folder).joinpath(self.image_file_names[idx])
# Load actual image data
image = Image.open(image_path).convert("RGB")
# Create full XML file path from data_folder and the current file name
# referenced by idx
label_path = Path(self.data_folder).joinpath(self.xml_file_names[idx])
# Get root element of XML file (named "annotation" in the file)
xml_tree = ET.parse(label_path).getroot()
# Find all the "<object>..</object>" tags in the XML file
objects = xml_tree.findall("object")
boxes = []
labels = []
for entry in objects:
# Get "<bndbox>..</bndbox>" tag inside the "<object>"
bounding_box = entry.find("bndbox")
# Get coordinates for bounding boxes
xmin = int(bounding_box.find("xmin").text)
xmax = int(bounding_box.find("xmax").text)
ymin = int(bounding_box.find("ymin").text)
ymax = int(bounding_box.find("ymax").text)
# Convert labels to dummy variable:
# 1 for female
# 2 for male
# 3 for drosophila
label_text = entry.find("name").text
if label_text == "f":
label = 1
elif label_text == "m":
label = 2
elif label_text == "d":
label = 3
else:
print(f"Label text for image {image_path}: {label_text}")
# Add bounding boxes
boxes.append([xmin, ymin, xmax, ymax])
labels.append(label)
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# Build the target dictionary
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
# Apply transforms, if there are any
if self.transforms is not None:
image, target = self.transforms(image, target)
# Return image and target data
return image, target
# The __len__ method returns how many images/XML file pairs
# there are in the dataset
def __len__(self):
return len(self.image_file_names)
# replace the classifier with a new one, that has
# num_classes which is user-defined
# 1 for female (Drosophila suzukii)
# 2 for male (Drosophila suzukii)
# 3 for drosophila (Drosophila melanogaster)
num_classes = 4 # 4 classes (female, male) + drosophila + background
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Do training and evaluation:
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# use our dataset and defined transformations
dataset = CustomDataset('/content/data/SWD', get_transform(train=False))
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=4,
collate_fn=collate_fn)
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images,targets) # Returns losses and detections
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x) # Returns predictions
# move model to the right device
model.to(device)
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
# let's train it for 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
# train for one epoch, printing every 10 iterations
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
# update the learning rate
lr_scheduler.step()
print("That's it!")
The text above also shows some mistakes in the labels (some were accidentally entered as "\" or "mw", but these are discarded.
# Save the model
torch.save(model.state_dict(), "/content/models/swd.pth")
Here we test the trained model on an image that it has never seen, and we see how good it is at identifying and counting objects.
# Run evaluation data
from torchvision.transforms.functional import to_tensor
image_path = '/content/data/SWD/DMS10001.JPG'
image = Image.open(image_path)
# Put model in "eval" mode for prediction
model.eval()
image_tensor = to_tensor(image).to(device)
# Get predictions
predictions = model([image_tensor])
# Draw bounding boxes
import cv2
img = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
from collections import Counter
label_counter = Counter()
for i in range(len(predictions[0]['boxes'])):
x1, x2, x3, x4 = map(int, predictions[0]['boxes'][i].tolist())
target = predictions[0]['labels'][i]
score = predictions[0]['scores'][i]
# Threshold predictions to remove low-confidence labels
if score < 0.5:
continue
if target == 1:
label_text = "f"
elif target == 2:
label_text = "m"
elif target == 3:
label_text = "d"
label_counter[label_text] += 1
image = cv2.rectangle(img, (x1, x2), (x3, x4), (255, 0, 0), 1)
image = cv2.putText(image, f"{label_text} ({score})", (x1, x2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
print(f'There are {label_counter["f"]} Drosophila suzukii females, {label_counter["m"]} Drosophila suzukii males, and {label_counter["d"]} Drosophila melanogaster.')
cv2.imshow('img', image)
The model correctly identified and counted most of the flies in this image, with exception of one blocked by the file details.