Learn practical skills, build real-world projects, and advance your career
# uncomment the line below to install dependencies
# !pip install jovian netcdf4 xarray

import pathlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import xarray as xr

from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.datasets.utils import download_and_extract_archive

DATASET_SRC = 'https://doi.pangaea.de/10.1594/PANGAEA.909132'   # for reference
DATASET_URL = 'https://store.pangaea.de/Publications/IizumiT_2019/gdhy_v1.2_v1.3_20190128.zip'

project_name = 'z2g-project'    # project name, used by jovian.commit()
dest = 'yields'


def load_dataset(download_archive=True):
  if download_archive:
    download_and_extract_archive(DATASET_URL, '.', dest)
  data_dir = pathlib.Path(dest)

  # one 'crop' per folder. some crops have multiple seasons
  # for the purpose of simplicity, we assume each season is a crop
  crops = [p.parts[-1] for p in data_dir.iterdir()]

  # create an empty dataframe
  data_frame = pd.DataFrame()

  for crop_dir in data_dir.iterdir():
    for data_file in crop_dir.iterdir():
      # each crop folder has files, each containing data for a year
      # load that into a Pandas dataframe and remove records where
      # there is missing data
      df = xr.open_dataset(data_file).to_dataframe().dropna().reset_index().rename(columns={'var': 'yield'})

      # the 'crop' is the folder name
      df['crop'] = crop_dir.parts[-1]

      # the file name ends in the year (4-digit)
      df['year'] = int(data_file.parts[-1][-8:][:4])

      # append to our data set
      data_frame = pd.concat([data_frame, df])

  return data_frame
data_frame = load_dataset(False)  # set the True to download the data
crops = np.sort(data_frame['crop'].unique())
years = np.sort(data_frame['year'].unique())
# let's look a bit at the data frame
data_frame.head()
data_frame.describe()
sns.displot(data_frame['year'])
<seaborn.axisgrid.FacetGrid at 0x7f36700b0f98>
Notebook Image