Compare commits
	
		
			10 Commits
		
	
	
		
			switching_
			...
			master
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 281c73764d | |||
| e86f131cc0 | |||
|   | 5c3ce04868 | ||
| 94054bc391 | |||
| 318a344c15 | |||
| bb2e136d42 | |||
| 87de2e7a6c | |||
| 503bec883e | |||
|   | 4658b8d866 | ||
|   | 4c2679a005 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -3,6 +3,7 @@ __pycache__/ | ||||
| */.ipynb_checkpoints/ | ||||
| .ipynb_checkpoints/ | ||||
| .env | ||||
| *.out | ||||
| weights/ | ||||
| datasets/ | ||||
| wip | ||||
| @@ -10,4 +11,4 @@ artifacts/ | ||||
| wandb/ | ||||
| scripts/pred/ | ||||
| scripts/pred_resampled/ | ||||
| scripts/lightning_logs/ | ||||
| scripts/lightning_logs/ | ||||
|   | ||||
							
								
								
									
										99
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										99
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,12 +2,13 @@ | ||||
|  | ||||
|  | ||||
| This repo contains notebooks and scripts demonstrating how to: | ||||
| - Prepare data for training a seisbench model detecting P and S waves (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20from%20Bogdanka%20to%20Seisbench%20format.ipynb) and the [script](utils/mseeds_to_seisbench.py) | ||||
| - [to update] Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb) | ||||
| - Prepare data for training a seisbench model detecting P and S waves (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](notebooks/Transforming%20mseeds%20from%20Bogdanka%20to%20Seisbench%20format.ipynb) and the [script](scripts/mseeds_to_seisbench.py) | ||||
|  | ||||
| [//]: # (- [to update] Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb)) | ||||
| - Train various cnn models available in seisbench library and compare their performance of detecting P and S waves, check the [script](scripts/pipeline.py)  | ||||
|    | ||||
| - [to update] Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb) | ||||
| - [to update] Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb) | ||||
| [//]: # (- [to update] Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)) | ||||
| [//]: # (- [to update] Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)) | ||||
|  | ||||
|  | ||||
| ### Acknowledgments | ||||
| @@ -21,7 +22,7 @@ Please download and install [Mambaforge](https://github.com/conda-forge/miniforg | ||||
| After successful installation and within the Mambaforge environment please clone this repository:  | ||||
|  | ||||
| ``` | ||||
| git clone ssh://git@git.plgrid.pl:7999/eai/platform-demo-scripts.git | ||||
| git clone https://epos-apps.grid.cyfronet.pl/epos-ai/platform-demo-scripts.git | ||||
| ``` | ||||
| and please run for Linux or Windows platforms: | ||||
|  | ||||
| @@ -69,10 +70,13 @@ poetry shell | ||||
|    WANDB_USER="your user" | ||||
|    WANDB_PROJECT="training_seisbench_models" | ||||
|    BENCHMARK_DEFAULT_WORKER=2 | ||||
|    ``` | ||||
|  | ||||
| 2. Transform data into seisbench format.  | ||||
|      | ||||
|     To utilize functionality of Seisbench library, data need to be transformed to [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)). If your data is in the MSEED format, you can use the prepared script `mseeds_to_seisbench.py` to perform the transformation. Please make sure that your data has the same structure as the data used in this project. | ||||
|     To utilize functionality of Seisbench library, data need to be transformed to [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)).  | ||||
|  | ||||
|     If your data is stored in the MSEED format and catalog in the QuakeML format, you can use the prepared script `mseeds_to_seisbench.py` to perform the transformation. Please make sure that your data has the same structure as the data used in this project. | ||||
|     The script assumes that: | ||||
|    *  the data is stored in the following directory structure: | ||||
|    `input_path/year/station_network_code/station_code/trace_channel.D` e.g. | ||||
| @@ -80,24 +84,20 @@ poetry shell | ||||
|     * the file names follow the pattern:   | ||||
|     `station_network_code.station_code..trace_channel.D.year.day_of_year` | ||||
|    e.g. `PL.ALBE..EHE.D.2018.282` | ||||
|     * events catalog is stored in quakeML format | ||||
|     | ||||
|     Run the script `mseeds_to_seisbench` located in the `utils` directory | ||||
|  | ||||
|     Run the `mseeds_to_seisbench.py` script with the following arguments: | ||||
|     ``` | ||||
|     cd utils | ||||
|     python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path | ||||
|     ``` | ||||
|     If you want to run the script on a cluster, you can use the script `convert_data.sh` as a template (adjust the grant name, computing name and paths) and send the job to queue using sbatch command on login node of e.g. Ares:  | ||||
|     | ||||
|     ``` | ||||
|     cd utils | ||||
|     sbatch convert_data.sh | ||||
|     If you want to run the script on a cluster, you can use the template script `convert_data_template.sh`.  | ||||
| After adjusting the grant name, the paths to conda env and the paths to data send the job to queue using sbatch command on a login node of e.g. Ares:  | ||||
|    ``` | ||||
|     sbatch convert_data_template.sh | ||||
|    ``` | ||||
|     | ||||
|     If your data has a different structure or format, use the notebooks to gain an understanding of the Seisbench format and what needs to be done to transform your data: | ||||
|     If your data has a different structure or format, check the notebooks to gain an understanding of the Seisbench format and what needs to be done to transform your data: | ||||
|    * [Seisbench example](https://colab.research.google.com/github/seisbench/seisbench/blob/main/examples/01a_dataset_basics.ipynb) or | ||||
|    * [Transforming mseeds from Bogdanka to Seisbench format](utils/Transforming mseeds from Bogdanka to Seisbench format.ipynb) notebook  | ||||
|    * [Transforming mseeds from Bogdanka to Seisbench format](notebooks/Transforming mseeds from Bogdanka to Seisbench format.ipynb) notebook  | ||||
|      | ||||
|  | ||||
| 3. Adjust the `config.json` and specify:  | ||||
| @@ -110,34 +110,65 @@ poetry shell | ||||
| `python pipeline.py` | ||||
|  | ||||
|    The script performs the following steps: | ||||
|    * Generates evaluation targets in `datasets/<dataset_name>/targets` directory.  | ||||
|      * Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss. | ||||
|    1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.  | ||||
|    1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss. | ||||
|       | ||||
|      This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results. | ||||
|           The results are available at    | ||||
|           `https://epos-ai.grid.cyfronet.pl/<WANDB_USER>/<WANDB_PROJECT>` | ||||
|      Weights and training logs can be downloaded from the platform.  | ||||
|     Additionally, the most important data are saved locally in `weights/<dataset_name>_<model_name>/ ` directory: | ||||
|      * Weights of the best checkpoint of each model are saved as  `<dataset_name>_<model_name>_sweep=<sweep_id>-run=<run_id>-epoch=<epoch_number>-val_loss=<val_loss>.ckpt` | ||||
|      * Metrics and hyperparams are saved  in <run_id> folders | ||||
|         This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results. | ||||
|              The results are available at    | ||||
|              `https://epos-ai.grid.cyfronet.pl/<WANDB_USER>/<WANDB_PROJECT>` | ||||
|         Weights and training logs can be downloaded from the platform.  | ||||
|        Additionally, the most important data are saved locally in `weights/<dataset_name>_<model_name>/ ` directory: | ||||
|         * Weights of the best checkpoint of each model are saved as  `<dataset_name>_<model_name>_sweep=<sweep_id>-run=<run_id>-epoch=<epoch_number>-val_loss=<val_loss>.ckpt` | ||||
|         * Metrics and hyperparams are saved  in <run_id> folders | ||||
|         | ||||
|    * Uses the best performing model of each type to generate predictions. The predictons are saved in the `scripts/pred/<dataset_name>_<model_name>/<run_id>` directory. | ||||
|    * Evaluates the performance of each model by comparing the predictions with the evaluation targets.  | ||||
|    The results are saved in the `scripts/pred/results.csv` file. | ||||
|    1. Uses the best performing model of each type to generate predictions. The predictons are saved in the `scripts/pred/<dataset_name>_<model_name>/<run_id>` directory. | ||||
|    1. Evaluates the performance of each model by comparing the predictions with the evaluation targets and calculating MAE metrics. | ||||
|    The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.  | ||||
|      | ||||
|    <br/> | ||||
|     The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run: | ||||
|  | ||||
|   The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script.  | ||||
|   For example, to change the sweep configuration file for GPD model, run: | ||||
|   `python pipeline.py --gpd_config <new config file>` | ||||
|   The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file. | ||||
|    ```python pipeline.py --gpd_config <new config file>``` | ||||
|        | ||||
|    The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file. | ||||
|     | ||||
|    Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:  | ||||
|     * `batch_size` | ||||
|     * `learning_rate` | ||||
|     | ||||
|    Phasenet model has additional available parameters: | ||||
|      * `norm` - normalization method, options ('peak', 'std') | ||||
|      * `pretrained` - pretrained seisbench models used for transfer learning  | ||||
|      * `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder') | ||||
|      * `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers | ||||
|  | ||||
|    GPD model has additional parameters for filtering:  | ||||
|      * `highpass` - highpass filter frequency | ||||
|      * `lowpass` - lowpass filter frequency | ||||
|  | ||||
|    The sweep configs are saved in the `experiments` folder.   | ||||
|  | ||||
|  | ||||
|    If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument: | ||||
|         | ||||
|    ```python pipeline.py --dataset <dataset_name>``` | ||||
|  | ||||
| ### Troubleshooting | ||||
|  | ||||
| * Problem with reading the catalog file: please make sure that your quakeML xml file has the following opening and closing tags: | ||||
| ``` | ||||
| <?xml version="1.0"?> | ||||
| <q:quakeml xmlns="http://quakeml.org/xmlns/bed/1.2" xmlns:q="http://quakeml.org/xmlns/quakeml/1.2"> | ||||
|   .... | ||||
| </q:quakeml> | ||||
| ``` | ||||
|  | ||||
| * `wandb: ERROR Run .. errored: OSError(24, 'Too many open files')` | ||||
| -> https://github.com/wandb/wandb/issues/2825 | ||||
|  | ||||
| ### Licence | ||||
|  | ||||
| TODO | ||||
| The code is licenced under the GNU General Public License v3.0. See the [LICENSE](LICENSE.txt) file for details. | ||||
|  | ||||
| ### Copyright | ||||
|  | ||||
|   | ||||
							
								
								
									
										28
									
								
								config.json
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								config.json
									
									
									
									
									
								
							| @@ -1,15 +1,17 @@ | ||||
| { | ||||
|   "dataset_name": "bogdanka", | ||||
|   "data_path": "datasets/bogdanka/seisbench_format/", | ||||
|   "targets_path": "datasets/targets", | ||||
|   "models_path": "weights", | ||||
|   "configs_path": "experiments", | ||||
|   "sampling_rate": 100, | ||||
|   "num_workers": 1, | ||||
|   "seed": 10, | ||||
|   "sweep_files": { | ||||
|     "GPD": "sweep_gpd.yaml", | ||||
|     "PhaseNet": "sweep_phasenet.yaml" | ||||
|   }, | ||||
|   "experiment_count": 20 | ||||
|     "dataset_name": "bogdanka_2018_2022", | ||||
|     "data_path": "datasets/bogdanka_2018_2022/seisbench_format/", | ||||
|     "targets_path": "datasets/targets", | ||||
|     "models_path": "weights", | ||||
|     "configs_path": "experiments", | ||||
|     "sampling_rate": 100, | ||||
|     "num_workers": 1, | ||||
|     "seed": 10, | ||||
|     "sweep_files": { | ||||
|         "GPD": "sweep_gpd.yaml", | ||||
|         "PhaseNet": "sweep_phasenet.yaml", | ||||
|         "BasicPhaseAE": "sweep_basicphase_ae.yaml", | ||||
|         "EQTransformer": "sweep_eqtransformer.yaml" | ||||
|     }, | ||||
|     "experiment_count": 15 | ||||
| } | ||||
							
								
								
									
										16
									
								
								experiments/sweep_basicphase_ae.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								experiments/sweep_basicphase_ae.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| name: BasicPhaseAE | ||||
| method: bayes | ||||
| metric: | ||||
|   goal: minimize | ||||
|   name: val_loss | ||||
| parameters: | ||||
|   model_name: | ||||
|     value: | ||||
|       - BasicPhaseAE | ||||
|   batch_size: | ||||
|     values: [64, 128, 256] | ||||
|   max_epochs: | ||||
|     value: | ||||
|       - 30 | ||||
|   learning_rate: | ||||
|     values: [0.01, 0.005, 0.001] | ||||
							
								
								
									
										16
									
								
								experiments/sweep_eqtransformer.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								experiments/sweep_eqtransformer.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| name: EQTransformer | ||||
| method: bayes | ||||
| metric: | ||||
|   goal: minimize | ||||
|   name: val_loss | ||||
| parameters: | ||||
|   model_name: | ||||
|     value: | ||||
|       - EQTransformer | ||||
|   batch_size: | ||||
|     values: [64, 128, 256] | ||||
|   max_epochs: | ||||
|     value: | ||||
|       - 30 | ||||
|   learning_rate: | ||||
|     values: [0.01, 0.005, 0.001] | ||||
| @@ -1,4 +1,4 @@ | ||||
| name: GPD_fixed_highpass:2-10 | ||||
| name: GPD | ||||
| method: bayes | ||||
| metric: | ||||
|   goal: minimize | ||||
| @@ -8,19 +8,15 @@ parameters: | ||||
|     value: | ||||
|       - GPD | ||||
|   batch_size: | ||||
|     distribution: int_uniform | ||||
|     max: 1024 | ||||
|     min: 256 | ||||
|     values: [64, 128, 256] | ||||
|   max_epochs: | ||||
|     value: | ||||
|       - 3 | ||||
|       - 30 | ||||
|   learning_rate: | ||||
|     distribution: uniform | ||||
|     max: 0.02 | ||||
|     min: 0.005 | ||||
|     values: [0.01, 0.005, 0.001] | ||||
|   highpass: | ||||
|     value: | ||||
|       - 2 | ||||
|       - 1 | ||||
|   lowpass: | ||||
|     value: | ||||
|       - 10 | ||||
|       - 10 | ||||
|   | ||||
| @@ -13,7 +13,7 @@ parameters: | ||||
|     min: 256 | ||||
|   max_epochs: | ||||
|     value: | ||||
|       - 15 | ||||
|       - 30 | ||||
|   learning_rate: | ||||
|     distribution: uniform | ||||
|     max: 0.02 | ||||
|   | ||||
| @@ -18,9 +18,7 @@ | ||||
|     "import seisbench.data as sbd\n", | ||||
|     "import seisbench.util as sbu\n", | ||||
|     "import numpy as np\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "import utils\n" | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
| @@ -1126,7 +1124,7 @@ | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.5" | ||||
|    "version": "3.10.6" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
| @@ -36,6 +36,16 @@ | ||||
|     "\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "70c64dc6-e4dd-4c01-939d-a28914866f5d", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "##### The catalog has a custom format with the following properties: \n", | ||||
|     "###### 'Datetime', 'X', 'Y', 'Depth', 'Mw', 'Phases', 'mseed_name'\n", | ||||
|     "###### Phases is a string with detected phases seperated by comma: <Phase> <Station> <Datetime> e.g. \"Pg BRDW 2020-01-01 10:09:44.400, Sg BRDW 2020-01-01 10:09:45.696\"" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
| @@ -106,6 +116,27 @@ | ||||
|     "catalog.head(1)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "id": "03257d45-299d-4ed1-bc64-03303d2a9873", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "'Pg BRDW 2020-01-01 10:09:44.400, Sg BRDW 2020-01-01 10:09:45.696, Pg GROD 2020-01-01 10:09:45.206, Sg GROD 2020-01-01 10:09:46.655, Pg GUZI 2020-01-01 10:09:45.116, Sg GUZI 2020-01-01 10:09:46.561, Pg JEDR 2020-01-01 10:09:44.920, Sg JEDR 2020-01-01 10:09:46.285, Pg MOSK2 2020-01-01 10:09:45.417, Sg MOSK2 2020-01-01 10:09:46.921, Pg NWLU 2020-01-01 10:09:45.686, Sg NWLU 2020-01-01 10:09:47.175, Pg PCHB 2020-01-01 10:09:45.213, Sg PCHB 2020-01-01 10:09:46.565, Pg PPOL 2020-01-01 10:09:44.755, Sg PPOL 2020-01-01 10:09:46.069, Pg RUDN 2020-01-01 10:09:44.502, Sg RUDN 2020-01-01 10:09:45.756, Pg RYNR 2020-01-01 10:09:43.442, Sg RYNR 2020-01-01 10:09:44.394, Pg RZEC 2020-01-01 10:09:46.075, Sg RZEC 2020-01-01 10:09:47.587, Pg SGOR 2020-01-01 10:09:45.817, Sg SGOR 2020-01-01 10:09:47.284, Pg TRBC2 2020-01-01 10:09:44.833, Sg TRBC2 2020-01-01 10:09:46.095, Pg TRN2 2020-01-01 10:09:44.488, Sg TRN2 2020-01-01 10:09:45.698, Pg TRZS 2020-01-01 10:09:46.232, Sg TRZS 2020-01-01 10:09:47.727, Pg ZMST 2020-01-01 10:09:43.592, Sg ZMST 2020-01-01 10:09:44.553, Pg LUBW 2020-01-01 10:09:43.119, Sg LUBW 2020-01-01 10:09:43.929'" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 4, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "catalog.Phases[0]" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "id": "fe0627b1-6fa0-4b5a-8a60-d80626b5c9be", | ||||
							
								
								
									
										19
									
								
								scripts/convert_data_template.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								scripts/convert_data_template.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| #!/bin/bash | ||||
| #SBATCH --job-name=mseeds_to_seisbench | ||||
| #SBATCH --time=1:00:00 | ||||
| #SBATCH --account=  									### to fill | ||||
| #SBATCH --partition plgrid | ||||
| #SBATCH --cpus-per-task=1 | ||||
| #SBATCH --ntasks-per-node=1 | ||||
| #SBATCH --mem=24gb | ||||
|  | ||||
|  | ||||
| ## activate conda environment | ||||
| source /path/to/mambaforge/bin/activate					### to  adjust | ||||
| conda activate epos-ai-train | ||||
|  | ||||
| input_path="/path/to/folder/with/mseed/files" | ||||
| catalog_path="/path/to/catolog.xml" | ||||
| output_path="/path/to/output/in/seisbench_format" | ||||
|  | ||||
| python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path | ||||
| @@ -1,8 +1,13 @@ | ||||
| """ | ||||
| This file contains functionality related to data. | ||||
| """ | ||||
| import os.path | ||||
|  | ||||
| import seisbench.data as sbd | ||||
| import logging | ||||
|  | ||||
| logging.root.setLevel(logging.INFO) | ||||
| logger = logging.getLogger('data') | ||||
|  | ||||
|  | ||||
| def get_dataset_by_name(name): | ||||
| @@ -30,3 +35,26 @@ def get_custom_dataset(path): | ||||
|     except AttributeError: | ||||
|         raise ValueError(f"Unknown dataset '{path}'.") | ||||
|  | ||||
|  | ||||
| def validate_custom_dataset(data_path): | ||||
|     """ | ||||
|     Validate the dataset | ||||
|     :param data_path: path to the dataset | ||||
|     :return: | ||||
|     """ | ||||
|     # check if path exists | ||||
|     if not os.path.isdir((data_path)): | ||||
|         raise ValueError(f"Data path {data_path} does not exist.") | ||||
|  | ||||
|     dataset = sbd.WaveformDataset(data_path) | ||||
|     # check if the dataset is split into train, dev and test | ||||
|     if len(dataset.train()) == 0: | ||||
|         raise ValueError(f"Training set is empty.") | ||||
|     if len(dataset.dev()) == 0: | ||||
|         raise ValueError(f"Dev set is empty.") | ||||
|     if len(dataset.test()) == 0: | ||||
|         raise ValueError(f"Test set is empty.") | ||||
|  | ||||
|     logger.info("Custom dataset validated successfully.") | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -39,10 +39,15 @@ from pathlib import Path | ||||
| import pandas as pd | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
|  | ||||
| import logging | ||||
| from models import phase_dict | ||||
|  | ||||
|  | ||||
|  | ||||
| logging.root.setLevel(logging.INFO) | ||||
| logger = logging.getLogger('targets generator') | ||||
|  | ||||
|  | ||||
| def main(dataset_name, output, tasks, sampling_rate, noise_before_events): | ||||
|     np.random.seed(42) | ||||
|     tasks = [str(i) in tasks.split(",") for i in range(1, 4)] | ||||
| @@ -64,17 +69,24 @@ def main(dataset_name, output, tasks, sampling_rate, noise_before_events): | ||||
|         dataset = sbd.WaveformDataset(dataset_name, **dataset_args) | ||||
|  | ||||
|     output = Path(output) | ||||
|     output.mkdir(parents=True, exist_ok=False) | ||||
|     output.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     if "split" in dataset.metadata.columns: | ||||
|         dataset.filter(dataset["split"].isin(["dev", "test"]), inplace=True) | ||||
|  | ||||
|     dataset.preload_waveforms(pbar=True) | ||||
|  | ||||
|      | ||||
|     if tasks[0]: | ||||
|         generate_task1(dataset, output, sampling_rate, noise_before_events) | ||||
|         if not Path.exists(output / "task1.csv"): | ||||
|             generate_task1(dataset, output, sampling_rate, noise_before_events) | ||||
|         else: | ||||
|             logger.info(f"{output}/task1.csv already exists. Skipping generation of targets.") | ||||
|     if tasks[1] or tasks[2]: | ||||
|         generate_task23(dataset, output, sampling_rate) | ||||
|         if not Path.exists(output / "task23.csv"): | ||||
|             generate_task23(dataset, output, sampling_rate) | ||||
|         else: | ||||
|             logger.info(f"{output}/task23.csv already exists. Skipping generation of targets.") | ||||
|  | ||||
|  | ||||
|  | ||||
| def generate_task1(dataset, output, sampling_rate, noise_before_events): | ||||
|   | ||||
| @@ -7,8 +7,9 @@ import os | ||||
| import os.path | ||||
| import argparse | ||||
| from pytorch_lightning.loggers import WandbLogger, CSVLogger | ||||
| from pytorch_lightning.callbacks import ModelCheckpoint | ||||
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor | ||||
| from pytorch_lightning.callbacks.early_stopping import EarlyStopping | ||||
|  | ||||
| import pytorch_lightning as pl | ||||
| import wandb | ||||
| import torch | ||||
| @@ -18,8 +19,8 @@ from dotenv import load_dotenv | ||||
| import models | ||||
| import train | ||||
| import util | ||||
| from config_loader import config as common_config | ||||
| from config_loader import models_path, dataset_name, seed, experiment_count | ||||
| import config_loader | ||||
|  | ||||
|  | ||||
|  | ||||
| torch.multiprocessing.set_sharing_strategy('file_system') | ||||
| @@ -33,16 +34,13 @@ host = os.environ.get("WANDB_HOST") | ||||
| if host is None: | ||||
|     raise ValueError("WANDB_HOST environment variable is not set.") | ||||
|  | ||||
|  | ||||
| wandb.login(key=wandb_api_key, host=host) | ||||
| # wandb.login(key=wandb_api_key) | ||||
|  | ||||
| wandb_project_name = os.environ.get("WANDB_PROJECT") | ||||
| wandb_user_name = os.environ.get("WANDB_USER") | ||||
|  | ||||
| script_name = os.path.splitext(os.path.basename(__file__))[0] | ||||
| logger = logging.getLogger(script_name) | ||||
| logger.setLevel(logging.WARNING) | ||||
| logger.setLevel(logging.INFO) | ||||
|  | ||||
|  | ||||
| def set_random_seed(seed=3): | ||||
| @@ -58,6 +56,11 @@ def get_trainer_args(config): | ||||
|     return trainer_args | ||||
|  | ||||
|  | ||||
| def get_arg(arg): | ||||
|     if type(arg) == list: | ||||
|         return arg[0] | ||||
|     return arg | ||||
|  | ||||
| class HyperparameterSweep: | ||||
|     def __init__(self, project_name, sweep_config): | ||||
|         self.project_name = project_name | ||||
| @@ -68,11 +71,9 @@ class HyperparameterSweep: | ||||
|  | ||||
|         # Create the sweep | ||||
|         self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name) | ||||
|  | ||||
|         logger.info("Created sweep with ID: " + self.sweep_id) | ||||
|  | ||||
|         # Run the sweep | ||||
|         wandb.agent(self.sweep_id, function=self.run_experiment, count=experiment_count) | ||||
|         wandb.agent(self.sweep_id, function=self.run_experiment, count=config_loader.experiment_count) | ||||
|  | ||||
|     def all_runs_finished(self): | ||||
|  | ||||
| @@ -90,24 +91,40 @@ class HyperparameterSweep: | ||||
|         return all_not_running | ||||
|  | ||||
|     def run_experiment(self): | ||||
|  | ||||
|         try: | ||||
|  | ||||
|             logger.debug("Starting a new run...") | ||||
|             logger.info("Starting a new run...") | ||||
|  | ||||
|             run = wandb.init( | ||||
|                 project=self.project_name, | ||||
|                 config=common_config, | ||||
|                 config=config_loader.config, | ||||
|                 save_code=True, | ||||
|                 entity=wandb_user_name | ||||
|             ) | ||||
|             run.log_code( | ||||
|                 root=".", | ||||
|                 include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"), | ||||
|                 exclude_fn=lambda path: path.endswith("template.sh") | ||||
|             ) | ||||
|  | ||||
|             wandb.run.log_code( | ||||
|                 ".", | ||||
|                 include_fn=lambda path: path.endswith(os.path.basename(__file__)) | ||||
|             ) | ||||
|  | ||||
|             model_name = wandb.config.model_name[0] | ||||
|             model_name = get_arg(wandb.config.model_name) | ||||
|             model_args = models.get_model_specific_args(wandb.config) | ||||
|             logger.debug(f"Initializing {model_name}") | ||||
|  | ||||
|             if "pretrained" in wandb.config: | ||||
|                 weights = get_arg(wandb.config.pretrained) | ||||
|                 if weights != "false": | ||||
|                     model_args["pretrained"] = weights | ||||
|  | ||||
|             if "norm" in wandb.config:  | ||||
| 	            model_args["norm"] = get_arg(wandb.config.norm) | ||||
| 	             | ||||
|             if "finetuning" in wandb.config: | ||||
|                 model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning) | ||||
|  | ||||
|             if "lr_reduce_factor" in wandb.config: | ||||
|                 model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor) | ||||
|  | ||||
|             logger.debug(f"Initializing {model_name} with args: {model_args}") | ||||
|             model = models.__getattribute__(model_name + "Lit")(**model_args) | ||||
|  | ||||
|             train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False) | ||||
| @@ -116,8 +133,8 @@ class HyperparameterSweep: | ||||
|             wandb_logger.watch(model) | ||||
|  | ||||
|             # CSV logger - also used for saving configuration as yaml | ||||
|             experiment_name = f"{dataset_name}_{model_name}" | ||||
|             csv_logger = CSVLogger(models_path, experiment_name, version=run.id) | ||||
|             experiment_name = f"{config_loader.dataset_name}_{model_name}" | ||||
|             csv_logger = CSVLogger(config_loader.models_path, experiment_name, version=run.id) | ||||
|             csv_logger.log_hyperparams(wandb.config) | ||||
|  | ||||
|             loggers = [wandb_logger, csv_logger] | ||||
| @@ -131,24 +148,26 @@ class HyperparameterSweep: | ||||
|                 filename=experiment_signature + "-{epoch}-{val_loss:.3f}", | ||||
|                 monitor="val_loss", | ||||
|                 mode="min", | ||||
|                 dirpath=f"{models_path}/{experiment_name}/", | ||||
|                 dirpath=f"{config_loader.models_path}/{experiment_name}/", | ||||
|             )  # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss | ||||
|             checkpoint_callback.STARTING_VERSION = 1 | ||||
|  | ||||
|             early_stopping_callback = EarlyStopping( | ||||
|                 monitor="val_loss", | ||||
|                 patience=3, | ||||
|                 patience=5, | ||||
|                 verbose=True, | ||||
|                 mode="min") | ||||
|             callbacks = [checkpoint_callback, early_stopping_callback] | ||||
|  | ||||
|             lr_monitor = LearningRateMonitor(logging_interval='epoch') | ||||
|  | ||||
|             callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor] | ||||
|  | ||||
|             trainer = pl.Trainer( | ||||
|                 default_root_dir=models_path, | ||||
|                 default_root_dir=config_loader.models_path, | ||||
|                 logger=loggers, | ||||
|                 callbacks=callbacks, | ||||
|                 **get_trainer_args(wandb.config) | ||||
|             ) | ||||
|  | ||||
|             trainer.fit(model, train_loader, dev_loader) | ||||
|  | ||||
|         except Exception as e: | ||||
| @@ -160,9 +179,8 @@ class HyperparameterSweep: | ||||
|  | ||||
|  | ||||
| def start_sweep(sweep_config): | ||||
|  | ||||
|     logger.info("Starting sweep with config: " + str(sweep_config)) | ||||
|     set_random_seed(seed) | ||||
|     set_random_seed(config_loader.seed) | ||||
|     sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config) | ||||
|     sweep_runner.run_sweep() | ||||
|  | ||||
| @@ -170,7 +188,6 @@ def start_sweep(sweep_config): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("--sweep_config", type=str, required=True) | ||||
|     args = parser.parse_args() | ||||
|   | ||||
							
								
								
									
										146
									
								
								scripts/input_validate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								scripts/input_validate.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | ||||
