Implementing this hybrid approach involves adjusting the model output to include both the bounding box coordinates and a presence score, then training the model using the combined loss function.
This code below will include loading a pre-trained VGG model, modifying it for localization (to predict bounding box coordinates), and setting up a basic training loop. We'll use a hypothetical dataset and assume that your data is already loaded and preprocessed into PyTorch DataLoader objects.
import torch
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
# Assuming you have a dataset class ready for your localization and classification task
class MyLocalizationAndClassificationDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None):
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) if transform is None else transform
# Load your dataset here
def __len__(self):
# Return the size of your dataset
def __getitem__(self, idx):
# Return the image, the bounding box [x, y, width, height], the label for presence, and the class label
# Modify the VGG model for object localization, presence check, and multi-class classification
class ModifiedVGG(nn.Module):
def __init__(self, pretrained_vgg, num_classes=10):
super(ModifiedVGG, self).__init__()
self.features = pretrained_vgg.features
self.avgpool = pretrained_vgg.avgpool
self.classifier = nn.Sequential(*list(pretrained_vgg.classifier.children())[:-1]) # Remove last layer
self.bbox = nn.Linear(4096, 4) # Bounding box regression layer
self.presence = nn.Linear(4096, 1) # Object presence layer
self.classes = nn.Linear(4096, num_classes) # Multi-class classification layer
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
bbox = self.bbox(x)
presence = torch.sigmoid(self.presence(x)) # Use sigmoid for presence score
classes = self.classes(x) # No activation, softmax will be applied in loss function
return bbox, presence, classes
# Custom Hybrid Loss
class HybridLoss(nn.Module):
def __init__(self, lambda_val=[0.5, 0.25, 0.25]):
super(HybridLoss, self).__init__()
self.mse_loss = nn.MSELoss()
self.binary_cross_entropy_loss = nn.BCELoss()
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.lambda_val = lambda_val
def forward(self, bbox_preds, presence_preds, class_preds, bbox_targets, presence_targets, class_targets):
regression_loss = self.mse_loss(bbox_preds, bbox_targets)
presence_loss = self.binary_cross_entropy_loss(presence_preds, presence_targets)
classification_loss = self.cross_entropy_loss(class_preds, class_targets)
total_loss = (self.lambda_val[0] * regression_loss +
self.lambda_val[1] * presence_loss +
self.lambda_val[2] * classification_loss)
return total_loss
# Load pre-trained VGG and modify it
pretrained_vgg = models.vgg16(pretrained=True)
modified_vgg = ModifiedVGG(pretrained_vgg, num_classes=10)
# Loss and optimizer
hybrid_loss = HybridLoss(lambda_val=[0.5, 0.25, 0.25])
optimizer = optim.Adam(modified_vgg.parameters(), lr=0.001)
# Training loop
for epoch in range(num_epochs):
for images, bbox_targets, presence_targets, class_targets in dataloader:
optimizer.zero_grad()
bbox_preds, presence_preds, classes_preds = modified_vgg(images)
loss = hybrid_loss(bbox_preds, presence_preds, class_preds,
bbox_targets, presence_targets, class_targets)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
The following code allows for the model to be tested using a specified image.
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
# Function to load an image and perform necessary transformations
def load_image(image_path):
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match the model's expected input
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0) # Add batch dimension
return image
# Function to predict and draw bounding box, display presence score and class prediction
def predict_and_draw_bbox(image_path, model):
image_tensor = load_image(image_path)
model.eval() # Set the model to evaluation mode
with torch.no_grad():
bbox_pred, presence_pred, class_pred = model(image_tensor)
bbox = bbox_pred.squeeze().tolist()
presence_score = torch.sigmoid(presence_pred).item() # Get the presence score
_, predicted_class = torch.max(class_pred, 1)
# Load the image again to draw on it
image = Image.open(image_path)
plt.figure()
fig, ax = plt.subplots(1)
ax.imshow(image)
# Draw bounding box if presence score is above a certain threshold (e.g., 0.5)
if presence_score > 0.5:
rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
plt.title(f"Predicted class: {predicted_class.item()} with presence score: {presence_score:.2f}")
else:
plt.title("No object detected")
plt.show()
# Example usage
image_path = 'path/to/your/image.jpg'
predict_and_draw_bbox(image_path, modified_vgg)