Learn practical skills, build real-world projects, and advance your career
# uncomment the line below to install dependencies
#!pip install jovian netcdf4 xarry

import pathlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import xarray as xr

from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.datasets.utils import download_and_extract_archive

DATASET_SRC = 'https://doi.pangaea.de/10.1594/PANGAEA.909132'   # for reference
DATASET_URL = 'https://store.pangaea.de/Publications/IizumiT_2019/gdhy_v1.2_v1.3_20190128.zip'

project_name = 'z2g-project'    # project name, used by jovian.commit()
dest = 'yields'


def load_dataset(download_archive=True):
  if download_archive:
    download_and_extract_archive(DATASET_URL, '.', dest)
  data_dir = pathlib.Path(dest)

  # one 'crop' per folder. some crops have multiple seasons
  # for the purpose of simplicity, we assume each season is a crop
  crops = [p.parts[-1] for p in data_dir.iterdir()]

  # create an empty dataframe
  data_frame = pd.DataFrame()

  for crop_dir in data_dir.iterdir():
    for data_file in crop_dir.iterdir():
      # each crop folder has files, each containing data for a year
      # load that into a Pandas dataframe and remove records where
      # there is missing data
      df = xr.open_dataset(data_file).to_dataframe().dropna().reset_index().rename(columns={'var': 'yield'})

      # the 'crop' is the folder name
      df['crop'] = crop_dir.parts[-1]

      # the file name ends in the year (4-digit)
      df['year'] = int(data_file.parts[-1][-8:][:4])

      # append to our data set
      data_frame = pd.concat([data_frame, df])

  return data_frame
data_frame = load_dataset(False)
crops = np.sort(data_frame['crop'].unique())
years = np.sort(data_frame['year'].unique())

crop_groups = data_frame.groupby(['crop'])
# let's look a bit at the data frame
data_frame.head()
data_frame.describe()
sns.distplot(data_frame['year'])
/usr/local/lib/python3.6/dist-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
<matplotlib.axes._subplots.AxesSubplot at 0x7fe04e004e10>
Notebook Image