| from pydantic import BaseModel, ConfigDict, field_validator | ||||
| from typing_extensions import Literal | ||||
| from typing import Union, List, Optional | ||||
| import yaml | ||||
| import logging | ||||
|  | ||||
| logging.root.setLevel(logging.INFO) | ||||
| logger = logging.getLogger('input_validator') | ||||
|  | ||||
| #todo | ||||
| # 1. check if a single value is allowed in a sweep | ||||
| # 2. merge input params | ||||
| # 3. change names of the classes | ||||
| # 4. add constraints for PhaseNet, GPD | ||||
|  | ||||
|  | ||||
| model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"] | ||||
| norm_values = Literal["peak", "std"] | ||||
| finetuning_values = Literal["all", "top", "decoder", "encoder"] | ||||
| pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic', | ||||
|                                                     'original', 'scedc', False] | ||||
|  | ||||
|  | ||||
| class Metric(BaseModel): | ||||
|     goal: str | ||||
|     name: str | ||||
|  | ||||
|  | ||||
| class NumericValue(BaseModel): | ||||
|     value: Union[int, float, List[Union[int, float]]] | ||||
|  | ||||
|  | ||||
| class NumericValues(BaseModel): | ||||
|     values: List[Union[int, float]] | ||||
|  | ||||
|  | ||||
| class IntDistribution(BaseModel): | ||||
|     distribution: str = "int_uniform" | ||||
|     min: int | ||||
|     max: int | ||||
|  | ||||
|  | ||||
| class FloatDistribution(BaseModel): | ||||
|     distribution: str = "uniform" | ||||
|     min: float | ||||
|     max: float | ||||
|  | ||||
|  | ||||
| class Pretrained(BaseModel): | ||||
|     distribution: Optional[str] = "categorical" | ||||
|     values: List[pretrained_values] = None | ||||
|     value: Union[pretrained_values, List[pretrained_values]] = None | ||||
|  | ||||
|  | ||||
| class Finetuning(BaseModel): | ||||
|     distribution: Optional[str] = "categorical" | ||||
|     values: List[finetuning_values] = None | ||||
|     value: Union[finetuning_values, List[finetuning_values]] = None | ||||
|  | ||||
|  | ||||
| class Norm(BaseModel): | ||||
|     distribution: Optional[str] = "categorical" | ||||
|     values: List[norm_values] = None | ||||
|     value: Union[norm_values, List[norm_values]] = None | ||||
|  | ||||
|  | ||||
| class ModelType(BaseModel): | ||||
|     distribution: Optional[str] = "categorical" | ||||
|     value: Union[model_names, List[model_names]] = None | ||||
|     values: List[model_names] = None | ||||
|  | ||||
|  | ||||
| class Parameters(BaseModel): | ||||
|     model_config = ConfigDict(extra='forbid', protected_namespaces=()) | ||||
|     model_name: ModelType | ||||
|     batch_size: Union[IntDistribution, NumericValue, NumericValues] | ||||
|     learning_rate: Union[FloatDistribution, NumericValue, NumericValues] | ||||
|     max_epochs: Union[IntDistribution, NumericValue, NumericValues] | ||||
|  | ||||
|  | ||||
| class PhaseNetParameters(Parameters): | ||||
|     model_config = ConfigDict(extra='forbid') | ||||
|     norm: Norm = None | ||||
|     pretrained: Pretrained = None | ||||
|     finetuning: Finetuning = None | ||||
|     lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None | ||||
|  | ||||
|     highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None | ||||
|     lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None | ||||
|  | ||||
|     @field_validator("model_name") | ||||
|     def validate_model(cls, v): | ||||
|         if "PhaseNet" not in v.value: | ||||
|             raise ValueError("Additional parameters implemented for PhaseNet only") | ||||
|         return v | ||||
|  | ||||
|  | ||||
| class FilteringParameters(Parameters): | ||||
|     model_config = ConfigDict(extra='forbid') | ||||
|  | ||||
|     highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None | ||||
|     lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None | ||||
|  | ||||
|     @field_validator("model_name") | ||||
|     def validate_model(cls, v): | ||||
|         print(v.value) | ||||
|         if v.value[0] not in ["GPD", "PhaseNet"]: | ||||
|             raise ValueError("Filtering parameters implemented for GPD and PhaseNet only") | ||||
|  | ||||
|  | ||||
| class InputParams(BaseModel): | ||||
|     name: str | ||||
|     method: str | ||||
|     metric: Metric | ||||
|     parameters: Union[Parameters, PhaseNetParameters, FilteringParameters] | ||||
|  | ||||
|  | ||||
| def validate_sweep_yaml(yaml_filename, model_name=None): | ||||
|     # Load YAML configuration | ||||
|     with open(yaml_filename, 'r') as f: | ||||
|         sweep_config = yaml.safe_load(f) | ||||
|  | ||||
|     validate_sweep_config(sweep_config, model_name) | ||||
|  | ||||
|  | ||||
| def validate_sweep_config(sweep_config, model_name=None): | ||||
|  | ||||
|     # Validate sweep config | ||||
|  | ||||
|     input_params = InputParams(**sweep_config) | ||||
|  | ||||
|     # Check consistency of input parameters and sweep configuration | ||||
|     sweep_model_name = input_params.parameters.model_name.value | ||||
|     if model_name is not None and model_name not in sweep_model_name: | ||||
|         info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}." | ||||
|         logger.info(info) | ||||
|         raise ValueError(info) | ||||
|     logger.info("Input validation successful.") | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml" | ||||
|     validate_sweep_yaml(yaml_filename, None) | ||||
| @@ -7,16 +7,26 @@ import seisbench.generate as sbg | ||||
|  | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from torch.optim import lr_scheduler | ||||
| import torch.nn.functional as F | ||||
| import numpy as np | ||||
| from abc import abstractmethod, ABC | ||||
|  | ||||
| # import lightning as L | ||||
|  | ||||
|  | ||||
| # Allows to import this file in both jupyter notebook and code | ||||
| try: | ||||
|     from .augmentations import DuplicateEvent | ||||
| except ImportError: | ||||
|     from augmentations import DuplicateEvent | ||||
|  | ||||
| import os | ||||
| import logging | ||||
|  | ||||
| script_name = os.path.splitext(os.path.basename(__file__))[0] | ||||
| logger = logging.getLogger(script_name) | ||||
| logger.setLevel(logging.DEBUG) | ||||
|  | ||||
| # Phase dict for labelling. We only study P and S phases without differentiating between them. | ||||
| phase_dict = { | ||||
| @@ -131,7 +141,31 @@ class PhaseNetLit(SeisBenchModuleLit): | ||||
|         self.sigma = sigma | ||||
|         self.sample_boundaries = sample_boundaries | ||||
|         self.loss = vector_cross_entropy | ||||
|         self.model = sbm.PhaseNet(phases="PN", **kwargs) | ||||
|         self.pretrained = kwargs.pop("pretrained", None) | ||||
|         self.norm = kwargs.pop("norm", "peak") | ||||
|         self.highpass = kwargs.pop("highpass", None) | ||||
|         self.lowpass = kwargs.pop("lowpass", None) | ||||
|  | ||||
|  | ||||
|         if self.pretrained is not None: | ||||
|             self.model = sbm.PhaseNet.from_pretrained(self.pretrained) | ||||
|             # self.norm = self.model.norm | ||||
|         else: | ||||
|             self.model = sbm.PhaseNet(**kwargs) | ||||
|  | ||||
|         self.finetuning_strategy = kwargs.pop("finetuning_strategy", None) | ||||
|         self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1) | ||||
|         self.reduce_lr_on_plateau = False | ||||
|  | ||||
|         self.initial_epochs = 0 | ||||
|  | ||||
|         if self.finetuning_strategy is not None: | ||||
|             if self.finetuning_strategy == "top": | ||||
|                 self.initial_epochs = 3 | ||||
|             elif self.finetuning_strategy in ["decoder", "encoder"]: | ||||
|                 self.initial_epochs = 6 | ||||
|  | ||||
|             self.freeze() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.model(x) | ||||
| @@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit): | ||||
|  | ||||
|     def configure_optimizers(self): | ||||
|         optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | ||||
|         return optimizer | ||||
|         if self.finetuning_strategy is not None: | ||||
|             scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda) | ||||
|             self.reduce_lr_on_plateau = False | ||||
|         else: | ||||
|             scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) | ||||
|             self.reduce_lr_on_plateau = True | ||||
|         # | ||||
|         return { | ||||
|             'optimizer': optimizer, | ||||
|             'lr_scheduler': { | ||||
|                 'scheduler': scheduler, | ||||
|                 'monitor': 'val_loss', | ||||
|                 'interval': 'epoch', | ||||
|                 'reduce_on_plateau': self.reduce_lr_on_plateau, | ||||
|             }, | ||||
|         } | ||||
|  | ||||
|     def lr_lambda(self, epoch): | ||||
|         # reduce lr after x initial epochs | ||||
|         if epoch == self.initial_epochs: | ||||
|             self.lr *= self.steplr_gamma | ||||
|  | ||||
|         return self.lr | ||||
|  | ||||
|     def lr_scheduler_step(self, scheduler, metric): | ||||
|         if self.reduce_lr_on_plateau: | ||||
|             scheduler.step(metric, epoch=self.current_epoch) | ||||
|         else: | ||||
|             scheduler.step(epoch=self.current_epoch) | ||||
|  | ||||
|     # def lr_scheduler_step(self, scheduler, optimizer_idx, metric): | ||||
|     #         scheduler.step(epoch=self.current_epoch) | ||||
|  | ||||
|     def get_augmentations(self): | ||||
|         filter = [] | ||||
|         if self.highpass is not None: | ||||
|             filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)] | ||||
|             logger.info(f"Using highpass filer {self.highpass}") | ||||
|         if self.lowpass is not None: | ||||
|             filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)] | ||||
|             logger.info(f"Using lowpass filer {self.lowpass}") | ||||
|             logger.info(filter) | ||||
|  | ||||
|         return [ | ||||
|             # In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training. | ||||
|             # Uses strategy variable, as padding will be handled by the random window. | ||||
| @@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit): | ||||
|                 windowlen=3001, | ||||
|                 strategy="pad", | ||||
|             ), | ||||
|             sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), | ||||
|             *filter, | ||||
|             sbg.ChangeDtype(np.float32), | ||||
|             sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), | ||||
|             sbg.ProbabilisticLabeller( | ||||
|                 label_columns=phase_dict, sigma=self.sigma, dim=0 | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|     def get_eval_augmentations(self): | ||||
|         filter = [] | ||||
|         if self.highpass is not None: | ||||
|             filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)] | ||||
|         if self.lowpass is not None: | ||||
|             filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)] | ||||
|  | ||||
|         return [ | ||||
|             sbg.SteeredWindow(windowlen=3001, strategy="pad"), | ||||
|             sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), | ||||
|             *filter, | ||||
|             sbg.ChangeDtype(np.float32), | ||||
|             sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), | ||||
|         ] | ||||
|  | ||||
|     def predict_step(self, batch, batch_idx=None, dataloader_idx=None): | ||||
| @@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit): | ||||
|  | ||||
|         return score_detection, score_p_or_s, p_sample, s_sample | ||||
|  | ||||
|     def freeze(self): | ||||
|         if self.finetuning_strategy == "decoder":  # finetune decoder branch and freeze encoder branch | ||||
|             for p in self.model.down_branch.parameters(): | ||||
|                 p.requires_grad = False | ||||
|         elif self.finetuning_strategy == "encoder":  # finetune encoder branch and freeze decoder branch | ||||
|             for p in self.model.up_branch.parameters(): | ||||
|                 p.requires_grad = False | ||||
|         elif self.finetuning_strategy == "top": | ||||
|             for p in self.model.out.parameters(): | ||||
|                 p.requires_grad = False | ||||
|  | ||||
|     def unfreeze(self): | ||||
|         logger.info("Unfreezing layers") | ||||
|         for p in self.model.parameters(): | ||||
|             p.requires_grad = True | ||||
|  | ||||
|     def on_train_epoch_start(self): | ||||
|         # Unfreeze some layers after x initial epochs | ||||
|         if self.current_epoch == self.initial_epochs: | ||||
|             self.unfreeze() | ||||
|  | ||||
|  | ||||
| class GPDLit(SeisBenchModuleLit): | ||||
|     """ | ||||
| @@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit): | ||||
|         # Create overlapping windows | ||||
|         re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device) | ||||
|         for i, start in enumerate(range(0, 2401, 400)): | ||||
|             re[:, :, i] = x[:, :, start : start + 600] | ||||
|             re[:, :, i] = x[:, :, start: start + 600] | ||||
|         x = re | ||||
|  | ||||
|         x = x.permute(0, 2, 1, 3)  # --> (batch, windows, channels, samples) | ||||
| @@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit): | ||||
|         for i, start in enumerate(range(0, 2401, 400)): | ||||
|             if start == 0: | ||||
|                 # Use full window (for start==0, the end will be overwritten) | ||||
|                 pred[:, :, start : start + 600] = window_pred[:, i] | ||||
|                 pred[:, :, start: start + 600] = window_pred[:, i] | ||||
|             else: | ||||
|                 pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:] | ||||
|                 pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:] | ||||
|  | ||||
|         score_detection = torch.zeros(pred.shape[0]) | ||||
|         score_p_or_s = torch.zeros(pred.shape[0]) | ||||
| @@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit): | ||||
|  | ||||
|  | ||||
| def get_model_specific_args(config): | ||||
|  | ||||
|     model = config.model_name[0] | ||||
|     lr = config.learning_rate | ||||
|     if type(lr) == list: | ||||
| @@ -1125,8 +1227,12 @@ def get_model_specific_args(config): | ||||
|             if 'highpass' in config: | ||||
|                 args['highpass'] = config.highpass | ||||
|             if 'lowpass' in config: | ||||
|                 args['lowpass'] = config.lowpass[0] | ||||
|                 args['lowpass'] = config.lowpass | ||||
|         case 'PhaseNet': | ||||
|             if 'highpass' in config: | ||||
|                 args['highpass'] = config.highpass | ||||
|             if 'lowpass' in config: | ||||
|                 args['lowpass'] = config.lowpass | ||||
|             if 'sample_boundaries' in config: | ||||
|                 args['sample_boundaries'] = config.sample_boundaries | ||||
|         case 'DPPPicker': | ||||
|   | ||||
| @@ -1,66 +1,44 @@ | ||||
| """ | ||||
| ----------------- | ||||
| Copyright © 2023 ACK Cyfronet AGH, Poland. | ||||
| This work was partially funded by EPOS Project funded in frame of PL-POIR4.2 | ||||
| ----------------- | ||||
| """ | ||||
| 
 | ||||
| import os | ||||
| import pandas as pd | ||||
| import glob | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import obspy | ||||
| from obspy.core.event import read_events | ||||
| 
 | ||||
| import seisbench | ||||
| import seisbench.data as sbd | ||||
| import seisbench.util as sbu | ||||
| import sys | ||||
| 
 | ||||
| import logging | ||||
| import argparse | ||||
| 
 | ||||
| 
 | ||||
| logging.basicConfig(filename="output.out", | ||||
|                     filemode='a', | ||||
|                     format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', | ||||
|                     datefmt='%H:%M:%S', | ||||
|                     level=logging.DEBUG) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| logging.root.setLevel(logging.INFO) | ||||
| logger = logging.getLogger('converter') | ||||
| 
 | ||||
| def create_traces_catalog(directory, years): | ||||
|     for year in years: | ||||
|         directory = f"{directory}/{year}" | ||||
|         files = glob.glob(directory) | ||||
|         traces = [] | ||||
|         for i, f in enumerate(files): | ||||
|             st = obspy.read(f) | ||||
| 
 | ||||
|             for tr in st.traces: | ||||
|                 # trace_id = tr.id | ||||
|                 # start = tr.meta.starttime | ||||
|                 # end = tr.meta.endtime | ||||
| 
 | ||||
|                 trs = pd.Series({ | ||||
|                     'trace_id': tr.id, | ||||
|                     'trace_st': tr.meta.starttime, | ||||
|                     'trace_end': tr.meta.endtime, | ||||
|                     'stream_fname': f | ||||
|                 }) | ||||
|                 traces.append(trs) | ||||
| 
 | ||||
|         traces_catalog = pd.DataFrame(pd.concat(traces)).transpose() | ||||
|         traces_catalog.to_csv("data/bogdanka/traces_catalog.csv", append=True, index=False) | ||||
| 
 | ||||
| 
 | ||||
| def split_events(events, input_path): | ||||
| 
 | ||||
|     logger.info("Splitting available events into train, dev and test sets ...") | ||||
|     events_stats = pd.DataFrame() | ||||
|     events_stats.index.name = "event" | ||||
| 
 | ||||
|     for i, event in enumerate(events): | ||||
|         #check if mseed exists | ||||
|         # check if mseed exists | ||||
|         actual_picks = 0 | ||||
|         for pick in event.picks: | ||||
|             trace_params = get_trace_params(pick) | ||||
|             trace_path = get_trace_path(input_path, trace_params) | ||||
| 
 | ||||
|             if os.path.isfile(trace_path): | ||||
|                 actual_picks += 1 | ||||
| 
 | ||||
| @@ -79,6 +57,10 @@ def split_events(events, input_path): | ||||
|             events_stats.loc[i, 'split'] = 'dev' | ||||
|         else: | ||||
|             break | ||||
|              | ||||
|     logger.info(f"Split: {events_stats['split'].value_counts()}") | ||||
| 
 | ||||
|     logger.info(f"Split: {events_stats['split'].value_counts()}") | ||||
| 
 | ||||
|     return events_stats | ||||
| 
 | ||||
| @@ -87,7 +69,6 @@ def get_event_params(event): | ||||
|     origin = event.preferred_origin() | ||||
|     if origin is None: | ||||
|         return {} | ||||
|     # print(origin) | ||||
| 
 | ||||
|     mag = event.preferred_magnitude() | ||||
| 
 | ||||
| @@ -115,12 +96,9 @@ def get_event_params(event): | ||||
| 
 | ||||
| 
 | ||||
| def get_trace_params(pick): | ||||
|     net = pick.waveform_id.network_code | ||||
|     sta = pick.waveform_id.station_code | ||||
| 
 | ||||
|     trace_params = { | ||||
|         "station_network_code": net, | ||||
|         "station_code": sta, | ||||
|         "station_network_code": pick.waveform_id.network_code, | ||||
|         "station_code": pick.waveform_id.station_code, | ||||
|         "trace_channel": pick.waveform_id.channel_code, | ||||
|         "station_location_code": pick.waveform_id.location_code, | ||||
|         "time": pick.time | ||||
| @@ -134,7 +112,6 @@ def find_trace(pick_time, traces): | ||||
|         if pick_time > tr.stats.endtime: | ||||
|             continue | ||||
|         if pick_time >= tr.stats.starttime: | ||||
|             # print(pick_time, " - selected trace: ", tr) | ||||
|             return tr | ||||
| 
 | ||||
|     logger.warning(f"no matching trace for peak: {pick_time}") | ||||
| @@ -152,12 +129,26 @@ def get_trace_path(input_path, trace_params): | ||||
|     return path | ||||
| 
 | ||||
| 
 | ||||
| def get_three_channels_trace_paths(input_path, trace_params): | ||||
|     year = trace_params["time"].year | ||||
|     day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year | ||||
|     net = trace_params["station_network_code"] | ||||
|     station = trace_params["station_code"] | ||||
|     channel_base = trace_params["trace_channel"] | ||||
|     paths = [] | ||||
|     for ch in ["E", "N", "Z"]: | ||||
|         channel = channel_base[:-1] + ch | ||||
|         paths.append( | ||||
|             f"{input_path}/{year}/{net}/{station}/{channel}.D/{net}.{station}..{channel}.D.{year}.{day_of_year:03}") | ||||
|     return paths | ||||
| 
 | ||||
| 
 | ||||
| def load_trace(input_path, trace_params): | ||||
|     trace_path = get_trace_path(input_path, trace_params) | ||||
|     trace = None | ||||
| 
 | ||||
|     if not os.path.isfile(trace_path): | ||||
|         logger.w(trace_path + " not found") | ||||
|         logger.warning(trace_path + " not found") | ||||
|     else: | ||||
|         stream = obspy.read(trace_path) | ||||
|         if len(stream.traces) > 1: | ||||
| @@ -171,19 +162,26 @@ def load_trace(input_path, trace_params): | ||||
| 
 | ||||
| 
 | ||||
| def load_stream(input_path, trace_params, time_before=60, time_after=60): | ||||
|     trace_path = get_trace_path(input_path, trace_params) | ||||
|     sampling_rate, stream = None, None | ||||
|     pick_time = trace_params["time"] | ||||
| 
 | ||||
|     if not os.path.isfile(trace_path): | ||||
|         print(trace_path + " not found") | ||||
|     else: | ||||
|         stream = obspy.read(trace_path) | ||||
|     trace_paths = get_three_channels_trace_paths(input_path, trace_params) | ||||
|     for trace_path in trace_paths: | ||||
|         if not os.path.isfile(trace_path): | ||||
|             logger.warning(trace_path + " not found") | ||||
|         else: | ||||
|             if stream is None: | ||||
|                 stream = obspy.read(trace_path) | ||||
|             else: | ||||
|                 stream += obspy.read(trace_path) | ||||
| 
 | ||||
|     if stream is not None: | ||||
|         stream = stream.slice(pick_time - time_before, pick_time + time_after) | ||||
|         if len(stream.traces) == 0: | ||||
|             print(f"no data in: {trace_path}") | ||||
|         else: | ||||
|             sampling_rate = stream.traces[0].stats.sampling_rate | ||||
|             stream.merge() | ||||
| 
 | ||||
|     return sampling_rate, stream | ||||
| 
 | ||||
| @@ -202,23 +200,36 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path): | ||||
| 
 | ||||
|     metadata_path = output_path + "/metadata.csv" | ||||
|     waveforms_path = output_path + "/waveforms.hdf5" | ||||
|      | ||||
|     events_to_convert = events_stats[events_stats['pick_count'] > 0] | ||||
| 
 | ||||
|     logger.debug("Catalog loaded, starting conversion ...") | ||||
| 
 | ||||
|     logger.debug("Catalog loaded, starting converting {events_to_convert} events ...") | ||||
| 
 | ||||
|     with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer: | ||||
|         writer.data_format = { | ||||
|             "dimension_order": "CW", | ||||
|             "component_order": "ZNE", | ||||
|         } | ||||
| 
 | ||||
|         for i, event in enumerate(events): | ||||
|             logger.debug(f"Converting {i} event") | ||||
|             event_params = get_event_params(event) | ||||
|             event_params["split"] = events_stats.loc[i, "split"] | ||||
| 
 | ||||
|             picks_per_station = {} | ||||
|             for pick in event.picks: | ||||
|                 trace_params = get_trace_params(pick) | ||||
|                 station = pick.waveform_id.station_code | ||||
|                 if station in picks_per_station: | ||||
|                     picks_per_station[station].append(pick) | ||||
|                 else: | ||||
|                     picks_per_station[station] = [pick] | ||||
| 
 | ||||
|             for picks in picks_per_station.values(): | ||||
| 
 | ||||
|                 trace_params = get_trace_params(picks[0]) | ||||
|                 sampling_rate, stream = load_stream(input_path, trace_params) | ||||
|                 if stream is None: | ||||
|                 if stream is None or len(stream.traces) == 0: | ||||
|                     continue | ||||
| 
 | ||||
|                 actual_t_start, data, _ = sbu.stream_to_array( | ||||
| @@ -229,22 +240,19 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path): | ||||
|                 trace_params["trace_sampling_rate_hz"] = sampling_rate | ||||
|                 trace_params["trace_start_time"] = str(actual_t_start) | ||||
| 
 | ||||
|                 pick_time = obspy.core.utcdatetime.UTCDateTime(trace_params["time"]) | ||||
|                 pick_idx = (pick_time - actual_t_start) * sampling_rate | ||||
| 
 | ||||
|                 trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx) | ||||
|                 for pick in picks: | ||||
|                     pick_time = obspy.core.utcdatetime.UTCDateTime(pick.time) | ||||
|                     pick_idx = (pick_time - actual_t_start) * sampling_rate | ||||
|                     trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx) | ||||
| 
 | ||||
|                 writer.add_trace({**event_params, **trace_params}, data) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
| 
 | ||||
|     parser = argparse.ArgumentParser(description='Convert mseed files to seisbench format') | ||||
|     parser.add_argument('--input_path', type=str, help='Path to mseed files') | ||||
|     parser.add_argument('--catalog_path', type=str, help='Path to events catalog in quakeml format') | ||||
|     parser.add_argument('--output_path', type=str, help='Path to output files') | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
| 
 | ||||
|     convert_mseed_to_seisbench_format(args.input_path, args.catalog_path, args.output_path) | ||||
							
								
								
									
										149
									
								
								scripts/perf_analysis.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								scripts/perf_analysis.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,149 @@ | ||||
| import json | ||||
| import pathlib | ||||
|  | ||||
| import numpy as np | ||||
| import obspy | ||||
| import pandas as pd | ||||
| import seisbench.data as sbd | ||||
| import seisbench.models as sbm | ||||
| from seisbench.models.team import itertools | ||||
| from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve | ||||
|  | ||||
| datasets = [ | ||||
|     # path to datasets in seisbench format | ||||
| ] | ||||
|  | ||||
| models = [ | ||||
|     # model names | ||||
| ] | ||||
|  | ||||
|  | ||||
| def find_keys_phase(meta, phase): | ||||
|     phases = [] | ||||
|     for k in meta.keys(): | ||||
|         if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"): | ||||
|             phases.append(k) | ||||
|  | ||||
|     return phases | ||||
|  | ||||
|  | ||||
| def create_stream(meta, raw, start, length=30): | ||||
|  | ||||
|     st = obspy.Stream() | ||||
|  | ||||
|     for i in range(3): | ||||
|         tr = obspy.Trace(raw[i, :]) | ||||
|         tr.stats.starttime = meta["trace_start_time"] | ||||
|         tr.stats.sampling_rate = meta["trace_sampling_rate_hz"] | ||||
|         tr.stats.network = meta["station_network_code"] | ||||
|         tr.stats.station = meta["station_code"] | ||||
|         tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i] | ||||
|  | ||||
|         stop = start + length | ||||
|         tr = tr.slice(start, stop) | ||||
|  | ||||
|         st.append(tr) | ||||
|  | ||||
|     return st | ||||
|  | ||||
|  | ||||
| def get_pred(model, stream): | ||||
|     ann = model.annotate(stream) | ||||
|     noise = ann.select(channel="PhaseNet_N")[0] | ||||
|     pred = max(1 - noise.data) | ||||
|     return pred | ||||
|  | ||||
|  | ||||
| def to_short(stream): | ||||
|     short = [tr for tr in stream if tr.data.shape[0] < 3001] | ||||
|     return any(short) | ||||
|  | ||||
|  | ||||
| for ds, model_name in itertools.product(datasets, models): | ||||
|  | ||||
|     data = sbd.WaveformDataset(ds, sampling_rate=100).test() | ||||
|     data_name = pathlib.Path(ds).stem | ||||
|     fname = f"roc___{model_name}___{data_name}.csv" | ||||
|  | ||||
|     print(f"{fname:.<50s}.... ", flush=True, end="") | ||||
|  | ||||
|     if pathlib.Path(fname).is_file(): | ||||
|         print(" ready, skipping", flush=True) | ||||
|         continue | ||||
|  | ||||
|     p_labels = find_keys_phase(data.metadata, "P") | ||||
|     s_labels = find_keys_phase(data.metadata, "S") | ||||
|  | ||||
|     model = sbm.PhaseNet().from_pretrained(model_name) | ||||
|  | ||||
|     label_true = [] | ||||
|     label_pred = [] | ||||
|  | ||||
|     for i in range(len(data)): | ||||
|  | ||||
|         waveform, metadata = data.get_sample(i) | ||||
|         m = pd.Series(metadata) | ||||
|  | ||||
|         has_p_label = m[p_labels].notna() | ||||
|         has_s_label = m[s_labels].notna() | ||||
|  | ||||
|         if any(has_p_label): | ||||
|  | ||||
|             trace_start_time = obspy.UTCDateTime(m["trace_start_time"]) | ||||
|             pick_sample = m[p_labels][has_p_label][0] | ||||
|  | ||||
|             start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15 | ||||
|  | ||||
|             try: | ||||
|                 st_p = create_stream(m, waveform, start) | ||||
|                 if not (to_short(st_p)): | ||||
|                     pred_p = get_pred(model, st_p) | ||||
|                     label_true.append(1) | ||||
|                     label_pred.append(pred_p) | ||||
|             except IndexError: | ||||
|                 pass | ||||
|  | ||||
|             try: | ||||
|                 st_n = create_stream(m, waveform, trace_start_time + 1) | ||||
|                 if not (to_short(st_n)): | ||||
|                     pred_n = get_pred(model, st_n) | ||||
|                     label_true.append(0) | ||||
|                     label_pred.append(pred_n) | ||||
|             except IndexError: | ||||
|                 pass | ||||
|  | ||||
|         if any(has_s_label): | ||||
|             trace_start_time = obspy.UTCDateTime(m["trace_start_time"]) | ||||
|             pick_sample = m[s_labels][has_s_label][0] | ||||
|             start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15 | ||||
|  | ||||
|             try: | ||||
|                 st_s = create_stream(m, waveform, start) | ||||
|                 if not (to_short(st_s)): | ||||
|                     pred_s = get_pred(model, st_s) | ||||
|                     label_true.append(1) | ||||
|                     label_pred.append(pred_s) | ||||
|             except IndexError: | ||||
|                 pass | ||||
|  | ||||
|     fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred) | ||||
|     df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds}) | ||||
|     df.to_csv(fname) | ||||
|  | ||||
|     precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred) | ||||
|     prc_thresholds_extra = np.append(prc_thresholds, -999) | ||||
|     df = pd.DataFrame( | ||||
|         {"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra} | ||||
|     ) | ||||
|     df.to_csv(fname.replace("roc", "pr")) | ||||
|  | ||||
|     stats = { | ||||
|         "model": str(model_name), | ||||
|         "data": str(data_name), | ||||
|         "auc": float(roc_auc_score(label_true, label_pred)), | ||||
|     } | ||||
|  | ||||
|     with open(f"stats___{model_name}___{data_name}.json", "w") as fp: | ||||
|         json.dump(stats, fp) | ||||
|  | ||||
|     print(" finished", flush=True) | ||||
| @@ -15,33 +15,49 @@ import generate_eval_targets | ||||
| import hyperparameter_sweep | ||||
| import eval | ||||
| import collect_results | ||||
| from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files | ||||
| import importlib | ||||
| import config_loader | ||||
| import input_validate | ||||
| import data | ||||
|  | ||||
| logging.root.setLevel(logging.INFO) | ||||
| logger = logging.getLogger('pipeline') | ||||
|  | ||||
|  | ||||
| def load_sweep_config(model_name, args): | ||||
|  | ||||
|     if model_name == "PhaseNet" and args.phasenet_config is not None: | ||||
|         sweep_fname = args.phasenet_config | ||||
|     elif model_name == "GPD" and args.gpd_config is not None: | ||||
|         sweep_fname = args.gpd_config | ||||
|     elif model_name == "BasicPhaseAE" and args.basic_phase_ae_config is not None: | ||||
|         sweep_fname = args.basic_phase_ae_config | ||||
|     elif model_name == "EQTransformer" and args.eqtransformer_config is not None: | ||||
|         sweep_fname = args.eqtransformer_config | ||||
|     else: | ||||
|         # use the default sweep config for the model | ||||
|         sweep_fname = sweep_files[model_name] | ||||
|         sweep_fname = config_loader.sweep_files[model_name] | ||||
|  | ||||
|     logger.info(f"Loading sweep config: {sweep_fname}") | ||||
|  | ||||
|     return util.load_sweep_config(sweep_fname) | ||||
|  | ||||
|  | ||||
| def find_the_best_params(model_name, args): | ||||
| def validate_pipeline_input(args): | ||||
|  | ||||
|     # validate input parameters | ||||
|     for model_name in args.models: | ||||
|         sweep_config = load_sweep_config(model_name, args) | ||||
|         input_validate.validate_sweep_config(sweep_config, model_name) | ||||
|  | ||||
|     # validate dataset | ||||
|     data.validate_custom_dataset(config_loader.data_path) | ||||
|  | ||||
|  | ||||
| def find_the_best_params(sweep_config): | ||||
|     # find the best hyperparams for the model_name | ||||
|     model_name = sweep_config['parameters']['model_name'] | ||||
|     logger.info(f"Starting searching for the best hyperparams for the model: {model_name}") | ||||
|  | ||||
|     sweep_config = load_sweep_config(model_name, args) | ||||
|     sweep_runner = hyperparameter_sweep.start_sweep(sweep_config) | ||||
|  | ||||
|     # wait for all runs to finish | ||||
| @@ -58,9 +74,9 @@ def find_the_best_params(model_name, args): | ||||
|  | ||||
|  | ||||
| def generate_predictions(sweep_id, model_name): | ||||
|     experiment_name = f"{dataset_name}_{model_name}" | ||||
|     experiment_name = f"{config_loader.dataset_name}_{model_name}" | ||||
|     eval.main(weights=experiment_name, | ||||
|               targets=targets_path, | ||||
|               targets=config_loader.targets_path, | ||||
|               sets='dev,test', | ||||
|               batchsize=128, | ||||
|               num_workers=4, | ||||
| @@ -73,22 +89,49 @@ def main(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("--phasenet_config", type=str, required=False) | ||||
|     parser.add_argument("--gpd_config", type=str, required=False) | ||||
|     parser.add_argument("--basic_phase_ae_config", type=str, required=False) | ||||
|     parser.add_argument("--eqtransformer_config", type=str, required=False) | ||||
|     parser.add_argument("--dataset", type=str, required=False) | ||||
|     available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"] | ||||
|     parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models, | ||||
|                         help="Models to train and evaluate (default: all)") | ||||
|     parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # generate labels | ||||
|     logger.info("Started generating labels for the dataset.") | ||||
|     generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None) | ||||
|     if not args.collect_results: | ||||
|  | ||||
|     # find the best hyperparams for the models | ||||
|     logger.info("Started training the models.") | ||||
|     for model_name in ["GPD", "PhaseNet"]: | ||||
|         sweep_id = find_the_best_params(model_name, args) | ||||
|         generate_predictions(sweep_id, model_name) | ||||
|         if args.dataset is not None: | ||||
|             util.set_dataset(args.dataset) | ||||
|             importlib.reload(config_loader) | ||||
|  | ||||
|         validate_pipeline_input(args) | ||||
|  | ||||
|         logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.") | ||||
|  | ||||
|         # generate labels | ||||
|         logger.info("Started generating labels for the dataset.") | ||||
|         generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate, | ||||
|                                    None) | ||||
|  | ||||
|         # find the best hyperparams for the models | ||||
|         logger.info("Started training the models.") | ||||
|         for model_name in args.models: | ||||
|             sweep_config = load_sweep_config(model_name, args) | ||||
|             sweep_id = find_the_best_params(sweep_config) | ||||
|             generate_predictions(sweep_id, model_name) | ||||
|  | ||||
|     # collect results | ||||
|     logger.info("Collecting results.") | ||||
|     collect_results.traverse_path("pred", "pred/results.csv") | ||||
|     logger.info("Results saved in pred/results.csv") | ||||
|     results_path = "pred/results.csv" | ||||
|     collect_results.traverse_path("pred", results_path) | ||||
|     logger.info(f"Results saved in {results_path}") | ||||
|  | ||||
|     # log calculated metrics (MAE) on w&b | ||||
|     logger.info("Logging MAE metrics on w&b.") | ||||
|     util.log_metrics(results_path) | ||||
|  | ||||
|     logger.info("Pipeline finished") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|   | ||||
							
								
								
									
										19
									
								
								scripts/run_pipeline_template.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								scripts/run_pipeline_template.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| #!/bin/bash | ||||
| #SBATCH --job-name=job_name | ||||
| #SBATCH --time=10:00:00 | ||||
| #SBATCH --account=						### to fill | ||||
| #SBATCH --partition=plgrid-gpu-v100 | ||||
| #SBATCH --cpus-per-task=1 | ||||
| #SBATCH --ntasks-per-node=1 | ||||
| #SBATCH --gres=gpu:1 | ||||
|  | ||||
| source path/to/mambaforge/bin/activate   ### to change | ||||
| conda activate epos-ai-train | ||||
|  | ||||
|  | ||||
| python -c "import torch; print('CUDA available:', torch.cuda.is_available())" | ||||
| python -c "import torch; print('Number of CUDA devices:', torch.cuda.device_count())" | ||||
| python -c "import torch; print('Name of GPU:', torch.cuda.get_device_name(torch.cuda.current_device()))" | ||||
|  | ||||
|  | ||||
| python pipeline.py --dataset "bogdanka" | ||||
| @@ -1,5 +1,10 @@ | ||||
| """ | ||||
| This script offers general functionality required in multiple places. | ||||
| ----------------- | ||||
| Copyright © 2023 ACK Cyfronet AGH, Poland. | ||||
| This work was partially funded by EPOS Project funded in frame of PL-POIR4.2 | ||||
| ----------------- | ||||
|  | ||||
| This script runs the pipeline for the training and evaluation of the models. | ||||
| """ | ||||
|  | ||||
| import numpy as np | ||||
| @@ -7,13 +12,15 @@ import pandas as pd | ||||
| import os | ||||
| import logging | ||||
| import glob | ||||
| import json | ||||
| import wandb | ||||
|  | ||||
| from dotenv import load_dotenv | ||||
| import sys | ||||
| from config_loader import models_path, configs_path | ||||
| from config_loader import models_path, configs_path, config_path | ||||
| import yaml | ||||
| load_dotenv() | ||||
|  | ||||
| load_dotenv() | ||||
|  | ||||
| logging.basicConfig() | ||||
| logging.getLogger().setLevel(logging.INFO) | ||||
| @@ -38,8 +45,16 @@ def load_best_model_data(sweep_id, weights): | ||||
|         # Get best run parameters | ||||
|         best_run = sweep.best_run() | ||||
|         run_id = best_run.id | ||||
|         matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt") | ||||
|         if len(matching_models)!=1: | ||||
|  | ||||
|         run = api.run(f"{wandb_user}/{wandb_project_name}/runs/{run_id}") | ||||
|         dataset = run.config["dataset_name"] | ||||
|         model = run.config["model_name"][0] | ||||
|         experiment = f"{dataset}_{model}" | ||||
|  | ||||
|         checkpoints_path = f"{models_path}/{experiment}/*run={run_id}*ckpt" | ||||
|         logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}") | ||||
|         matching_models = glob.glob(checkpoints_path) | ||||
|         if len(matching_models) != 1: | ||||
|             raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id) | ||||
|         best_checkpoint_path = matching_models[0] | ||||
|  | ||||
| @@ -62,31 +77,6 @@ def load_best_model_data(sweep_id, weights): | ||||
|     return best_checkpoint_path, run_id | ||||
|  | ||||
|  | ||||
| def load_best_model(model_cls, weights, version): | ||||
|     """ | ||||
|     Determines the model with lowest validation loss from the csv logs and loads it | ||||
|  | ||||
|     :param model_cls: Class of the lightning module to load | ||||
|     :param weights: Path to weights as in cmd arguments | ||||
|     :param version: String of version file | ||||
|     :return: Instance of lightning module that was loaded from the best checkpoint | ||||
|     """ | ||||
|     metrics = pd.read_csv(weights / version / "metrics.csv") | ||||
|  | ||||
|     idx = np.nanargmin(metrics["val_loss"]) | ||||
|     min_row = metrics.iloc[idx] | ||||
|  | ||||
|     #  For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805 | ||||
|     #  and https://github.com/Lightning-AI/lightning/issues/16636. | ||||
|     #  For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't | ||||
|     checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt" | ||||
|  | ||||
|     # For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372 | ||||
|     checkpoint_path = weights / version / "checkpoints" / checkpoint | ||||
|  | ||||
|     return model_cls.load_from_checkpoint(checkpoint_path) | ||||
|  | ||||
|  | ||||
| default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None) | ||||
| if default_workers is None: | ||||
|     logging.warning( | ||||
| @@ -117,3 +107,51 @@ def load_sweep_config(sweep_fname): | ||||
|         sys.exit(1) | ||||
|  | ||||
|     return sweep_config | ||||
|  | ||||
|  | ||||
| def log_metrics(results_file): | ||||
|     """ | ||||
|  | ||||
|     :param results_file: csv file with calculated metrics | ||||
|     :return: | ||||
|     """ | ||||
|  | ||||
|     api = wandb.Api() | ||||
|     wandb_project_name = os.environ.get("WANDB_PROJECT") | ||||
|     wandb_user = os.environ.get("WANDB_USER") | ||||
|  | ||||
|     results = pd.read_csv(results_file) | ||||
|     for run_id in results["version"].unique(): | ||||
|         try: | ||||
|             run = api.run(f"{wandb_user}/{wandb_project_name}/{run_id}") | ||||
|             metrics_to_log = {} | ||||
|             run_results = results[results["version"] == run_id] | ||||
|             for col in run_results.columns: | ||||
|                 if 'mae' in col: | ||||
|                     metrics_to_log[col] = run_results[col].values[0] | ||||
|                     run.summary[col] = run_results[col].values[0] | ||||
|  | ||||
|             run.summary.update() | ||||
|             logging.info(f"Logged metrics for run: {run_id}, {metrics_to_log}") | ||||
|              | ||||
|         except Exception as e: | ||||
|             print(f"An error occurred: {e}, {type(e).__name__}, {e.args}") | ||||
|  | ||||
|  | ||||
| def set_dataset(dataset_name): | ||||
|     """ | ||||
|     Sets the dataset name in the config file | ||||
|     :param dataset_name: | ||||
|     :return: | ||||
|     """ | ||||
|  | ||||
|     with open(config_path, "r+") as f: | ||||
|         config = json.load(f) | ||||
|         config["dataset_name"] = dataset_name | ||||
|         config["data_path"] = f"datasets/{dataset_name}/seisbench_format/" | ||||
|  | ||||
|         f.seek(0)  # rewind | ||||
|         json.dump(config, f, indent=4) | ||||
|         f.truncate() | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,19 +0,0 @@ | ||||
| #!/bin/bash | ||||
| #SBATCH --job-name=mseeds_to_seisbench | ||||
| #SBATCH --time=1:00:00 | ||||
| #SBATCH --account=plgeposai22gpu-gpu | ||||
| #SBATCH --partition plgrid | ||||
| #SBATCH --cpus-per-task=1 | ||||
| #SBATCH --ntasks-per-node=1 | ||||
| #SBATCH --mem=24gb | ||||
|  | ||||
|  | ||||
| ## activate conda environment | ||||
| source /net/pr2/projects/plgrid/plggeposai/kmilian/mambaforge/bin/activate | ||||
| conda activate epos-ai-train | ||||
|  | ||||
| input_path="/net/pr2/projects/plgrid/plggeposai/datasets/bogdanka" | ||||
| catalog_path="/net/pr2/projects/plgrid/plggeposai/datasets/bogdanka/BOIS_all.xml" | ||||
| output_path="/net/pr2/projects/plgrid/plggeposai/kmilian/platform-demo-scripts/datasets/bogdanka/seisbench_format" | ||||
|  | ||||
| python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path | ||||
							
								
								
									
										230
									
								
								utils/utils.py
									
									
									
									
									
								
							
							
						
						
									
										230
									
								
								utils/utils.py
									
									
									
									
									
								
							| @@ -1,230 +0,0 @@ | ||||
| import os | ||||
| import pandas as pd | ||||
| import glob | ||||
| from pathlib import Path | ||||
|  | ||||
| import obspy | ||||
| from obspy.core.event import read_events | ||||
|  | ||||
| import seisbench.data as sbd | ||||
| import seisbench.util as sbu | ||||
| import sys | ||||
| import logging | ||||
|  | ||||
| logging.basicConfig(filename="output.out", | ||||
|                     filemode='a', | ||||
|                     format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', | ||||
|                     datefmt='%H:%M:%S', | ||||
|                     level=logging.DEBUG) | ||||
|  | ||||
| logger = logging.getLogger('converter') | ||||
|  | ||||
| def create_traces_catalog(directory, years): | ||||
|     for year in years: | ||||
|         directory = f"{directory}/{year}" | ||||
|         files = glob.glob(directory) | ||||
|         traces = [] | ||||
|         for i, f in enumerate(files): | ||||
|             st = obspy.read(f) | ||||
|  | ||||
|             for tr in st.traces: | ||||
|                 # trace_id = tr.id | ||||
|                 # start = tr.meta.starttime | ||||
|                 # end = tr.meta.endtime | ||||
|  | ||||
|                 trs = pd.Series({ | ||||
|                     'trace_id': tr.id, | ||||
|                     'trace_st': tr.meta.starttime, | ||||
|                     'trace_end': tr.meta.endtime, | ||||
|                     'stream_fname': f | ||||
|                 }) | ||||
|                 traces.append(trs) | ||||
|  | ||||
|         traces_catalog = pd.DataFrame(pd.concat(traces)).transpose() | ||||
|         traces_catalog.to_csv("data/bogdanka/traces_catalog.csv", append=True, index=False) | ||||
|  | ||||
|  | ||||
| def split_events(events, input_path): | ||||
|  | ||||
|     logger.info("Splitting available events into train, dev and test sets ...") | ||||
|     events_stats = pd.DataFrame() | ||||
|     events_stats.index.name = "event" | ||||
|  | ||||
|     for i, event in enumerate(events): | ||||
|         #check if mseed exists | ||||
|         actual_picks = 0 | ||||
|         for pick in event.picks: | ||||
|             trace_params = get_trace_params(pick) | ||||
|             trace_path = get_trace_path(input_path, trace_params) | ||||
|             if os.path.isfile(trace_path): | ||||
|                 actual_picks += 1 | ||||
|  | ||||
|         events_stats.loc[i, "pick_count"] = actual_picks | ||||
|  | ||||
|     events_stats['pick_count_cumsum'] = events_stats.pick_count.cumsum() | ||||
|  | ||||
|     train_th = 0.7 * events_stats.pick_count_cumsum.values[-1] | ||||
|     dev_th = 0.85 * events_stats.pick_count_cumsum.values[-1] | ||||
|  | ||||
|     events_stats['split'] = 'test' | ||||
|     for i, event in events_stats.iterrows(): | ||||
|         if event['pick_count_cumsum'] < train_th: | ||||
|             events_stats.loc[i, 'split'] = 'train' | ||||
|         elif event['pick_count_cumsum'] < dev_th: | ||||
|             events_stats.loc[i, 'split'] = 'dev' | ||||
|         else: | ||||
|             break | ||||
|  | ||||
|     return events_stats | ||||
|  | ||||
|  | ||||
| def get_event_params(event): | ||||
|     origin = event.preferred_origin() | ||||
|     if origin is None: | ||||
|         return {} | ||||
|     # print(origin) | ||||
|  | ||||
|     mag = event.preferred_magnitude() | ||||
|  | ||||
|     source_id = str(event.resource_id) | ||||
|  | ||||
|     event_params = { | ||||
|         "source_id": source_id, | ||||
|         "source_origin_uncertainty_sec": origin.time_errors["uncertainty"], | ||||
|         "source_latitude_deg": origin.latitude, | ||||
|         "source_latitude_uncertainty_km": origin.latitude_errors["uncertainty"], | ||||
|         "source_longitude_deg": origin.longitude, | ||||
|         "source_longitude_uncertainty_km": origin.longitude_errors["uncertainty"], | ||||
|         "source_depth_km": origin.depth / 1e3, | ||||
|         "source_depth_uncertainty_km": origin.depth_errors["uncertainty"] / 1e3 if origin.depth_errors[ | ||||
|                                                                                        "uncertainty"] is not None else None, | ||||
|     } | ||||
|  | ||||
|     if mag is not None: | ||||
|         event_params["source_magnitude"] = mag.mag | ||||
|         event_params["source_magnitude_uncertainty"] = mag.mag_errors["uncertainty"] | ||||
|         event_params["source_magnitude_type"] = mag.magnitude_type | ||||
|         event_params["source_magnitude_author"] = mag.creation_info.agency_id if mag.creation_info is not None else None | ||||
|  | ||||
|     return event_params | ||||
|  | ||||
|  | ||||
| def get_trace_params(pick): | ||||
|     net = pick.waveform_id.network_code | ||||
|     sta = pick.waveform_id.station_code | ||||
|  | ||||
|     trace_params = { | ||||
|         "station_network_code": net, | ||||
|         "station_code": sta, | ||||
|         "trace_channel": pick.waveform_id.channel_code, | ||||
|         "station_location_code": pick.waveform_id.location_code, | ||||
|         "time": pick.time | ||||
|     } | ||||
|  | ||||
|     return trace_params | ||||
|  | ||||
|  | ||||
| def find_trace(pick_time, traces): | ||||
|     for tr in traces: | ||||
|         if pick_time > tr.stats.endtime: | ||||
|             continue | ||||
|         if pick_time >= tr.stats.starttime: | ||||
|             # print(pick_time, " - selected trace: ", tr) | ||||
|             return tr | ||||
|  | ||||
|     logger.warning(f"no matching trace for peak: {pick_time}") | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def get_trace_path(input_path, trace_params): | ||||
|     year = trace_params["time"].year | ||||
|     day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year | ||||
|     net = trace_params["station_network_code"] | ||||
|     station = trace_params["station_code"] | ||||
|     tr_channel = trace_params["trace_channel"] | ||||
|  | ||||
|     path = f"{input_path}/{year}/{net}/{station}/{tr_channel}.D/{net}.{station}..{tr_channel}.D.{year}.{day_of_year}" | ||||
|     return path | ||||
|  | ||||
|  | ||||
| def load_trace(input_path, trace_params): | ||||
|     trace_path = get_trace_path(input_path, trace_params) | ||||
|     trace = None | ||||
|  | ||||
|     if not os.path.isfile(trace_path): | ||||
|         logger.w(trace_path + " not found") | ||||
|     else: | ||||
|         stream = obspy.read(trace_path) | ||||
|         if len(stream.traces) > 1: | ||||
|             trace = find_trace(trace_params["time"], stream.traces) | ||||
|         elif len(stream.traces) == 0: | ||||
|             logger.warning(f"no data in: {trace_path}") | ||||
|         else: | ||||
|             trace = stream.traces[0] | ||||
|  | ||||
|     return trace | ||||
|  | ||||
|  | ||||
| def load_stream(input_path, trace_params, time_before=60, time_after=60): | ||||
|     trace_path = get_trace_path(input_path, trace_params) | ||||
|     sampling_rate, stream = None, None | ||||
|     pick_time = trace_params["time"] | ||||
|  | ||||
|     if not os.path.isfile(trace_path): | ||||
|         print(trace_path + " not found") | ||||
|     else: | ||||
|         stream = obspy.read(trace_path) | ||||
|         stream = stream.slice(pick_time - time_before, pick_time + time_after) | ||||
|         if len(stream.traces) == 0: | ||||
|             print(f"no data in: {trace_path}") | ||||
|         else: | ||||
|             sampling_rate = stream.traces[0].stats.sampling_rate | ||||
|  | ||||
|     return sampling_rate, stream | ||||
|  | ||||
|  | ||||
| def convert_mseed_to_seisbench_format(): | ||||
|     input_path = "/net/pr2/projects/plgrid/plggeposai" | ||||
|     logger.info("Loading events catalog ...") | ||||
|     events = read_events(input_path + "/BOIS_all.xml") | ||||
|     events_stats = split_events(events) | ||||
|     output_path = input_path + "/seisbench_format" | ||||
|     metadata_path = output_path + "/metadata.csv" | ||||
|     waveforms_path = output_path + "/waveforms.hdf5" | ||||
|  | ||||
|     with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer: | ||||
|         writer.data_format = { | ||||
|             "dimension_order": "CW", | ||||
|             "component_order": "ZNE", | ||||
|         } | ||||
|         for i, event in enumerate(events): | ||||
|             logger.debug(f"Converting {i} event") | ||||
|             event_params = get_event_params(event) | ||||
|             event_params["split"] = events_stats.loc[i, "split"] | ||||
|             #             b = False | ||||
|  | ||||
|             for pick in event.picks: | ||||
|                 trace_params = get_trace_params(pick) | ||||
|                 sampling_rate, stream = load_stream(input_path, trace_params) | ||||
|                 if stream is None: | ||||
|                     continue | ||||
|  | ||||
|                 actual_t_start, data, _ = sbu.stream_to_array( | ||||
|                     stream, | ||||
|                     component_order=writer.data_format["component_order"], | ||||
|                 ) | ||||
|  | ||||
|                 trace_params["trace_sampling_rate_hz"] = sampling_rate | ||||
|                 trace_params["trace_start_time"] = str(actual_t_start) | ||||
|  | ||||
|                 pick_time = obspy.core.utcdatetime.UTCDateTime(trace_params["time"]) | ||||
|                 pick_idx = (pick_time - actual_t_start) * sampling_rate | ||||
|  | ||||
|                 trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx) | ||||
|  | ||||
|                 writer.add_trace({**event_params, **trace_params}, data) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     convert_mseed_to_seisbench_format() | ||||
|     # create_traces_catalog("/net/pr2/projects/plgrid/plggeposai/", ["2018", "2019"]) | ||||
		Reference in New Issue
	
	Block a user