From 70202603c3d6c36a2adb6491bd0d5d6ad25d6def Mon Sep 17 00:00:00 2001 From: ftong Date: Wed, 4 Jun 2025 17:37:46 +0200 Subject: [PATCH] Update src/seismic_hazard_forecasting.py --- src/seismic_hazard_forecasting.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/seismic_hazard_forecasting.py b/src/seismic_hazard_forecasting.py index 0e889fd..91ea5cb 100644 --- a/src/seismic_hazard_forecasting.py +++ b/src/seismic_hazard_forecasting.py @@ -500,6 +500,12 @@ verbose: {verbose}") mode='reflect', anti_aliasing=False) iml_grid_hd[iml_grid_hd == 0.0] = np.nan # change zeroes back to nan + # trim edges so the grid is not so blocky + vmin_hd = min(x for x in iml_grid_hd.flatten() if not math.isnan(x)) + vmax_hd = max(x for x in iml_grid_hd.flatten() if not math.isnan(x)) + trim_thresh = vmin + iml_grid_hd[iml_grid_hd < trim_thresh] = np.nan + # generate image overlay north, south = lat.max(), lat.min() # Latitude range east, west = lon.max(), lon.min() # Longitude range @@ -508,8 +514,10 @@ verbose: {verbose}") map_center = [np.mean([north, south]), np.mean([east, west])] # Create an image from the grid + cmap_name = 'viridis' + cmap = plt.get_cmap(cmap_name) fig, ax = plt.subplots(figsize=(6, 6)) - ax.imshow(iml_grid_hd, origin='lower', cmap='viridis') + ax.imshow(iml_grid_hd, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax) ax.axis('off') # Save the figure @@ -518,7 +526,6 @@ verbose: {verbose}") plt.close(fig) # Make the color bar - cmap_name = 'viridis' width = 50 height = 500 @@ -528,11 +535,11 @@ verbose: {verbose}") fig, ax = plt.subplots(figsize=((width + 40) / 100.0, (height + 20) / 100.0), dpi=100) # Increase fig size for labels - ax.imshow(gradient, aspect='auto', cmap=plt.get_cmap(cmap_name), - extent=[0, 1, vmin, vmax]) # Note: extent order is different for vertical + ax.imshow(gradient, aspect='auto', cmap=cmap.reversed(), + extent=[0, 1, vmin, vmax_hd]) # Note: extent order is different for vertical ax.set_xticks([]) # Remove x-ticks for vertical colorbar num_ticks = 11 # Show more ticks - tick_positions = np.linspace(vmin, vmax, num_ticks) + tick_positions = np.linspace(vmin, vmax_hd, num_ticks) ax.set_yticks(tick_positions) ax.set_yticklabels([f"{tick:.2f}" for tick in tick_positions]) # format tick labels ax.set_title(products[j], pad=15)