[Notebook] Dask and Xarray on AWS-HPC Cluster: Distributed Processing of Earth Data
This notebook continues the previous post by showing the actual code for distributed data processing.
%matplotlib inline
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from dask.diagnostics import ProgressBar
from dask_jobqueue import SLURMCluster
from distributed import Client, progress
import dask
import distributed
dask.__version__, distributed.__version__
%env HDF5_USE_FILE_LOCKING=FALSE
Data exploration¶
Data are organized by year/month:
ls /fsx
ls /fsx/2008/
ls /fsx/2008/01/data # one variable per file
# hourly data over a month
dr = xr.open_dataarray('/fsx/2008/01/data/sea_surface_temperature.nc')
dr
# Static plot of the first time slice
fig, ax = plt.subplots(1, 1, figsize=[12, 8], subplot_kw={'projection': ccrs.PlateCarree()})
dr[0].plot(ax=ax, transform=ccrs.PlateCarree(), cbar_kwargs={'shrink': 0.6})
ax.coastlines();
What happens to the values over the land? Easier to check by an interactive plot.
import geoviews as gv
import hvplot.xarray
fig_hv = dr[0].hvplot.quadmesh(
x='lon', y='lat', rasterize=True, cmap='viridis', geo=True,
crs=ccrs.PlateCarree(), projection=ccrs.PlateCarree(), project=True,
width=800, height=400,
) * gv.feature.coastline
# fig_hv
# This is just a hack to display figure on Nikola blog post
# If you know an easier way let me know
import holoviews as hv
from bokeh.resources import CDN, INLINE
from bokeh.embed import file_html
from IPython.display import HTML
HTML(file_html(hv.render(fig_hv), CDN))
So it turns out that the "temperature" over the land is set as 273.16K (0 degree celsius). A better way is probably masking them out.
Serial read with master node¶
Let's see how slow it is to read one year of data with only master node.
# Just querying metadata will cause files being pulled from S3 to FSx.
# This takes a while at first executation. Much faster at second time.
%time ds_1yr = xr.open_mfdataset('/fsx/2008/*/data/sea_surface_temperature.nc', chunks={'time0': 50})
dr_1yr = ds_1yr['sea_surface_temperature']
dr_1yr
The aggregated size is ~29 GB:
dr_1yr.nbytes / 1e9 # GB
with ProgressBar():
mean_1yr_ser = dr_1yr.mean().compute()
mean_1yr_ser
It takes ~2 min. Further reading the 10-year full data would take ~20 min. Such slowness encourages the use of a distributed cluster.
Parallel read with dask cluster¶
Cluster initialization¶
!sinfo # spin-up 8 idle nodes with AWS ParallelCluster
!mkdir -p ./dask_tempdir
# Reference: https://jobqueue.dask.org/en/latest/configuration.html
# - "cores" is the number of CPUs used per Slurm job.
# Here fix it as 72, which is the number of vCPUs per c5n.18xl node. So one slurm job gets exactly one node.
# - "processes" specifies the number of dask workers in a single Slurm job.
# - "memory" specifies the memory requested in a single Slurm job.
cluster = SLURMCluster(cores=72, processes=36, memory='150GB',
local_directory='./dask_tempdir')
# 8 node * 36 workers/node
cluster.scale(8*36)
cluster
Visit http://localhost:8787 for the dashboard.
# remember to also create dask client to talk to the cluster!
client = Client(cluster) # automatically switches to distributed mode
client
# now the default scheduler is dask.distributed
dask.config.get('scheduler')
!sinfo # nodes are now fully allocated
!squeue # all are dask worker jobs, one per compute node
Read 1-year data¶
# Actually, no need to reopen files. Can just reuse the previous dask graph and put it onto the cluster
ds_1yr = xr.open_mfdataset('/fsx/2008/*/data/sea_surface_temperature.nc', chunks={'time0': 50})
dr_1yr = ds_1yr['sea_surface_temperature']
dr_1yr
%time mean_1yr_par = dr_1yr.mean().compute()
The throughput is like 29GB/5s ~ 6 GB/s. This seems to exceed our Lustre bandwidth of ~3GB/s. That's likely because the actual NetCDF files are compressed, and the 29 GB is just for in-memory arrays.
mean_1yr_par.equals(mean_1yr_ser) # consistent with serial result
len(dr_1yr.chunks[0]) # number of dask chunks
There are actually not that many chunks for dask workers, even though I am using quite small chunks. Let's try more files.
Read multi-year data¶
For this part you might get "Too many files open" error. If so, run sudo sh -c "ulimit -n 65535 && exec su $LOGNAME"
to raise the limit before starting Jupyter (ref: https://stackoverflow.com/a/17483998).
file_list = [f'/fsx/{year}/{month:02d}/data/sea_surface_temperature.nc'
for year in range(2008, 2018) for month in range(1, 13)]
len(file_list) # number of files
file_list[0:3], file_list[-3:] # span over multiple years
# Will cause data being pulled from S3 to FSx.
# Will take a long time at first executation. Much faster at second time.
%time ds_10yr = xr.open_mfdataset(file_list, chunks={'time0': 50})
dr_10yr = ds_10yr['sea_surface_temperature']
dr_10yr
Near 300 GB!
dr_10yr.nbytes / 1e9 # GB
%time mean_10yr = dr_10yr.mean().compute()
Throughput is like 287GB/15s = 19 GB/s ?! Again that's likely due to HDF5/NetCDF compression.
Finally, instead of getting a scalar value (which is boring), let's get a time-series of global mean SST:
%time ts_10yr = dr_10yr.mean(dim=['lat', 'lon']).compute()
Despite my vastly inaccurate approximation (not masking out land, not weighting by grid cell areas, in order to keep the code simple), we can still see a clear increase of mean SST over the past years (0.2 °C is sort of a big deal for climate; more rigorous calculations suggest more).
ts_10yr.plot(size=6)
Comments
Comments powered by Disqus