Learn practical skills, build real-world projects, and advance your career

Insurance cost prediction using linear regression

Make a submisson here: https://jovian.ai/learn/deep-learning-with-pytorch-zero-to-gans/assignment/assignment-2-train-your-first-model

In this assignment we're going to use information like a person's age, sex, BMI, no. of children and smoking habit to predict the price of yearly medical bills. This kind of model is useful for insurance companies to determine the yearly insurance premium for a person. The dataset for this problem is taken from Kaggle.

We will create a model with the following steps:

  1. Download and explore the dataset
  2. Prepare the dataset for training
  3. Create a linear regression model
  4. Train the model to fit the data
  5. Make predictions using the trained model

This assignment builds upon the concepts from the first 2 lessons. It will help to review these Jupyter notebooks:

As you go through this notebook, you will find a ??? in certain places. Your job is to replace the ??? with appropriate code or values, to ensure that the notebook runs properly end-to-end . In some cases, you'll be required to choose some hyperparameters (learning rate, batch size etc.). Try to experiment with the hypeparameters to get the lowest loss.

import torch
import jovian
import torchvision
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torch.utils.data import DataLoader, TensorDataset, random_split
project_name='Assignment 02-insurance-linear-regression' # will be used by jovian.commit

Step 1: Download and explore the data

Let us begin by downloading the data. We'll use the download_url function from PyTorch to get the data as a CSV (comma-separated values) file.

DATASET_URL = "https://hub.jovian.ml/wp-content/uploads/2020/05/insurance.csv"
DATA_FILENAME = "insurance.csv"
download_url(DATASET_URL, '.')
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))