Learn practical skills, build real-world projects, and advance your career
Updated 3 years ago
# uncomment the line below to install dependencies
#!pip install jovian netcdf4 xarry
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import xarray as xr
from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.datasets.utils import download_and_extract_archive
DATASET_SRC = 'https://doi.pangaea.de/10.1594/PANGAEA.909132' # for reference
DATASET_URL = 'https://store.pangaea.de/Publications/IizumiT_2019/gdhy_v1.2_v1.3_20190128.zip'
project_name = 'z2g-project' # project name, used by jovian.commit()
dest = 'yields'
def load_dataset(download_archive=True):
if download_archive:
download_and_extract_archive(DATASET_URL, '.', dest)
data_dir = pathlib.Path(dest)
# one 'crop' per folder. some crops have multiple seasons
# for the purpose of simplicity, we assume each season is a crop
crops = [p.parts[-1] for p in data_dir.iterdir()]
# create an empty dataframe
data_frame = pd.DataFrame()
for crop_dir in data_dir.iterdir():
for data_file in crop_dir.iterdir():
# each crop folder has files, each containing data for a year
# load that into a Pandas dataframe and remove records where
# there is missing data
df = xr.open_dataset(data_file).to_dataframe().dropna().reset_index().rename(columns={'var': 'yield'})
# the 'crop' is the folder name
df['crop'] = crop_dir.parts[-1]
# the file name ends in the year (4-digit)
df['year'] = int(data_file.parts[-1][-8:][:4])
# append to our data set
data_frame = pd.concat([data_frame, df])
return data_frame
data_frame = load_dataset(False)
crops = np.sort(data_frame['crop'].unique())
years = np.sort(data_frame['year'].unique())
crop_groups = data_frame.groupby(['crop'])
# let's look a bit at the data frame
data_frame.head()
data_frame.describe()
sns.distplot(data_frame['year'])
/usr/local/lib/python3.6/dist-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
<matplotlib.axes._subplots.AxesSubplot at 0x7fe04e004e10>