Learn practical skills, build real-world projects, and advance your career
%pylab inline

import jovian
Populating the interactive namespace from numpy and matplotlib
import matplotlib.tri as tri
import scipy.stats as stats
import numpy as np

_corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
_triangle = tri.Triangulation(_corners[:, 0]+[.01,-.01,0], _corners[:, 1]+[0.01,0.01,-0.01])
_midpoints = [(_corners[(i + 1) % 3] + _corners[(i + 2) % 3]) / 2.0 \
              for i in range(3)]

def xy2bc(xy, tol=0):
    '''Converts 2D Cartesian coordinates to barycentric.
    Arguments:
        `xy`: A length-2 sequence containing the x and y value.
    '''
    s = [(xy - _midpoints[i]) @ (_corners[i] - _midpoints[i]) / 0.75 \
         for i in range(3)]
    return np.clip(s, tol, 1.0 - tol).T

def draw_pdf_contours(dist, linestyles, border=False, nlevels=5, subdiv=8, **kwargs):
    '''Draws pdf contours over an equilateral triangle (2-simplex).
    Arguments:
        `dist`: A distribution instance with a `pdf` method.
        `border` (bool): If True, the simplex border is drawn.
        `nlevels` (int): Number of contours to draw.
        `subdiv` (int): Number of recursive mesh subdivisions to create.
        kwargs: Keyword args passed on to `plt.triplot`.
    '''
    from matplotlib import ticker, cm
    import math

    refiner = tri.UniformTriRefiner(_triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    # problem is here
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
    plt.tricontour(trimesh, pvals, nlevels, width=2, linestyles=linestyles, **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')
    if border is True:
        plt.hold(1)
        plt.triplot(_triangle, linewidth=1)

def plot_points(X, barycentric=True, border=True, ms=10, **kwargs):
    '''Plots a set of points in the simplex.
    Arguments:
        `X` (ndarray): A 2xN array (if in Cartesian coords) or 3xN array
                       (if in barycentric coords) of points to plot.
        `barycentric` (bool): Indicates if `X` is in barycentric coords.
        `border` (bool): If True, the simplex border is drawn.
        kwargs: Keyword args passed on to `plt.plot`.
    '''
    if barycentric is True:
        X = X.dot(_corners)
    plt.plot(X[:, 0], X[:, 1], 'k.', ms=ms, **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')
    #if border is True:
         #plt.hold(1)
         #plt.triplot(_triangle, linewidth=2, color='k')

# refiner = tri.UniformTriRefiner(_triangle)
# trimesh = refiner.refine_triangulation(subdiv=8)
# file from /mv_users/mgwillia/cvpr/hacks/birds/init_pretrained_preds_cross_val.npy
#PREDS = np.load('init_pretrained_preds_cross_val.npy')
PREDS = np.load('init_resnet_preds_cross_val.npy')
# 200 x 120 x 200
# This loads a 200 x 30 x 200 tensor of data, the first dimension is the class it's representing,
#     the second indexes the training image and the third dimension is the probability vector over the 200 classes

# TODO **** NEED TO MAKE SURE TO PULL OUT ANY COMPLETELY ZERO ROWS!!!!  It means that there aren't 30 training examples
# non zeros
'''for classPred in range(200):
    for imageNum in range(120):
        row = PREDS[:, imageNum, classPred]
        if np.sum(row) == 0:
            print(row)
'''
PREDS = PREDS[:, 0:30, :]
def plot_preds( all_preds, three_classes, class_to_plot, color ):
    # get the points for the class_to_plot, but only the values for the 3 classes we're plotting, then re-normalize them so that they sum to one
    samples = all_preds[class_to_plot][:,three_classes]
    valid_rows = np.sum(samples,axis=1) > 0
    #pts = all_preds[class_to_plot][:,three_classes]
    pts = samples[valid_rows,:]
    pts /= np.atleast_2d(np.sum(pts,axis=1)).T * (np.ones((1,3)))
    
    # next, given the class to plot, figure out which examples are correct and which are wrong
    which = np.argmax(np.asarray(three_classes) == class_to_plot)
    correct = (np.argmax(pts,axis=1) == which)
    wrong = np.logical_not(correct)
    
    # plot the correct ones as dots and the incorrect as x's
    plot_points(pts[correct,:], color=color,marker='.')
    plot_points(pts[wrong,:], color=color,marker='x',ms=6)
    return pts, correct, wrong

def get_preds( all_preds, three_classes, class_to_plot, color ):
    # get the points for the class_to_plot, but only the values for the 3 classes we're plotting, then re-normalize them so that they sum to one
    samples = all_preds[class_to_plot][:,three_classes]
    valid_rows = np.sum(samples,axis=1) > 0
    #pts = all_preds[class_to_plot][:,three_classes]
    pts = samples[valid_rows,:]
    pts /= np.atleast_2d(np.sum(pts,axis=1)).T * (np.ones((1,3)))
    
    # next, given the class to plot, figure out which examples are correct and which are wrong
    which = np.argmax(np.asarray(three_classes) == class_to_plot)
    correct = (np.argmax(pts,axis=1) == which)
    wrong = np.logical_not(correct)
    return pts, correct, wrong


#CROW = 29, Red-winged Blackbird = 10, RAVEN = 107
if False:
    plot_preds(PREDS, [28,9,106], 28, color='c')
    plot_preds(PREDS, [28,9,106], 9, color='r')
    plot_preds(PREDS, [28,9,106], 106, color='k')
else:
    # Arctic Tern = 141, Caspian Tern = 142, Common Tern = 143
    # becca -- Arctic Tern = 141, Caspian Tern = 143, Common Tern = 144
    plot_preds(PREDS, [140,142,143], 140, color='c')
    plot_preds(PREDS, [140,142,143], 142, color='r')
    plot_preds(PREDS, [140,142,143], 143, color='k')
#PREDS.shape
Notebook Image
#import sys
import scipy as sp
import scipy.stats as stats
from scipy.special import (psi, polygamma, gammaln)
from numpy import (array, asanyarray, ones, arange, log, diag, vstack, exp,
        asarray, ndarray, zeros, isscalar)
from numpy.linalg import norm
import numpy as np

GAMMA = -1*psi(1) # Euler-Mascheroni constant

def loglikelihood(D, a):
    '''Compute log likelihood of Dirichlet distribution, i.e. log p(D|a).
    Parameters
    ----------
    D : 2D array
        where ``N`` is the number of observations, ``K`` is the number of
        parameters for the Dirichlet distribution.
    a : array
        Parameters for the Dirichlet distribution.
    Returns
    -------
    logl : float
        The log likelihood of the Dirichlet distribution'''
    N, K = D.shape
    logp = log(D).mean(axis=0)
    return N*(gammaln(a.sum()) - gammaln(a).sum() + ((a - 1)*logp).sum())


def mse(D, tol=1e-7, maxiter=None):
    '''Mean and precision alternating method for MLE of Dirichlet
    distribution'''
    N, K = D.shape
    logp = log(D).mean(axis=0)
    a0 = _init_a(D)
    s0 = a0.sum()
    if s0 < 0:
        a0 = a0/s0
        s0 = 1
    elif s0 == 0:
        a0 = ones(a.shape) / len(a)
        s0 = 1
    m0 = a0/s0

    # Start updating
    if maxiter is None:
        maxiter = 10000
    for i in range(maxiter):
        # print(i)
        a1 = _fit_s(D, a0, logp, tol=tol)
        s1 = sum(a1)
        a1 = _fit_m(D, a1, logp, tol=tol)
        m = a1/s1
        # if norm(a1-a0) < tol:
        if abs(loglikelihood(D, a1)-loglikelihood(D, a0)) < tol: # much faster
            return a1
        a0 = a1
    raise Exception('Failed to converge after {} iterations, values are {}.'
                    .format(maxiter, a1))

def _fit_s(D, a0, logp, tol=1e-7, maxiter=1000):
    '''Assuming a fixed mean for Dirichlet distribution, maximize likelihood
    for preicision a.k.a. s'''
    N, K = D.shape
    s1 = a0.sum()
    m = a0 / s1
    mlogp = (m*logp).sum()
    for i in range(maxiter):
        s0 = s1
        g = psi(s1) - (m*psi(s1*m)).sum() + mlogp
        h = _trigamma(map(lambda)s1) - ((m**2)*_trigamma(s1*m)).sum()

        if g + s1 * h < 0:
            s1 = 1/(1/s0 + g/h/(s0**2))
        if s1 <= 0:
            s1 = s0 * exp(-g/(s0*h + g)) # Newton on log s
        if s1 <= 0:
            s1 = 1/(1/s0 + g/((s0**2)*h + 2*s0*g)) # Newton on 1/s
        if s1 <= 0:
            s1 = s0 - g/h # Newton
        if s1 <= 0:
            raise Exception('Unable to update s from {}'.format(s0))

        a = s1 * m
        if abs(s1 - s0) < tol:
            return a

    raise Exception('Failed to converge after {} iterations, s is {}'
            .format(maxiter, s1))

def _fit_m(D, a0, logp, tol=1e-7, maxiter=1000):
    '''With fixed precision s, maximize mean m'''
    N,K = D.shape
    s = a0.sum()

    for i in range(maxiter):
        m = a0 / s
        a1 = _ipsi(logp + (m*(psi(a0) - logp)).sum())
        a1 = a1/a1.sum() * s

        if norm(a1 - a0) < tol:
            return a1
        a0 = a1

    raise Exception('Failed to converge after {} iterations, s is {}'
            .format(maxiter, s))

def _piecewise(x, condlist, funclist, *args, **kw):
    '''Fixed version of numpy.piecewise for 0-d arrays'''
    x = asanyarray(x)
    n2 = len(funclist)
    if isscalar(condlist) or \
            (isinstance(condlist, np.ndarray) and condlist.ndim == 0) or \
            (x.ndim > 0 and condlist[0].ndim == 0):
        condlist = [condlist]
    condlist = [asarray(c, dtype=bool) for c in condlist]
    n = len(condlist)

    zerod = False
    # This is a hack to work around problems with NumPy's
    #  handling of 0-d arrays and boolean indexing with
    #  numpy.bool_ scalars
    if x.ndim == 0:
        x = x[None]
        zerod = True
        newcondlist = []
        for k in range(n):
            if condlist[k].ndim == 0:
                condition = condlist[k][None]
            else:
                condition = condlist[k]
            newcondlist.append(condition)
        condlist = newcondlist

    if n == n2-1:  # compute the "otherwise" condition.
        totlist = condlist[0]
        for k in range(1, n):
            totlist |= condlist[k]
        condlist.append(~totlist)
        n += 1
    if (n != n2):
        raise ValueError(
                "function list and condition list must be the same")

    y = zeros(x.shape, x.dtype)
    for k in range(n):
        item = funclist[k]
        if not callable(item):
            y[condlist[k]] = item
     
   else:
            vals = x[condlist[k]]
            if vals.size > 0:
                y[condlist[k]] = item(vals, *args, **kw)
    if zerod:
        y = y.squeeze()
    return y

def _init_a(D):
    '''Initial guess for Dirichlet alpha parameters given data D'''
    E = D.mean(axis=0)
    E2 = (D**2).mean(axis=0)
    return ((E[0] - E2[0])/(E2[0]-E[0]**2)) * E

def _ipsi(y, tol=1.48e-9, maxiter=10):
    '''Inverse of psi (digamma) using Newton's method. For the purposes
    of Dirichlet MLE, since the parameters a[i] must always
    satisfy a > 0, we define ipsi :: R -> (0,inf).'''
    y = asanyarray(y, dtype='float')
    x0 = _piecewise(y, [y >= -2.22, y < -2.22],
            [(lambda x: exp(x) + 0.5), (lambda x: -1/(x+GAMMA))])
    for i in range(maxiter):
        x1 = x0 - (psi(x0) - y)/_trigamma(x0)
        if norm(x1 - x0) < tol:
            return x1
        x0 = x1
    raise Exception(
        'Unable to converge in {} iterations, value is {}'.format(maxiter, x1))

def _trigamma(x):
    return polygamma(1, x)
File "<ipython-input-138-64df0d5d4dca>", line 71 h = _trigamma(map(lambda)s1) - ((m**2)*_trigamma(s1*m)).sum() ^ SyntaxError: invalid syntax

def plot_correct(correct_params, correct_pts, color):
    # Create a dirichlet -- use MLE (Maximimul Likelihood Estimation) to fit a Dirichlet to the data
    alpha = correct_params
    # Then sample some points to draw -- plot the points that were used to fit the Dirichlet
    data = correct_pts
    # Plot the points
    plot_points(data, color=color,marker='.')
    # Plot the contours for the Dirichlet
    draw_pdf_contours(stats.dirichlet(alpha), 'solid', colors=color)
    #mu = data.mean(0) @ _corners
    #scatter(*mu, color='b', marker='x')

def plot_wrong(wrong_params, wrong_pts, color):
    alpha2 = wrong_params
    draw_pdf_contours(stats.dirichlet(alpha2),'dashed',colors=color)
    data = wrong_pts
    plot_points(data, color=color,marker='x')
    #mu = data.mean(0) @ _corners
    #scatter(*mu, color='r', marker='x')