In [1]:
``````%matplotlib inline
from fastai.basics import *
import jovian``````

Linear Regression

The goal here is to fit a line to set of points.

In [2]:
``n = 100``
In [3]:
``````# Create a tensor of n rows and 2 columns
x = torch.ones(n, 2)``````
In [4]:
``````#Replace every single row of first column with values b/w -1 to 1
x[:, 0].uniform_(-1.,1);x[:3]``````
Out[4]:
``````tensor([[-0.0170,  1.0000],
[-0.2915,  1.0000],
[ 0.2883,  1.0000]])``````
In [5]:
``a = tensor(3.,2); a``
Out[5]:
``tensor([3., 2.])``
In [6]:
``y = x@a + torch.rand(n)``
In [7]:
``plt.scatter(x[:,0], y)``
Out[7]:
The goal is to find the weights `a` such that we can minimize the error between the points & the line `x@a`. Here a is unknown. For regression, loss/error function is mean squared error

In [8]:
``def mean_squared_error(y_hat, y): return((y_hat - y)**2).mean()``
In [9]:
``a = tensor(-1.,1)``
In [10]:
``y_hat=x@a``
In [11]:
``mean_squared_error(y_hat, y)``
Out[11]:
``tensor(7.8217)``
In [12]:
``````plt.scatter(x[:,0], y)
plt.scatter(x[:,0], y_hat)``````
Out[12]:
Here model (logistic regression) & evaluation criteria (or loss function) is specified. Now how can we find the optimized values for `a` in order to find best fitting linear regression.

In [13]:
``a = nn.Parameter(a);a``
Out[13]:
``````Parameter containing:
In [14]:
``````def update():
y_hat = x@a
loss = mean_squared_error(y_hat, y)
if t % 10 == 0: print(loss)
loss.backward()
``````
In [40]:
``````lr = 1e-1
for t in range(100): update()``````
```tensor(7.3373, grad_fn=<MeanBackward1>) tensor(1.3988, grad_fn=<MeanBackward1>) tensor(0.3868, grad_fn=<MeanBackward1>) tensor(0.1484, grad_fn=<MeanBackward1>) tensor(0.0912, grad_fn=<MeanBackward1>) tensor(0.0775, grad_fn=<MeanBackward1>) tensor(0.0742, grad_fn=<MeanBackward1>) tensor(0.0734, grad_fn=<MeanBackward1>) tensor(0.0732, grad_fn=<MeanBackward1>) tensor(0.0731, grad_fn=<MeanBackward1>) ```
In [41]:
``````plt.scatter(x[:,0],y)
plt.scatter(x[:,0],x@a);``````

Animate it

In [42]:
``````from matplotlib import animation, rc
rc('animation', html='jshtml')``````
In [ ]:
