Deep Learning on multi-label text classification with FastAi

A multi-label classification  has multiple target values associated with  dataset. Here we are predicting probability of each class instead of predicting a single class.

In this post, I will explain about the multi-label text classification problem with fastai. Here we have used Toxic Comment Classification Challenge to explain how FastAi works for multi-label problem.

Lets look at the data

Let’s have a look on the overview of data and know the data types of each features, to understand the importance of features.
For this problem, we have 6 label classes i.e;  6 different toxicity are as follow :

  • toxic
  • severe_toxic
  • obscene
  • threat
  • insult
  • identity_hate

We have to create a model which predicts a probability of each type of toxicity for each comment.

Load and analyse data

Fast ai expects the data to be loaded as a Data Bunch and then a Fast ai Learner can use this data for the models.  Here, we will first create data bunch with our train dataset.

Fit the deep learning model with domain specific data

First we will fit our model with train data without target values so that our model knows better about our data.

Re-fit model with classification label

Here we will re-fit our model with our target values and tuned our model for better accuracy result.

 

Let’s predict the target values and compare with original target values.

 

Get Prediction

Let’s get the prediction and create the submission file to submit it in Kaggle.

 

 

All the code

All the code for this task can be found here on Kaggle kernels:

 

 

 

 

1 thought on “Deep Learning on multi-label text classification with FastAi”

  1. Hi Nikita,
    I’ve been doing the fastai course, and having lots of memory problems on the little 6GB GPU I have. Your writing has helped me a lot. Your code is excellent, your blog is fantastic, I cannot thank you enough! Thank you, and keep up the awesome work!
    Hon

Leave a Reply

Your email address will not be published. Required fields are marked *