Skip to content

cygnus77/nerve-segmentation

Repository files navigation

Semantic Segmentation in PyTorch: Ultrasound Nerve Segmentation

Ultrasound Nerve Segmentation is a Kaggle challenge to build a model that can identify nerve structures in a dataset of ultrasound images of the neck. The dataset in this challenge is a great resource for learning and testing semantic segmentation algorithms. Here, I use PyTorch and Keras to explore semantic segmentation on this dataset.

Sample Results

  • Purple: True Positive (model prediction matches nerve area marked by human)
  • Yellow: False Negative (model missed nerve in the area)
  • Green: False Positive (model incorrectly predicted nerve in the area)

After 15 epochs of training, the model achieves an F-score of 0.75.

The Data

Ultrasound images are provided as 8bit/pixel grayscale, LZW compressed TIFF images of dimension 580 x 420. For each ultrasound image, a mask image of the same type and dimensions as the ultrasound is provided. Each mask image contains one or more manually marked areas representing location of the nerve segments (if any) in the corresponding ultrasound image.

Ultrasound image Label mask
ultrasoundimage mask

OpenCV decodes each image into a 580x420x3 numpy uint8 array. Though the image is grayscale, I retained the 3 channels as pre-trained VGG model expects 3 channel input.

For smooth down and up sampling, I cropped the images to 576x416 (multiples of 32).

I developed a tool image-analysis.py to:

  • Identify duplicate images, with possibly differing masks (labeling errors).
  • View duplicate images and make corrections.
  • Distribution of frames with and without nerve segments to help balance dataset prior to training.

Duplicates:

Output of image-analysis.py

image-analysis.py works in two modes: scanning and analysis. The '-scan' option scans all data files to spot duplicates by computing image differences for every possible pair of images and stores the results in a file. In the second step, it can show side-by-side comparisons of duplicates, histogram of difference values and allow user to enter corrections.

  • 63 duplicates images with difference < 100, 52 of which have mismatching masks
  • 131 duplicates images with difference < 1000, 107 of which have mismatching masks

Histogram of image differences

(log scale)

Neural Net Architecture

VGG-16 is a fairly simple deep network that is commonly used for image segmentation. Though VGG-16 is less accurate than the larger Resnet or Inception networks and slower than Mobilenets, its simple architecture lends itself to extension by adding additional layers, introducing skip-connections, etc.

Model

Loss Functions and Metrics

During training the following metrics are computed after every 30 batches of training.

def iou(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return (intersection + 1.) / 
                (torch.sum(y_true) + torch.sum(y_pred) - intersection + 1.)
def dice_coef(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return (2. * intersection + 1.) / 
                (torch.sum(y_true) + torch.sum(y_pred) + 1.)
  • [False Positive and False Negative]
def falsepos(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return torch.sum(y_pred) - intersection

def falseneg(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return torch.sum(y_true) - intersection
def precision(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return intersection / (torch.sum(y_pred) + 1.)

def recall(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return intersection / (torch.sum(y_true) + 1.)
def fscore(y_true, y_pred):
    presci = precision(y_true, y_pred)
    rec = recall(y_true, y_pred)
    return 2*(presci * rec)/(presci + rec)
  • Loss function:
    • Adapt F1-score function, weighted to improve recall as the loss function.
    • Returns negative f1 score, since larger fscore is preferable and optimizer will push it towards a larger negative value.
def weighted_fscore_loss(weight):
    def fscore_loss(y_true, y_pred):
        presci = precision(y_true, y_pred)
        rec = recall(y_true, y_pred)
        return -(1+weight)*(presci * rec) / (weight*presci + rec)
    return fscore_loss

Training

dataset.py: Data set is randomly split into training and validation sets - 80% training and 20% validation.

Batches of data are transformed with torchvision.transforms:

  • Converting into tensors
  • Normalized to values between -1 and 1 with mean 0.5 and std. dev 0.5.

train.py: Batches of training data are loaded intp the GPU for computing the forward pass and getting the output of the network, calculating losses by comparing with labeled data and updating gradients in the backward pass through the loss function.

for i, (inputs, labels) in enumerate(train_loader, 0):

    # map to gpu
    inputs, labels = inputs.cuda(), labels.cuda()

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

Since I chose not to resize the images, I could only fit 10 images per batch in the GPU. Resizing to smaller dimensions would have allowed larger batch sizes at the expense of precision.

At about 10 epochs of training, validation and training losses plateaued.

Optimizer and learning rate:

Torch offeres several optimization strategies, I used the popular Adam optimizer as it tends to converge quickly. Through trial and error, I arrived at a learning rate of 1e-5. Further, I used pytorch's learning rate scheduler ReduceLROnPlateau to automatically scale the learning rate when validation starts stagnating. But in testing 10 epochs, the learning rate was not reduced.

Evaluation

Metrics from training are sent to Visdom server for visualization.

Mean Training and Validation Losses

  • Orange: Validation loss
  • Blue: Training loss

IOU and Dice Coefficient improves with training

  • Orange: IOU
  • Blue: Dice coefficient

False Positives and False Negatives drop

  • Orange: False negatives
  • Blue: False positives

Precision, Recall and F1-Score improve with training

  • Orange: Recall
  • Blue: Precision
  • Green: F1-score

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published