initial commit
This commit is contained in:
commit
128d657e76
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
__pycache__/
|
||||||
|
.idea/
|
||||||
|
*/.ipynb_checkpoints/
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
.env
|
||||||
|
models/
|
||||||
|
data/
|
||||||
|
wip
|
||||||
|
artifacts/
|
||||||
|
wandb/
|
18
README.md
Normal file
18
README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Demo notebooks and scripts for EPOS AI Platform
|
||||||
|
|
||||||
|
|
||||||
|
This repo contains notebooks and scripts demonstrating how to:
|
||||||
|
- Prepare IGF data for training model detecting P phase (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20to%20SeisBench%20dataset.ipynb).
|
||||||
|
The original data can be downloaded from the [drive](https://drive.google.com/drive/folders/1InVI9DLaD7gdzraM2jMzeIrtiBSu-UIK?usp=drive_link)
|
||||||
|
|
||||||
|
- Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb)
|
||||||
|
- Train cnn model (Seisbench PhaseNet)to detect P phase, check the [script](scripts/train.py)
|
||||||
|
- Search for the best training hyperparams, check the [script](scripts/training_wandb_sweep.py)
|
||||||
|
- Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
|
||||||
|
- Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
|
||||||
|
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
To install all dependencies run:
|
||||||
|
`poetry install`
|
11
experiments/config.json
Normal file
11
experiments/config.json
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"epochs": 10,
|
||||||
|
"batch_size": 256,
|
||||||
|
"dataset": "igf_1",
|
||||||
|
"sampling_rate": 100,
|
||||||
|
"model_names": "EQTransformer,BasicPhaseAE,GPD",
|
||||||
|
"model_name": "PhaseNet",
|
||||||
|
"learning_rate": 0.01,
|
||||||
|
"pretrained": null,
|
||||||
|
"sampling_rate": 100
|
||||||
|
}
|
38
experiments/sweep.yaml
Normal file
38
experiments/sweep.yaml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
method: bayes
|
||||||
|
metric:
|
||||||
|
goal: minimize
|
||||||
|
name: test_loss
|
||||||
|
parameters:
|
||||||
|
batch_size:
|
||||||
|
distribution: int_uniform
|
||||||
|
max: 512
|
||||||
|
min: 128
|
||||||
|
dataset:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
- igf_1
|
||||||
|
epochs:
|
||||||
|
distribution: int_uniform
|
||||||
|
max: 60
|
||||||
|
min: 15
|
||||||
|
learning_rate:
|
||||||
|
distribution: uniform
|
||||||
|
max: 0.02
|
||||||
|
min: 0.005
|
||||||
|
model_name:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
- PhaseNet
|
||||||
|
pretrained:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
- diting
|
||||||
|
- ethz
|
||||||
|
- geofon
|
||||||
|
- instance
|
||||||
|
- iquique
|
||||||
|
- lendb
|
||||||
|
- neic
|
||||||
|
- original
|
||||||
|
- scedc
|
||||||
|
program: training_wandb_sweep.py
|
32
experiments/sweep4.yaml
Normal file
32
experiments/sweep4.yaml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
method: bayes
|
||||||
|
metric:
|
||||||
|
goal: minimize
|
||||||
|
name: test_loss
|
||||||
|
parameters:
|
||||||
|
batch_size:
|
||||||
|
distribution: int_uniform
|
||||||
|
max: 512
|
||||||
|
min: 256
|
||||||
|
dataset:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
- igf_1
|
||||||
|
epochs:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
- 10
|
||||||
|
learning_rate:
|
||||||
|
distribution: uniform
|
||||||
|
max: 0.02
|
||||||
|
min: 0.01
|
||||||
|
model_name:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
# - EQTransformer
|
||||||
|
- PhaseNet
|
||||||
|
pretrained:
|
||||||
|
distribution: categorical
|
||||||
|
values:
|
||||||
|
- instance
|
||||||
|
- iquique
|
||||||
|
- false
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1422
notebooks/Explore igf data.ipynb
Normal file
1422
notebooks/Explore igf data.ipynb
Normal file
File diff suppressed because one or more lines are too long
389
notebooks/Present model predictions.ipynb
Normal file
389
notebooks/Present model predictions.ipynb
Normal file
File diff suppressed because one or more lines are too long
1745
poetry.lock
generated
Normal file
1745
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
22
pyproject.toml
Normal file
22
pyproject.toml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "ai_platform_demo_scripts"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = ["Krystyna Milian <krystyna.milian@gmail.com>"]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.10"
|
||||||
|
seisbench = "^0.4.1"
|
||||||
|
torch = "^2.0.1"
|
||||||
|
PyYAML = "^6.0"
|
||||||
|
python-dotenv = "^1.0.0"
|
||||||
|
pandas = "^2.0.3"
|
||||||
|
obspy = "^1.4.0"
|
||||||
|
wandb = "^0.15.4"
|
||||||
|
torchmetrics = "^0.11.4"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
339
scripts/train.py
Normal file
339
scripts/train.py
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
import os.path
|
||||||
|
import wandb
|
||||||
|
import seisbench.data as sbd
|
||||||
|
import seisbench.generate as sbg
|
||||||
|
import seisbench.models as sbm
|
||||||
|
from seisbench.util import worker_seeding
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn.functional as f
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchmetrics import Metric
|
||||||
|
from torch import Tensor, tensor
|
||||||
|
import json
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||||
|
if wandb_api_key is None:
|
||||||
|
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||||
|
|
||||||
|
wandb.login(key=wandb_api_key)
|
||||||
|
|
||||||
|
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
class PickMAE(Metric):
|
||||||
|
higher_is_better: bool = False
|
||||||
|
mae_error: Tensor
|
||||||
|
|
||||||
|
def __init__(self, sampling_rate):
|
||||||
|
super().__init__()
|
||||||
|
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
assert preds.shape == target.shape
|
||||||
|
|
||||||
|
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
|
||||||
|
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
|
||||||
|
|
||||||
|
mae = nn.L1Loss()
|
||||||
|
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.mae_error.float()
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopper:
|
||||||
|
def __init__(self, patience=1, min_delta=0):
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.counter = 0
|
||||||
|
self.min_validation_loss = np.inf
|
||||||
|
|
||||||
|
def early_stop(self, validation_loss):
|
||||||
|
if validation_loss < self.min_validation_loss:
|
||||||
|
self.min_validation_loss = validation_loss
|
||||||
|
self.counter = 0
|
||||||
|
elif validation_loss > (self.min_validation_loss + self.min_delta):
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
|
||||||
|
|
||||||
|
if path is not None:
|
||||||
|
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
|
||||||
|
phase_dict = {
|
||||||
|
"trace_Pg_arrival_sample": "P"
|
||||||
|
}
|
||||||
|
elif sb_dataset == "ethz":
|
||||||
|
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
|
||||||
|
|
||||||
|
phase_dict = {
|
||||||
|
"trace_p_arrival_sample": "P",
|
||||||
|
"trace_pP_arrival_sample": "P",
|
||||||
|
"trace_P_arrival_sample": "P",
|
||||||
|
"trace_P1_arrival_sample": "P",
|
||||||
|
"trace_Pg_arrival_sample": "P",
|
||||||
|
"trace_Pn_arrival_sample": "P",
|
||||||
|
"trace_PmP_arrival_sample": "P",
|
||||||
|
"trace_pwP_arrival_sample": "P",
|
||||||
|
"trace_pwPm_arrival_sample": "P",
|
||||||
|
# "trace_s_arrival_sample": "S",
|
||||||
|
# "trace_S_arrival_sample": "S",
|
||||||
|
# "trace_S1_arrival_sample": "S",
|
||||||
|
# "trace_Sg_arrival_sample": "S",
|
||||||
|
# "trace_SmS_arrival_sample": "S",
|
||||||
|
# "trace_Sn_arrival_sample": "S",
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = data.get_split(split)
|
||||||
|
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
|
||||||
|
|
||||||
|
print(split, dataset.metadata.shape, sampling_rate)
|
||||||
|
|
||||||
|
if station is not None:
|
||||||
|
dataset.filter(dataset.metadata.station_code==station)
|
||||||
|
|
||||||
|
data_generator = sbg.GenericGenerator(dataset)
|
||||||
|
if window == 'random':
|
||||||
|
print("using random window")
|
||||||
|
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
|
||||||
|
else:
|
||||||
|
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
|
||||||
|
|
||||||
|
augmentations = [
|
||||||
|
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
|
||||||
|
strategy="variable"),
|
||||||
|
window_selector,
|
||||||
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||||
|
sbg.ChangeDtype(np.float32),
|
||||||
|
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
data_generator.add_augmentations(augmentations)
|
||||||
|
|
||||||
|
return data_generator
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
|
||||||
|
window='random'):
|
||||||
|
|
||||||
|
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
|
||||||
|
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
|
||||||
|
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
|
||||||
|
|
||||||
|
return train_generator, dev_generator, test_generator
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
|
||||||
|
window='random'):
|
||||||
|
|
||||||
|
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
|
||||||
|
num_workers = 0 # The number of threads used for loading data
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
|
||||||
|
worker_init_fn=worker_seeding)
|
||||||
|
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||||
|
worker_init_fn=worker_seeding)
|
||||||
|
|
||||||
|
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||||
|
worker_init_fn=worker_seeding)
|
||||||
|
|
||||||
|
return train_loader, dev_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
|
||||||
|
|
||||||
|
if name == "PhaseNet":
|
||||||
|
|
||||||
|
if pretrained is not None and pretrained:
|
||||||
|
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
|
||||||
|
else:
|
||||||
|
model = sbm.PhaseNet(phases="PN", norm="peak")
|
||||||
|
|
||||||
|
if modify_output:
|
||||||
|
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, dataloader, optimizer, pick_mae):
|
||||||
|
size = len(dataloader.dataset)
|
||||||
|
for batch_id, batch in enumerate(dataloader):
|
||||||
|
|
||||||
|
# Compute prediction and loss
|
||||||
|
|
||||||
|
pred = model(batch["X"].to(model.device))
|
||||||
|
|
||||||
|
loss = loss_fn(pred, batch["y"].to(model.device))
|
||||||
|
|
||||||
|
# Compute cross entropy loss
|
||||||
|
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||||
|
|
||||||
|
# Compute mae
|
||||||
|
mae = pick_mae(pred, batch['y'])
|
||||||
|
|
||||||
|
wandb.log({"loss": loss})
|
||||||
|
wandb.log({"batch cross entropy loss": cross_entropy_loss})
|
||||||
|
wandb.log({"p_mae": mae})
|
||||||
|
|
||||||
|
|
||||||
|
# Backpropagation
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if batch_id % 5 == 0:
|
||||||
|
loss, current = loss.item(), batch_id * batch["X"].shape[0]
|
||||||
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||||
|
print(f"mae: {mae:>7f}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
|
||||||
|
|
||||||
|
num_batches = len(dataloader)
|
||||||
|
test_loss = 0
|
||||||
|
test_mae = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in dataloader:
|
||||||
|
pred = model(batch["X"].to(model.device))
|
||||||
|
|
||||||
|
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
|
||||||
|
test_mae += pick_mae(pred, batch['y'])
|
||||||
|
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||||
|
if wandb_log:
|
||||||
|
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
|
||||||
|
|
||||||
|
test_loss /= num_batches
|
||||||
|
test_mae /= num_batches
|
||||||
|
|
||||||
|
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
|
||||||
|
|
||||||
|
print(f"Test avg loss: {test_loss:>8f}")
|
||||||
|
print(f"Test avg mae: {test_mae:>7f}\n")
|
||||||
|
|
||||||
|
return test_loss, test_mae
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(model, path_to_trained_model, train_loader, dev_loader):
|
||||||
|
|
||||||
|
wandb.watch(model, log_freq=10)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
|
||||||
|
early_stopper = EarlyStopper(patience=3, min_delta=10)
|
||||||
|
pick_mae = PickMAE(wandb.config.sampling_rate)
|
||||||
|
|
||||||
|
best_loss = np.inf
|
||||||
|
best_metrics = {}
|
||||||
|
|
||||||
|
for t in range(wandb.config.epochs):
|
||||||
|
print(f"Epoch {t + 1}\n-------------------------------")
|
||||||
|
train_one_epoch(model, train_loader, optimizer, pick_mae)
|
||||||
|
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
|
||||||
|
|
||||||
|
if test_loss < best_loss:
|
||||||
|
best_loss = test_loss
|
||||||
|
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
|
||||||
|
torch.save(model.state_dict(), path_to_trained_model)
|
||||||
|
|
||||||
|
if early_stopper.early_stop(test_loss):
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Best model: ", str(best_metrics))
|
||||||
|
|
||||||
|
|
||||||
|
def loss_fn(y_pred, y_true, eps=1e-5):
|
||||||
|
# vector cross entropy loss
|
||||||
|
h = y_true * torch.log(y_pred + eps)
|
||||||
|
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
|
||||||
|
h = h.mean() # Mean over batch axis
|
||||||
|
return -h
|
||||||
|
|
||||||
|
|
||||||
|
def train_phasenet_on_sb_data():
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"epochs": 3,
|
||||||
|
"batch_size": 256,
|
||||||
|
"dataset": "ethz",
|
||||||
|
"sampling_rate": 100,
|
||||||
|
"model_name": "PhaseNet"
|
||||||
|
}
|
||||||
|
|
||||||
|
run = wandb.init(
|
||||||
|
# set the wandb project where this run will be logged
|
||||||
|
project="training_seisbench_models_on_igf_data",
|
||||||
|
# track hyperparameters and run metadata
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||||
|
|
||||||
|
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||||
|
sampling_rate=wandb.config.sampling_rate,
|
||||||
|
path=None,
|
||||||
|
sb_dataset=wandb.config.dataset)
|
||||||
|
|
||||||
|
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
|
||||||
|
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
|
||||||
|
train_model(model, path_to_trained_model,
|
||||||
|
train_loader, dev_loader)
|
||||||
|
|
||||||
|
artifact = wandb.Artifact('model', type='model')
|
||||||
|
artifact.add_file(path_to_trained_model)
|
||||||
|
run.log_artifact(artifact)
|
||||||
|
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def train_sbmodel_on_igf_data():
|
||||||
|
|
||||||
|
config_path = project_path + "/experiments/config.json"
|
||||||
|
config = load_config(config_path)
|
||||||
|
|
||||||
|
run = wandb.init(
|
||||||
|
# set the wandb project where this run will be logged
|
||||||
|
project="training_seisbench_models_on_igf_data",
|
||||||
|
# track hyperparameters and run metadata
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||||
|
|
||||||
|
print(wandb.config.batch_size, wandb.config.sampling_rate)
|
||||||
|
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||||
|
sampling_rate=wandb.config.sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
model_name = wandb.config.model_name
|
||||||
|
pretrained = wandb.config.pretrained
|
||||||
|
|
||||||
|
print(model_name, pretrained)
|
||||||
|
model = load_model(name=model_name, pretrained=pretrained)
|
||||||
|
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||||
|
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||||
|
|
||||||
|
artifact = wandb.Artifact('model', type='model')
|
||||||
|
artifact.add_file(path_to_trained_model)
|
||||||
|
run.log_artifact(artifact)
|
||||||
|
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# train_phasenet_on_sb_data()
|
||||||
|
train_sbmodel_on_igf_data()
|
||||||
|
|
62
scripts/training_wandb_sweep.py
Normal file
62
scripts/training_wandb_sweep.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import os.path
|
||||||
|
import wandb
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from train import get_data_loaders, load_model, train_model
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||||
|
if wandb_api_key is None:
|
||||||
|
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||||
|
|
||||||
|
wandb.login(key=wandb_api_key)
|
||||||
|
|
||||||
|
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sweep_config_path = project_path + "/experiments/sweep4.yaml"
|
||||||
|
|
||||||
|
with open(sweep_config_path) as file:
|
||||||
|
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
sweep_id = wandb.sweep(
|
||||||
|
sweep=sweep_configuration,
|
||||||
|
project='training_seisbench_models_on_igf_data'
|
||||||
|
)
|
||||||
|
sampling_rate = 100
|
||||||
|
|
||||||
|
def tune_training_hyperparams():
|
||||||
|
|
||||||
|
run = wandb.init(
|
||||||
|
# set the wandb project where this run will be logged
|
||||||
|
project="training_seisbench_models_on_igf_data",
|
||||||
|
# track hyperparameters and run metadata
|
||||||
|
config={"sampling_rate":sampling_rate}
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||||
|
|
||||||
|
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||||
|
sampling_rate=wandb.config.sampling_rate,
|
||||||
|
sb_dataset=wandb.config.dataset)
|
||||||
|
|
||||||
|
model_name = wandb.config.model_name
|
||||||
|
pretrained = wandb.config.pretrained
|
||||||
|
print(wandb.config)
|
||||||
|
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
|
||||||
|
if not pretrained:
|
||||||
|
pretrained
|
||||||
|
model = load_model(name=model_name, pretrained=pretrained)
|
||||||
|
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||||
|
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||||
|
|
||||||
|
artifact = wandb.Artifact('model', type='model')
|
||||||
|
artifact.add_file(path_to_trained_model)
|
||||||
|
run.log_artifact(artifact)
|
||||||
|
|
||||||
|
run.finish()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)
|
708
utils/Transforming mseeds to SeisBench dataset.ipynb
Normal file
708
utils/Transforming mseeds to SeisBench dataset.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user