Compare commits
10 Commits
switching_
...
master
Author | SHA1 | Date | |
---|---|---|---|
281c73764d | |||
e86f131cc0 | |||
|
5c3ce04868 | ||
94054bc391 | |||
318a344c15 | |||
bb2e136d42 | |||
87de2e7a6c | |||
503bec883e | |||
|
4658b8d866 | ||
|
4c2679a005 |
27
README.md
27
README.md
@ -22,7 +22,7 @@ Please download and install [Mambaforge](https://github.com/conda-forge/miniforg
|
||||
After successful installation and within the Mambaforge environment please clone this repository:
|
||||
|
||||
```
|
||||
git clone ssh://git@git.plgrid.pl:7999/eai/platform-demo-scripts.git
|
||||
git clone https://epos-apps.grid.cyfronet.pl/epos-ai/platform-demo-scripts.git
|
||||
```
|
||||
and please run for Linux or Windows platforms:
|
||||
|
||||
@ -111,7 +111,7 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
|
||||
|
||||
The script performs the following steps:
|
||||
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
|
||||
1. Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
|
||||
1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss.
|
||||
|
||||
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
|
||||
The results are available at
|
||||
@ -126,14 +126,31 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
|
||||
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
|
||||
|
||||
<br/>
|
||||
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for GPD model, run:
|
||||
The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run:
|
||||
|
||||
```python pipeline.py --gpd_config <new config file>```
|
||||
|
||||
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
|
||||
|
||||
|
||||
Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:
|
||||
* `batch_size`
|
||||
* `learning_rate`
|
||||
|
||||
Phasenet model has additional available parameters:
|
||||
* `norm` - normalization method, options ('peak', 'std')
|
||||
* `pretrained` - pretrained seisbench models used for transfer learning
|
||||
* `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder')
|
||||
* `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers
|
||||
|
||||
GPD model has additional parameters for filtering:
|
||||
* `highpass` - highpass filter frequency
|
||||
* `lowpass` - lowpass filter frequency
|
||||
|
||||
The sweep configs are saved in the `experiments` folder.
|
||||
|
||||
|
||||
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
|
||||
|
||||
|
||||
```python pipeline.py --dataset <dataset_name>```
|
||||
|
||||
### Troubleshooting
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"dataset_name": "lumineos",
|
||||
"data_path": "datasets/lumineos/seisbench_format/",
|
||||
"dataset_name": "bogdanka_2018_2022",
|
||||
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
|
||||
"targets_path": "datasets/targets",
|
||||
"models_path": "weights",
|
||||
"configs_path": "experiments",
|
||||
@ -13,5 +13,5 @@
|
||||
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
||||
"EQTransformer": "sweep_eqtransformer.yaml"
|
||||
},
|
||||
"experiment_count": 20
|
||||
"experiment_count": 15
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
name: BasicPhaseAE
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
@ -7,13 +8,9 @@ parameters:
|
||||
value:
|
||||
- BasicPhaseAE
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
values: [64, 128, 256]
|
||||
max_epochs:
|
||||
value:
|
||||
- 20
|
||||
- 30
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.001
|
||||
values: [0.01, 0.005, 0.001]
|
||||
|
@ -8,13 +8,9 @@ parameters:
|
||||
value:
|
||||
- EQTransformer
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
values: [64, 128, 256]
|
||||
max_epochs:
|
||||
value:
|
||||
- 30
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
values: [0.01, 0.005, 0.001]
|
@ -1,4 +1,4 @@
|
||||
name: GPD_fixed_highpass:2-10
|
||||
name: GPD
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
@ -8,16 +8,12 @@ parameters:
|
||||
value:
|
||||
- GPD
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
values: [64, 128, 256]
|
||||
max_epochs:
|
||||
value:
|
||||
- 30
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
values: [0.01, 0.005, 0.001]
|
||||
highpass:
|
||||
value:
|
||||
- 1
|
||||
|
@ -1,8 +1,13 @@
|
||||
"""
|
||||
This file contains functionality related to data.
|
||||
"""
|
||||
import os.path
|
||||
|
||||
import seisbench.data as sbd
|
||||
import logging
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('data')
|
||||
|
||||
|
||||
def get_dataset_by_name(name):
|
||||
@ -30,3 +35,26 @@ def get_custom_dataset(path):
|
||||
except AttributeError:
|
||||
raise ValueError(f"Unknown dataset '{path}'.")
|
||||
|
||||
|
||||
def validate_custom_dataset(data_path):
|
||||
"""
|
||||
Validate the dataset
|
||||
:param data_path: path to the dataset
|
||||
:return:
|
||||
"""
|
||||
# check if path exists
|
||||
if not os.path.isdir((data_path)):
|
||||
raise ValueError(f"Data path {data_path} does not exist.")
|
||||
|
||||
dataset = sbd.WaveformDataset(data_path)
|
||||
# check if the dataset is split into train, dev and test
|
||||
if len(dataset.train()) == 0:
|
||||
raise ValueError(f"Training set is empty.")
|
||||
if len(dataset.dev()) == 0:
|
||||
raise ValueError(f"Dev set is empty.")
|
||||
if len(dataset.test()) == 0:
|
||||
raise ValueError(f"Test set is empty.")
|
||||
|
||||
logger.info("Custom dataset validated successfully.")
|
||||
|
||||
|
||||
|
@ -7,8 +7,9 @@ import os
|
||||
import os.path
|
||||
import argparse
|
||||
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
import torch
|
||||
@ -20,6 +21,8 @@ import train
|
||||
import util
|
||||
import config_loader
|
||||
|
||||
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
os.system("ulimit -n unlimited")
|
||||
|
||||
@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST")
|
||||
if host is None:
|
||||
raise ValueError("WANDB_HOST environment variable is not set.")
|
||||
|
||||
|
||||
wandb.login(key=wandb_api_key, host=host)
|
||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||
wandb_user_name = os.environ.get("WANDB_USER")
|
||||
|
||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||
logger = logging.getLogger(script_name)
|
||||
logger.setLevel(logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def set_random_seed(seed=3):
|
||||
@ -54,6 +56,11 @@ def get_trainer_args(config):
|
||||
return trainer_args
|
||||
|
||||
|
||||
def get_arg(arg):
|
||||
if type(arg) == list:
|
||||
return arg[0]
|
||||
return arg
|
||||
|
||||
class HyperparameterSweep:
|
||||
def __init__(self, project_name, sweep_config):
|
||||
self.project_name = project_name
|
||||
@ -84,25 +91,40 @@ class HyperparameterSweep:
|
||||
return all_not_running
|
||||
|
||||
def run_experiment(self):
|
||||
|
||||
try:
|
||||
|
||||
logger.debug("Starting a new run...")
|
||||
logger.info("Starting a new run...")
|
||||
|
||||
run = wandb.init(
|
||||
project=self.project_name,
|
||||
config=config_loader.config,
|
||||
save_code=True
|
||||
save_code=True,
|
||||
entity=wandb_user_name
|
||||
)
|
||||
run.log_code(
|
||||
root=".",
|
||||
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
||||
exclude_fn=lambda path: path.endswith("template.sh")
|
||||
) # not working as expected
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name[0]
|
||||
model_name = get_arg(wandb.config.model_name)
|
||||
model_args = models.get_model_specific_args(wandb.config)
|
||||
logger.debug(f"Initializing {model_name}")
|
||||
|
||||
if "pretrained" in wandb.config:
|
||||
weights = get_arg(wandb.config.pretrained)
|
||||
if weights != "false":
|
||||
model_args["pretrained"] = weights
|
||||
|
||||
if "norm" in wandb.config:
|
||||
model_args["norm"] = get_arg(wandb.config.norm)
|
||||
|
||||
if "finetuning" in wandb.config:
|
||||
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
||||
|
||||
if "lr_reduce_factor" in wandb.config:
|
||||
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
||||
|
||||
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||
|
||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||
@ -132,10 +154,13 @@ class HyperparameterSweep:
|
||||
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=3,
|
||||
patience=5,
|
||||
verbose=True,
|
||||
mode="min")
|
||||
callbacks = [checkpoint_callback, early_stopping_callback]
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=config_loader.models_path,
|
||||
@ -143,7 +168,6 @@ class HyperparameterSweep:
|
||||
callbacks=callbacks,
|
||||
**get_trainer_args(wandb.config)
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, dev_loader)
|
||||
|
||||
except Exception as e:
|
||||
@ -155,7 +179,6 @@ class HyperparameterSweep:
|
||||
|
||||
|
||||
def start_sweep(sweep_config):
|
||||
|
||||
logger.info("Starting sweep with config: " + str(sweep_config))
|
||||
set_random_seed(config_loader.seed)
|
||||
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
||||
@ -165,7 +188,6 @@ def start_sweep(sweep_config):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sweep_config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
146
scripts/input_validate.py
Normal file
146
scripts/input_validate.py
Normal file
@ -0,0 +1,146 @@
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from typing_extensions import Literal
|
||||
from typing import Union, List, Optional
|
||||
import yaml
|
||||
import logging
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('input_validator')
|
||||
|
||||
#todo
|
||||
# 1. check if a single value is allowed in a sweep
|
||||
# 2. merge input params
|
||||
# 3. change names of the classes
|
||||
# 4. add constraints for PhaseNet, GPD
|
||||
|
||||
|
||||
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
|
||||
norm_values = Literal["peak", "std"]
|
||||
finetuning_values = Literal["all", "top", "decoder", "encoder"]
|
||||
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
|
||||
'original', 'scedc', False]
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
goal: str
|
||||
name: str
|
||||
|
||||
|
||||
class NumericValue(BaseModel):
|
||||
value: Union[int, float, List[Union[int, float]]]
|
||||
|
||||
|
||||
class NumericValues(BaseModel):
|
||||
values: List[Union[int, float]]
|
||||
|
||||
|
||||
class IntDistribution(BaseModel):
|
||||
distribution: str = "int_uniform"
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class FloatDistribution(BaseModel):
|
||||
distribution: str = "uniform"
|
||||
min: float
|
||||
max: float
|
||||
|
||||
|
||||
class Pretrained(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
values: List[pretrained_values] = None
|
||||
value: Union[pretrained_values, List[pretrained_values]] = None
|
||||
|
||||
|
||||
class Finetuning(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
values: List[finetuning_values] = None
|
||||
value: Union[finetuning_values, List[finetuning_values]] = None
|
||||
|
||||
|
||||
class Norm(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
values: List[norm_values] = None
|
||||
value: Union[norm_values, List[norm_values]] = None
|
||||
|
||||
|
||||
class ModelType(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
value: Union[model_names, List[model_names]] = None
|
||||
values: List[model_names] = None
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
model_config = ConfigDict(extra='forbid', protected_namespaces=())
|
||||
model_name: ModelType
|
||||
batch_size: Union[IntDistribution, NumericValue, NumericValues]
|
||||
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
|
||||
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
|
||||
|
||||
|
||||
class PhaseNetParameters(Parameters):
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
norm: Norm = None
|
||||
pretrained: Pretrained = None
|
||||
finetuning: Finetuning = None
|
||||
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
||||
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
|
||||
@field_validator("model_name")
|
||||
def validate_model(cls, v):
|
||||
if "PhaseNet" not in v.value:
|
||||
raise ValueError("Additional parameters implemented for PhaseNet only")
|
||||
return v
|
||||
|
||||
|
||||
class FilteringParameters(Parameters):
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
|
||||
@field_validator("model_name")
|
||||
def validate_model(cls, v):
|
||||
print(v.value)
|
||||
if v.value[0] not in ["GPD", "PhaseNet"]:
|
||||
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
||||
|
||||
|
||||
class InputParams(BaseModel):
|
||||
name: str
|
||||
method: str
|
||||
metric: Metric
|
||||
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
||||
|
||||
|
||||
def validate_sweep_yaml(yaml_filename, model_name=None):
|
||||
# Load YAML configuration
|
||||
with open(yaml_filename, 'r') as f:
|
||||
sweep_config = yaml.safe_load(f)
|
||||
|
||||
validate_sweep_config(sweep_config, model_name)
|
||||
|
||||
|
||||
def validate_sweep_config(sweep_config, model_name=None):
|
||||
|
||||
# Validate sweep config
|
||||
|
||||
input_params = InputParams(**sweep_config)
|
||||
|
||||
# Check consistency of input parameters and sweep configuration
|
||||
sweep_model_name = input_params.parameters.model_name.value
|
||||
if model_name is not None and model_name not in sweep_model_name:
|
||||
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
|
||||
logger.info(info)
|
||||
raise ValueError(info)
|
||||
logger.info("Input validation successful.")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
||||
validate_sweep_yaml(yaml_filename, None)
|
@ -7,16 +7,26 @@ import seisbench.generate as sbg
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
# import lightning as L
|
||||
|
||||
|
||||
# Allows to import this file in both jupyter notebook and code
|
||||
try:
|
||||
from .augmentations import DuplicateEvent
|
||||
except ImportError:
|
||||
from augmentations import DuplicateEvent
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||
logger = logging.getLogger(script_name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Phase dict for labelling. We only study P and S phases without differentiating between them.
|
||||
phase_dict = {
|
||||
@ -131,7 +141,31 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
self.sigma = sigma
|
||||
self.sample_boundaries = sample_boundaries
|
||||
self.loss = vector_cross_entropy
|
||||
self.model = sbm.PhaseNet(phases="PN", **kwargs)
|
||||
self.pretrained = kwargs.pop("pretrained", None)
|
||||
self.norm = kwargs.pop("norm", "peak")
|
||||
self.highpass = kwargs.pop("highpass", None)
|
||||
self.lowpass = kwargs.pop("lowpass", None)
|
||||
|
||||
|
||||
if self.pretrained is not None:
|
||||
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
||||
# self.norm = self.model.norm
|
||||
else:
|
||||
self.model = sbm.PhaseNet(**kwargs)
|
||||
|
||||
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
||||
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
||||
self.reduce_lr_on_plateau = False
|
||||
|
||||
self.initial_epochs = 0
|
||||
|
||||
if self.finetuning_strategy is not None:
|
||||
if self.finetuning_strategy == "top":
|
||||
self.initial_epochs = 3
|
||||
elif self.finetuning_strategy in ["decoder", "encoder"]:
|
||||
self.initial_epochs = 6
|
||||
|
||||
self.freeze()
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optimizer
|
||||
if self.finetuning_strategy is not None:
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
||||
self.reduce_lr_on_plateau = False
|
||||
else:
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
||||
self.reduce_lr_on_plateau = True
|
||||
#
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'val_loss',
|
||||
'interval': 'epoch',
|
||||
'reduce_on_plateau': self.reduce_lr_on_plateau,
|
||||
},
|
||||
}
|
||||
|
||||
def lr_lambda(self, epoch):
|
||||
# reduce lr after x initial epochs
|
||||
if epoch == self.initial_epochs:
|
||||
self.lr *= self.steplr_gamma
|
||||
|
||||
return self.lr
|
||||
|
||||
def lr_scheduler_step(self, scheduler, metric):
|
||||
if self.reduce_lr_on_plateau:
|
||||
scheduler.step(metric, epoch=self.current_epoch)
|
||||
else:
|
||||
scheduler.step(epoch=self.current_epoch)
|
||||
|
||||
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
||||
# scheduler.step(epoch=self.current_epoch)
|
||||
|
||||
def get_augmentations(self):
|
||||
filter = []
|
||||
if self.highpass is not None:
|
||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||
logger.info(f"Using highpass filer {self.highpass}")
|
||||
if self.lowpass is not None:
|
||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||
logger.info(f"Using lowpass filer {self.lowpass}")
|
||||
logger.info(filter)
|
||||
|
||||
return [
|
||||
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
||||
# Uses strategy variable, as padding will be handled by the random window.
|
||||
@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
windowlen=3001,
|
||||
strategy="pad",
|
||||
),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||
*filter,
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
sbg.ProbabilisticLabeller(
|
||||
label_columns=phase_dict, sigma=self.sigma, dim=0
|
||||
),
|
||||
]
|
||||
|
||||
def get_eval_augmentations(self):
|
||||
filter = []
|
||||
if self.highpass is not None:
|
||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||
if self.lowpass is not None:
|
||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||
|
||||
return [
|
||||
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||
*filter,
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
]
|
||||
|
||||
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
||||
@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
return score_detection, score_p_or_s, p_sample, s_sample
|
||||
|
||||
def freeze(self):
|
||||
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
|
||||
for p in self.model.down_branch.parameters():
|
||||
p.requires_grad = False
|
||||
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
|
||||
for p in self.model.up_branch.parameters():
|
||||
p.requires_grad = False
|
||||
elif self.finetuning_strategy == "top":
|
||||
for p in self.model.out.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
logger.info("Unfreezing layers")
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# Unfreeze some layers after x initial epochs
|
||||
if self.current_epoch == self.initial_epochs:
|
||||
self.unfreeze()
|
||||
|
||||
|
||||
class GPDLit(SeisBenchModuleLit):
|
||||
"""
|
||||
@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
|
||||
# Create overlapping windows
|
||||
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
|
||||
for i, start in enumerate(range(0, 2401, 400)):
|
||||
re[:, :, i] = x[:, :, start : start + 600]
|
||||
re[:, :, i] = x[:, :, start: start + 600]
|
||||
x = re
|
||||
|
||||
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
|
||||
@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
|
||||
for i, start in enumerate(range(0, 2401, 400)):
|
||||
if start == 0:
|
||||
# Use full window (for start==0, the end will be overwritten)
|
||||
pred[:, :, start : start + 600] = window_pred[:, i]
|
||||
pred[:, :, start: start + 600] = window_pred[:, i]
|
||||
else:
|
||||
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:]
|
||||
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
|
||||
|
||||
score_detection = torch.zeros(pred.shape[0])
|
||||
score_p_or_s = torch.zeros(pred.shape[0])
|
||||
@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit):
|
||||
|
||||
|
||||
def get_model_specific_args(config):
|
||||
|
||||
model = config.model_name[0]
|
||||
lr = config.learning_rate
|
||||
if type(lr) == list:
|
||||
@ -1125,8 +1227,12 @@ def get_model_specific_args(config):
|
||||
if 'highpass' in config:
|
||||
args['highpass'] = config.highpass
|
||||
if 'lowpass' in config:
|
||||
args['lowpass'] = config.lowpass[0]
|
||||
args['lowpass'] = config.lowpass
|
||||
case 'PhaseNet':
|
||||
if 'highpass' in config:
|
||||
args['highpass'] = config.highpass
|
||||
if 'lowpass' in config:
|
||||
args['lowpass'] = config.lowpass
|
||||
if 'sample_boundaries' in config:
|
||||
args['sample_boundaries'] = config.sample_boundaries
|
||||
case 'DPPPicker':
|
||||
|
@ -1,66 +1,44 @@
|
||||
"""
|
||||
-----------------
|
||||
Copyright © 2023 ACK Cyfronet AGH, Poland.
|
||||
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
||||
-----------------
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
||||
import obspy
|
||||
from obspy.core.event import read_events
|
||||
|
||||
import seisbench
|
||||
import seisbench.data as sbd
|
||||
import seisbench.util as sbu
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
|
||||
logging.basicConfig(filename="output.out",
|
||||
filemode='a',
|
||||
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
level=logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('converter')
|
||||
|
||||
def create_traces_catalog(directory, years):
|
||||
for year in years:
|
||||
directory = f"{directory}/{year}"
|
||||
files = glob.glob(directory)
|
||||
traces = []
|
||||
for i, f in enumerate(files):
|
||||
st = obspy.read(f)
|
||||
|
||||
for tr in st.traces:
|
||||
# trace_id = tr.id
|
||||
# start = tr.meta.starttime
|
||||
# end = tr.meta.endtime
|
||||
|
||||
trs = pd.Series({
|
||||
'trace_id': tr.id,
|
||||
'trace_st': tr.meta.starttime,
|
||||
'trace_end': tr.meta.endtime,
|
||||
'stream_fname': f
|
||||
})
|
||||
traces.append(trs)
|
||||
|
||||
traces_catalog = pd.DataFrame(pd.concat(traces)).transpose()
|
||||
traces_catalog.to_csv("data/bogdanka/traces_catalog.csv", append=True, index=False)
|
||||
|
||||
|
||||
def split_events(events, input_path):
|
||||
|
||||
logger.info("Splitting available events into train, dev and test sets ...")
|
||||
events_stats = pd.DataFrame()
|
||||
events_stats.index.name = "event"
|
||||
|
||||
for i, event in enumerate(events):
|
||||
#check if mseed exists
|
||||
# check if mseed exists
|
||||
actual_picks = 0
|
||||
for pick in event.picks:
|
||||
trace_params = get_trace_params(pick)
|
||||
trace_path = get_trace_path(input_path, trace_params)
|
||||
|
||||
if os.path.isfile(trace_path):
|
||||
actual_picks += 1
|
||||
|
||||
@ -79,6 +57,10 @@ def split_events(events, input_path):
|
||||
events_stats.loc[i, 'split'] = 'dev'
|
||||
else:
|
||||
break
|
||||
|
||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
||||
|
||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
||||
|
||||
return events_stats
|
||||
|
||||
@ -87,7 +69,6 @@ def get_event_params(event):
|
||||
origin = event.preferred_origin()
|
||||
if origin is None:
|
||||
return {}
|
||||
# print(origin)
|
||||
|
||||
mag = event.preferred_magnitude()
|
||||
|
||||
@ -115,12 +96,9 @@ def get_event_params(event):
|
||||
|
||||
|
||||
def get_trace_params(pick):
|
||||
net = pick.waveform_id.network_code
|
||||
sta = pick.waveform_id.station_code
|
||||
|
||||
trace_params = {
|
||||
"station_network_code": net,
|
||||
"station_code": sta,
|
||||
"station_network_code": pick.waveform_id.network_code,
|
||||
"station_code": pick.waveform_id.station_code,
|
||||
"trace_channel": pick.waveform_id.channel_code,
|
||||
"station_location_code": pick.waveform_id.location_code,
|
||||
"time": pick.time
|
||||
@ -134,7 +112,6 @@ def find_trace(pick_time, traces):
|
||||
if pick_time > tr.stats.endtime:
|
||||
continue
|
||||
if pick_time >= tr.stats.starttime:
|
||||
# print(pick_time, " - selected trace: ", tr)
|
||||
return tr
|
||||
|
||||
logger.warning(f"no matching trace for peak: {pick_time}")
|
||||
@ -152,12 +129,26 @@ def get_trace_path(input_path, trace_params):
|
||||
return path
|
||||
|
||||
|
||||
def get_three_channels_trace_paths(input_path, trace_params):
|
||||
year = trace_params["time"].year
|
||||
day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year
|
||||
net = trace_params["station_network_code"]
|
||||
station = trace_params["station_code"]
|
||||
channel_base = trace_params["trace_channel"]
|
||||
paths = []
|
||||
for ch in ["E", "N", "Z"]:
|
||||
channel = channel_base[:-1] + ch
|
||||
paths.append(
|
||||
f"{input_path}/{year}/{net}/{station}/{channel}.D/{net}.{station}..{channel}.D.{year}.{day_of_year:03}")
|
||||
return paths
|
||||
|
||||
|
||||
def load_trace(input_path, trace_params):
|
||||
trace_path = get_trace_path(input_path, trace_params)
|
||||
trace = None
|
||||
|
||||
if not os.path.isfile(trace_path):
|
||||
logger.w(trace_path + " not found")
|
||||
logger.warning(trace_path + " not found")
|
||||
else:
|
||||
stream = obspy.read(trace_path)
|
||||
if len(stream.traces) > 1:
|
||||
@ -171,19 +162,26 @@ def load_trace(input_path, trace_params):
|
||||
|
||||
|
||||
def load_stream(input_path, trace_params, time_before=60, time_after=60):
|
||||
trace_path = get_trace_path(input_path, trace_params)
|
||||
sampling_rate, stream = None, None
|
||||
pick_time = trace_params["time"]
|
||||
|
||||
if not os.path.isfile(trace_path):
|
||||
print(trace_path + " not found")
|
||||
else:
|
||||
stream = obspy.read(trace_path)
|
||||
trace_paths = get_three_channels_trace_paths(input_path, trace_params)
|
||||
for trace_path in trace_paths:
|
||||
if not os.path.isfile(trace_path):
|
||||
logger.warning(trace_path + " not found")
|
||||
else:
|
||||
if stream is None:
|
||||
stream = obspy.read(trace_path)
|
||||
else:
|
||||
stream += obspy.read(trace_path)
|
||||
|
||||
if stream is not None:
|
||||
stream = stream.slice(pick_time - time_before, pick_time + time_after)
|
||||
if len(stream.traces) == 0:
|
||||
print(f"no data in: {trace_path}")
|
||||
else:
|
||||
sampling_rate = stream.traces[0].stats.sampling_rate
|
||||
stream.merge()
|
||||
|
||||
return sampling_rate, stream
|
||||
|
||||
@ -202,23 +200,36 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
||||
|
||||
metadata_path = output_path + "/metadata.csv"
|
||||
waveforms_path = output_path + "/waveforms.hdf5"
|
||||
|
||||
events_to_convert = events_stats[events_stats['pick_count'] > 0]
|
||||
|
||||
logger.debug("Catalog loaded, starting conversion ...")
|
||||
|
||||
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
|
||||
|
||||
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
||||
writer.data_format = {
|
||||
"dimension_order": "CW",
|
||||
"component_order": "ZNE",
|
||||
}
|
||||
|
||||
for i, event in enumerate(events):
|
||||
logger.debug(f"Converting {i} event")
|
||||
event_params = get_event_params(event)
|
||||
event_params["split"] = events_stats.loc[i, "split"]
|
||||
|
||||
picks_per_station = {}
|
||||
for pick in event.picks:
|
||||
trace_params = get_trace_params(pick)
|
||||
station = pick.waveform_id.station_code
|
||||
if station in picks_per_station:
|
||||
picks_per_station[station].append(pick)
|
||||
else:
|
||||
picks_per_station[station] = [pick]
|
||||
|
||||
for picks in picks_per_station.values():
|
||||
|
||||
trace_params = get_trace_params(picks[0])
|
||||
sampling_rate, stream = load_stream(input_path, trace_params)
|
||||
if stream is None:
|
||||
if stream is None or len(stream.traces) == 0:
|
||||
continue
|
||||
|
||||
actual_t_start, data, _ = sbu.stream_to_array(
|
||||
@ -229,22 +240,19 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
||||
trace_params["trace_sampling_rate_hz"] = sampling_rate
|
||||
trace_params["trace_start_time"] = str(actual_t_start)
|
||||
|
||||
pick_time = obspy.core.utcdatetime.UTCDateTime(trace_params["time"])
|
||||
pick_idx = (pick_time - actual_t_start) * sampling_rate
|
||||
|
||||
trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx)
|
||||
for pick in picks:
|
||||
pick_time = obspy.core.utcdatetime.UTCDateTime(pick.time)
|
||||
pick_idx = (pick_time - actual_t_start) * sampling_rate
|
||||
trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx)
|
||||
|
||||
writer.add_trace({**event_params, **trace_params}, data)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert mseed files to seisbench format')
|
||||
parser.add_argument('--input_path', type=str, help='Path to mseed files')
|
||||
parser.add_argument('--catalog_path', type=str, help='Path to events catalog in quakeml format')
|
||||
parser.add_argument('--output_path', type=str, help='Path to output files')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
convert_mseed_to_seisbench_format(args.input_path, args.catalog_path, args.output_path)
|
||||
|
149
scripts/perf_analysis.py
Normal file
149
scripts/perf_analysis.py
Normal file
@ -0,0 +1,149 @@
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import obspy
|
||||
import pandas as pd
|
||||
import seisbench.data as sbd
|
||||
import seisbench.models as sbm
|
||||
from seisbench.models.team import itertools
|
||||
from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve
|
||||
|
||||
datasets = [
|
||||
# path to datasets in seisbench format
|
||||
]
|
||||
|
||||
models = [
|
||||
# model names
|
||||
]
|
||||
|
||||
|
||||
def find_keys_phase(meta, phase):
|
||||
phases = []
|
||||
for k in meta.keys():
|
||||
if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"):
|
||||
phases.append(k)
|
||||
|
||||
return phases
|
||||
|
||||
|
||||
def create_stream(meta, raw, start, length=30):
|
||||
|
||||
st = obspy.Stream()
|
||||
|
||||
for i in range(3):
|
||||
tr = obspy.Trace(raw[i, :])
|
||||
tr.stats.starttime = meta["trace_start_time"]
|
||||
tr.stats.sampling_rate = meta["trace_sampling_rate_hz"]
|
||||
tr.stats.network = meta["station_network_code"]
|
||||
tr.stats.station = meta["station_code"]
|
||||
tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i]
|
||||
|
||||
stop = start + length
|
||||
tr = tr.slice(start, stop)
|
||||
|
||||
st.append(tr)
|
||||
|
||||
return st
|
||||
|
||||
|
||||
def get_pred(model, stream):
|
||||
ann = model.annotate(stream)
|
||||
noise = ann.select(channel="PhaseNet_N")[0]
|
||||
pred = max(1 - noise.data)
|
||||
return pred
|
||||
|
||||
|
||||
def to_short(stream):
|
||||
short = [tr for tr in stream if tr.data.shape[0] < 3001]
|
||||
return any(short)
|
||||
|
||||
|
||||
for ds, model_name in itertools.product(datasets, models):
|
||||
|
||||
data = sbd.WaveformDataset(ds, sampling_rate=100).test()
|
||||
data_name = pathlib.Path(ds).stem
|
||||
fname = f"roc___{model_name}___{data_name}.csv"
|
||||
|
||||
print(f"{fname:.<50s}.... ", flush=True, end="")
|
||||
|
||||
if pathlib.Path(fname).is_file():
|
||||
print(" ready, skipping", flush=True)
|
||||
continue
|
||||
|
||||
p_labels = find_keys_phase(data.metadata, "P")
|
||||
s_labels = find_keys_phase(data.metadata, "S")
|
||||
|
||||
model = sbm.PhaseNet().from_pretrained(model_name)
|
||||
|
||||
label_true = []
|
||||
label_pred = []
|
||||
|
||||
for i in range(len(data)):
|
||||
|
||||
waveform, metadata = data.get_sample(i)
|
||||
m = pd.Series(metadata)
|
||||
|
||||
has_p_label = m[p_labels].notna()
|
||||
has_s_label = m[s_labels].notna()
|
||||
|
||||
if any(has_p_label):
|
||||
|
||||
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
|
||||
pick_sample = m[p_labels][has_p_label][0]
|
||||
|
||||
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
|
||||
|
||||
try:
|
||||
st_p = create_stream(m, waveform, start)
|
||||
if not (to_short(st_p)):
|
||||
pred_p = get_pred(model, st_p)
|
||||
label_true.append(1)
|
||||
label_pred.append(pred_p)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
try:
|
||||
st_n = create_stream(m, waveform, trace_start_time + 1)
|
||||
if not (to_short(st_n)):
|
||||
pred_n = get_pred(model, st_n)
|
||||
label_true.append(0)
|
||||
label_pred.append(pred_n)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if any(has_s_label):
|
||||
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
|
||||
pick_sample = m[s_labels][has_s_label][0]
|
||||
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
|
||||
|
||||
try:
|
||||
st_s = create_stream(m, waveform, start)
|
||||
if not (to_short(st_s)):
|
||||
pred_s = get_pred(model, st_s)
|
||||
label_true.append(1)
|
||||
label_pred.append(pred_s)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred)
|
||||
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds})
|
||||
df.to_csv(fname)
|
||||
|
||||
precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred)
|
||||
prc_thresholds_extra = np.append(prc_thresholds, -999)
|
||||
df = pd.DataFrame(
|
||||
{"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra}
|
||||
)
|
||||
df.to_csv(fname.replace("roc", "pr"))
|
||||
|
||||
stats = {
|
||||
"model": str(model_name),
|
||||
"data": str(data_name),
|
||||
"auc": float(roc_auc_score(label_true, label_pred)),
|
||||
}
|
||||
|
||||
with open(f"stats___{model_name}___{data_name}.json", "w") as fp:
|
||||
json.dump(stats, fp)
|
||||
|
||||
print(" finished", flush=True)
|
@ -17,6 +17,8 @@ import eval
|
||||
import collect_results
|
||||
import importlib
|
||||
import config_loader
|
||||
import input_validate
|
||||
import data
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('pipeline')
|
||||
@ -40,11 +42,22 @@ def load_sweep_config(model_name, args):
|
||||
return util.load_sweep_config(sweep_fname)
|
||||
|
||||
|
||||
def find_the_best_params(model_name, args):
|
||||
def validate_pipeline_input(args):
|
||||
|
||||
# validate input parameters
|
||||
for model_name in args.models:
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
input_validate.validate_sweep_config(sweep_config, model_name)
|
||||
|
||||
# validate dataset
|
||||
data.validate_custom_dataset(config_loader.data_path)
|
||||
|
||||
|
||||
def find_the_best_params(sweep_config):
|
||||
# find the best hyperparams for the model_name
|
||||
model_name = sweep_config['parameters']['model_name']
|
||||
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
||||
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
||||
|
||||
# wait for all runs to finish
|
||||
@ -79,26 +92,33 @@ def main():
|
||||
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
||||
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
||||
parser.add_argument("--dataset", type=str, required=False)
|
||||
available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
|
||||
parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
|
||||
help="Models to train and evaluate (default: all)")
|
||||
parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset is not None:
|
||||
util.set_dataset(args.dataset)
|
||||
importlib.reload(config_loader)
|
||||
if not args.collect_results:
|
||||
|
||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||
if args.dataset is not None:
|
||||
util.set_dataset(args.dataset)
|
||||
importlib.reload(config_loader)
|
||||
|
||||
# generate labels
|
||||
logger.info("Started generating labels for the dataset.")
|
||||
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
|
||||
None)
|
||||
validate_pipeline_input(args)
|
||||
|
||||
# find the best hyperparams for the models
|
||||
logger.info("Started training the models.")
|
||||
for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
|
||||
if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
|
||||
break
|
||||
sweep_id = find_the_best_params(model_name, args)
|
||||
generate_predictions(sweep_id, model_name)
|
||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||
|
||||
# generate labels
|
||||
logger.info("Started generating labels for the dataset.")
|
||||
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
|
||||
None)
|
||||
|
||||
# find the best hyperparams for the models
|
||||
logger.info("Started training the models.")
|
||||
for model_name in args.models:
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
sweep_id = find_the_best_params(sweep_config)
|
||||
generate_predictions(sweep_id, model_name)
|
||||
|
||||
# collect results
|
||||
logger.info("Collecting results.")
|
||||
|
@ -16,5 +16,4 @@ python -c "import torch; print('Number of CUDA devices:', torch.cuda.device_coun
|
||||
python -c "import torch; print('Name of GPU:', torch.cuda.get_device_name(torch.cuda.current_device()))"
|
||||
|
||||
|
||||
python pipeline.py --dataset "lumineos"
|
||||
python pipeline.py --dataset "bogdanka"
|
||||
|
Loading…
Reference in New Issue
Block a user