Learn practical skills, build real-world projects, and advance your career

Insurance cost prediction using linear regression

In this assignment we're going to use information like a person's age, sex, BMI, no. of children and smoking habit to predict the price of yearly medical bills. This kind of model is useful for insurance companies to determine the yearly insurance premium for a person. The dataset for this problem is taken from: https://www.kaggle.com/mirichoi0218/insurance

We will create a model with the following steps:

  1. Download and explore the dataset
  2. Prepare the dataset for training
  3. Create a linear regression model
  4. Train the model to fit the data
  5. Make predictions using the trained model

This assignment builds upon the concepts from the first 2 lectures. It will help to review these Jupyter notebooks:

As you go through this notebook, you will find a ??? in certain places. Your job is to replace the ??? with appropriate code or values, to ensure that the notebook runs properly end-to-end . In some cases, you'll be required to choose some hyperparameters (learning rate, batch size etc.). Try to experiment with the hypeparameters to get the lowest loss.

# Uncomment and run the commands below if imports fail
!conda install numpy pytorch torchvision cpuonly -c pytorch -y
!pip install matplotlib --upgrade --quiet
!pip install jovian --upgrade --quiet
!pip install pandas
!pip install seaborn
/bin/bash: conda: command not found Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (1.0.4) Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from pandas) (1.18.5) Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.6.1->pandas) (1.12.0) Requirement already satisfied: seaborn in /usr/local/lib/python3.6/dist-packages (0.10.1) Requirement already satisfied: scipy>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from seaborn) (1.4.1) Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from seaborn) (1.18.5) Requirement already satisfied: matplotlib>=2.1.2 in /usr/local/lib/python3.6/dist-packages (from seaborn) (3.2.2) Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from seaborn) (1.0.4) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.1.2->seaborn) (2.8.1) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.1.2->seaborn) (1.2.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.1.2->seaborn) (2.4.7) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.1.2->seaborn) (0.10.0) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->seaborn) (2018.9) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib>=2.1.2->seaborn) (1.12.0)
import torch
import jovian
import torchvision
import torch.nn as nn
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torch.utils.data import DataLoader, TensorDataset, random_split
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead. import pandas.util.testing as tm
project_name='02-insurance-linear-regression' 

Step 1: Download and explore the data

Let us begin by downloading the data. We'll use the download_url function from PyTorch to get the data as a CSV (comma-separated values) file.