platform-demo-scripts/notebooks/Present model predictions.ipynb

390 lines
130 KiB
Plaintext
Raw Normal View History

2023-07-05 09:58:06 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "55ef77b9-9320-44db-862b-088d7af03112",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mkmilian\u001b[0m (\u001b[33mepos\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /Users/krystynamilian/.netrc\n"
]
}
],
"source": [
"import pandas as pd\n",
"from obspy.core.event import read_events\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import seisbench.models as sbm\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"import seisbench.data as sbd\n",
"import seisbench.generate as sbg\n",
"import seisbench.models as sbm\n",
"from seisbench.util import worker_seeding\n",
"import numpy as np\n",
"from torch.utils.data import DataLoader\n",
"\n",
"import wandb\n",
"import os\n",
"import sys\n",
"\n",
"from pathlib import Path\n",
"project_path = str(Path.cwd().parent)\n",
"sys.path.append(project_path)\n",
"from scripts import train"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "694ceb35-d2f3-4654-a6af-d84e494a4660",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"wandb version 0.15.4 is available! To upgrade, please run:\n",
" $ pip install wandb --upgrade"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.15.3"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/Users/krystynamilian/Documents/praca/Cyfronet/epos/ai/repo/demo_scripts/notebooks/wandb/run-20230704_110602-ufltoqra</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/epos/training_seisbench_models_on_igf_data/runs/ufltoqra' target=\"_blank\">polished-totem-255</a></strong> to <a href='https://wandb.ai/epos/training_seisbench_models_on_igf_data' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/epos/training_seisbench_models_on_igf_data' target=\"_blank\">https://wandb.ai/epos/training_seisbench_models_on_igf_data</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/epos/training_seisbench_models_on_igf_data/runs/ufltoqra' target=\"_blank\">https://wandb.ai/epos/training_seisbench_models_on_igf_data/runs/ufltoqra</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n"
]
}
],
"source": [
"run = wandb.init(project=\"training_seisbench_models_on_igf_data\", entity=\"epos\", mode=\"online\")\n",
"artifact = run.use_artifact('epos/model-registry/phasenet_p:v0', type='model')\n",
"artifact_dir = artifact.download()"
]
},
{
"cell_type": "markdown",
"id": "0589f4aa-e5a4-485e-9213-a4696c38a60c",
"metadata": {},
"source": [
"# Load model, get data loaders"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9935a11d-e8b5-4019-aafa-f031f1024a71",
"metadata": {},
"outputs": [],
"source": [
"model = train.load_model()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "922e8062-e503-4958-b192-b274453d64c3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'./artifacts/model:v4'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"artifact_dir"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7e100ac5-af2b-4d20-95f3-4f0a6a4fc0ad",
"metadata": {},
"outputs": [],
"source": [
"fname = artifact_dir + \"/\" + os.listdir(artifact_dir)[0]\n",
"# fname \n",
"# fname = \"../models/PhaseNet_pretrained_on_iquique_finetuned_on_igf_1.pt\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ffdf1280-2b25-49a3-9e6b-38a85a9c5501",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PhaseNet(\n",
" (inc): Conv1d(3, 8, kernel_size=(7,), stride=(1,), padding=same)\n",
" (in_bn): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (down_branch): ModuleList(\n",
" (0): ModuleList(\n",
" (0): Conv1d(8, 8, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (1): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(8, 8, kernel_size=(7,), stride=(4,), padding=(3,), bias=False)\n",
" (3): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): ModuleList(\n",
" (0): Conv1d(8, 16, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (1): BatchNorm1d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(16, 16, kernel_size=(7,), stride=(4,), bias=False)\n",
" (3): BatchNorm1d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (2): ModuleList(\n",
" (0): Conv1d(16, 32, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (1): BatchNorm1d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(32, 32, kernel_size=(7,), stride=(4,), bias=False)\n",
" (3): BatchNorm1d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (3): ModuleList(\n",
" (0): Conv1d(32, 64, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (1): BatchNorm1d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(64, 64, kernel_size=(7,), stride=(4,), bias=False)\n",
" (3): BatchNorm1d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (4): ModuleList(\n",
" (0): Conv1d(64, 128, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (1): BatchNorm1d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2-3): 2 x None\n",
" )\n",
" )\n",
" (up_branch): ModuleList(\n",
" (0): ModuleList(\n",
" (0): ConvTranspose1d(128, 64, kernel_size=(7,), stride=(4,), bias=False)\n",
" (1): BatchNorm1d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(128, 64, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (3): BatchNorm1d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (1): ModuleList(\n",
" (0): ConvTranspose1d(64, 32, kernel_size=(7,), stride=(4,), bias=False)\n",
" (1): BatchNorm1d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(64, 32, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (3): BatchNorm1d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (2): ModuleList(\n",
" (0): ConvTranspose1d(32, 16, kernel_size=(7,), stride=(4,), bias=False)\n",
" (1): BatchNorm1d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(32, 16, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (3): BatchNorm1d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (3): ModuleList(\n",
" (0): ConvTranspose1d(16, 8, kernel_size=(7,), stride=(4,), bias=False)\n",
" (1): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): Conv1d(16, 8, kernel_size=(7,), stride=(1,), padding=same, bias=False)\n",
" (3): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (out): Conv1d(8, 2, kernel_size=(1,), stride=(1,), padding=same)\n",
" (softmax): Softmax(dim=1)\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load(fname))\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4656e5fe-bb7b-4564-923a-8af385eb312d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train (12444, 17) 100\n",
"using random window\n",
"dev (2773, 17) 100\n",
"using random window\n",
"test (2785, 17) 100\n",
"using random window\n"
]
}
],
"source": [
"train_gen, dev_gen, test_gen = train.get_data_generators()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dfdaa857-9f2b-419f-8f28-9ce8f4b11a0f",
"metadata": {},
"outputs": [],
"source": [
"def plot_sample(sample, model, i): \n",
" fig = plt.figure(figsize=(15, 10))\n",
" fig.suptitle(\"Predictions for test sample: \" + str(i))\n",
" axs = fig.subplots(2, 1, sharex=True, gridspec_kw={\"hspace\": 0, \"height_ratios\": [3, 2]})\n",
" axs[0].plot(sample[\"X\"][0].T, label='x')\n",
" plt.legend()\n",
" axs[1].plot(sample[\"y\"][0].T, label='y')\n",
" \n",
" model.eval() # close the model for evaluation\n",
" \n",
" with torch.no_grad():\n",
" pred = model(torch.tensor(sample[\"X\"], device=model.device).unsqueeze(0)) # Add a fake batch dimension\n",
" pred = pred[0].cpu().numpy()\n",
" pick_idx = np.argmax(pred[0])\n",
" print(pred.shape)\n",
" \n",
" axs[1].plot(pred[0], label='pred', color='orange')\n",
" plt.legend()\n",
" \n",
" plt.show()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "26ee3888-8138-4d02-a5f8-a899b8ca6436",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 3001)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABNIAAAORCAYAAAA3ZI+fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5hU1fkH8O/MbAOWJh1BQFCwYsSI2LAQsWsSo2LyUzH2kERJopIitoglKnbUKJZYEGtioYggovQiRXpvu7C07bszc8/vj9l755xzz525OzuzC8v38zw+7M7euW1m0fn6vu8JCCEEiIiIiIiIiIiIKKFgQ58AERERERERERHRgYBBGhERERERERERkQ8M0oiIiIiIiIiIiHxgkEZEREREREREROQDgzQiIiIiIiIiIiIfGKQRERERERERERH5wCCNiIiIiIiIiIjIBwZpREREREREREREPjBIIyIiIiIiIiIi8oFBGhERERl1794d119/vfP9tGnTEAgEMG3atLQdIxAI4L777kvb/tLt8ccfx+GHH45QKIQTTjihoU+HfDrrrLNw1llnNfRpEBERUSPEII2IiGg/9PrrryMQCDj/5OXl4cgjj8SwYcNQWFjY0KdXK1988cV+HZZ5mTRpEu666y6cdtppGDt2LB5++OGMHu+dd97B6NGjM3qMbdu24b777sOiRYsyepyDxaRJk/Db3/4Wxx57LEKhELp37+7reW+//TYCgQDy8/NdP7v++uuV3337nz59+ijbbdu2Db/5zW/Qu3dvNG/eHK1atcLJJ5+MN954A0KIdFweERERGWQ19AkQERGRtwceeAA9evRAZWUlZsyYgRdffBFffPEFli5diqZNm9bruZx55pmoqKhATk5OrZ73xRdf4PnnnzeGaRUVFcjK2j//c+Trr79GMBjEq6++WutrTsU777yDpUuX4o477sjYMbZt24b7778f3bt3Z4VdGrzzzjsYN24cTjzxRHTu3NnXc0pLS3HXXXehWbNmntvk5ubi3//+t/JYy5Ytle+LioqwZcsWXHHFFTjssMMQDocxefJkXH/99Vi5cmXGg18iIqKD1f75X65EREQEALjgggtw0kknAQBuvPFGtGnTBk8++SQ+/fRTDBkyxPicsrKyhB/SUxUMBpGXl5fWfaZ7f+m0Y8cONGnSJG0hmhAClZWVaNKkSVr2Rw3v4YcfxiuvvILs7GxcfPHFWLp0adLnPPTQQ2jevDnOPvtsfPLJJ8ZtsrKy8Jvf/Cbhfo4//nhXm/WwYcNwySWX4JlnnsGDDz6IUCjk91KIiIjIJ7Z2EhERHUDOOeccAMD69esBxNrA8vPzsXbtWlx44YVo3rw5fv3rXwMALMvC6NGjccwxxyAvLw8dOnTALbfcgj179ij7FELgoYceQpcuXdC0aVOcffbZWLZsmevYXjPSZs+ejQsvvBCtW7dGs2bNcPzxx+Ppp592zu/5558HAKVNzWaakbZw4UJccMEFaNGiBfLz83Huuedi1qxZyjZ26+t3332H4cOHo127dmjWrBl+/vOfY+fOncq28+bNw+DBg9G2bVs0adIEPXr0wA033JDwPgcCAYwdOxZlZWXOOb/++usAgEgkggcffBA9e/ZEbm4uunfvjr/+9a+oqqpS9tG9e3dcfPHFmDhxIk466SQ0adIEL730kvF4Z511Fj7//HNs3LjROZ7cJlhVVYWRI0eiV69eyM3NRdeuXXHXXXe5jjl58mScfvrpaNWqFfLz89G7d2/89a9/BRB7/X76058CAIYOHeq6LpOSkhLccccd6N69O3Jzc9G+fXv87Gc/w4IFC5xtvv32W/zqV7/CYYcd5pzbnXfeiYqKCmVf9nt106ZNuPjii5Gfn49DDz3UeX8sWbIE55xzDpo1a4Zu3brhnXfeUZ5vv+bTp0/HLbfcgjZt2qBFixa49tprXe9pE7/3sKioCCtWrEB5eXnSfXbu3BnZ2dlJt7OtXr0aTz31FJ588smklZjRaBTFxcW+923r3r07ysvLUV1dXevnEhERUXKsSCMiIjqArF27FgDQpk0b57FIJILBgwfj9NNPx7/+9S+n5fOWW27B66+/jqFDh+IPf/gD1q9fj+eeew4LFy7Ed9995wQA9957Lx566CFceOGFuPDCC7FgwQKcd955vj6IT548GRdffDE6deqEP/7xj+jYsSOWL1+Ozz77DH/84x9xyy23YNu2bZg8eTLeeuutpPtbtmwZzjjjDLRo0QJ33XUXsrOz8dJLL+Gss87CN998g/79+yvb//73v0fr1q0xcuRIbNiwAaNHj8awYcMwbtw4ALGqsvPOOw/t2rXDPffcg1atWmHDhg346KOPEp7HW2+9hZdffhlz5sxxWuxOPfVUALHKwDfeeANXXHEF/vSnP2H27NkYNWoUli9fjo8//ljZz8qVKzFkyBDccsstuOmmm9C7d2/j8f72t79h37592LJlC5566ikAcOZnWZaFSy+9FDNmzMDNN9+Mo446CkuWLMFTTz2FVatWOVVNy5Ytw8UXX4zjjz8eDzzwAHJzc7FmzRp89913AICjjjoKDzzwAO69917cfPPNOOOMM5TrMrn11lvxwQcfYNiwYTj66KOxa9cuzJgxA8uXL8eJJ54IABg/fjzKy8tx2223oU2bNpgzZw6effZZbNmyBePHj1f2F41GccEFF+DMM8/EY489hrfffhvDhg1Ds2bN8Le//Q2//vWv8Ytf/AJjxozBtddeiwEDBqBHjx7KPoYNG4ZWrVrhvvvuw8qVK/Hiiy9i48aNTtBr4vceAsBzzz2H+++/H1OnTk37ggV33HEHzj77bFx44YV4//33PbcrLy9HixYtUF5ejtatW2PIkCF49NFHjTPVKioqUFZWhtLSUnzzzTcYO3YsBgwYwMpHIiKiTBFERES03xk7dqwAIL766iuxc+dOsXnzZvHee++JNm3aiCZNmogtW7YIIYS47rrrBABxzz33KM//9ttvBQDx9ttvK49PmDBBeXzHjh0iJydHXHTRRcKyLGe7v/71rwKAuO6665zHpk6dKgCIqVOnCiGEiEQiokePHqJbt25iz549ynHkff3ud78TXv/JAUCMHDnS+f7yyy8XOTk5Yu3atc5j27ZtE82bNxdnnnmm6/4MGjRIOdadd94pQqGQ2Lt3rxBCiI8//lgAEHPnzjUeP5HrrrtONGvWTHls0aJFAoC48cYblcf//Oc/CwDi66+/dh7r1q2bACAmTJjg63gXXXSR6Natm+vxt956SwSDQfHtt98qj48ZM0YAEN99950QQoinnnpKABA7d+70PMbcuXMFADF27Fhf59SyZUvxu9/9LuE25eXlrsdGjRolAoGA2Lhxo/OY/V59+OGHncf27NkjmjRpIgKBgHjvvfecx1esWOF6b9iveb9+/UR1dbXz+GOPPSYAiE8//dR5bODAgWLgwIHO937voRBCjBw5Unmf++X1+tk+++wzkZWVJZYtWyaEML+/hBDinnvuEXfffbcYN26cePfdd537dtppp4lwOOzaftSoUQKA88+5554rNm3aVKtzJyIiIv/Y2klERLQfGzRoENq1a4euXbvi6quvRn5+Pj7++GMceuihyna33Xab8v348ePRsmVL/OxnP0NRUZHzT79+/ZCfn4+pU6cCAL766itUV1fj97//vVLN42fg/cKFC7F+/XrccccdaNWqlfIzr8qgRKLRKCZNmoTLL78chx9+uPN4p06dcM0112DGjBmuVrebb75ZOdYZZ5yBaDSKjRs3AoBzXp999hnC4XCtz0n3xRdfAACGDx+uPP6nP/0JAPD5558rj/fo0QODBw+u0zHHjx+Po446Cn369FFeS7vN134t7Wv99NNPYVlWnY5pa9WqFWbPno1t27Z5biNXPpWVlaGoqAinnnoqhBBYuHCha/sbb7xR2X/v3r3RrFkzXHnllc7jvXv3RqtWrbBu3TrX82+++WalnfK2225DVlaW89qY+L2HAHDfffdBCJHWarTq6mrceeeduPXWW3H00Ucn3HbUqFF45JFHcOWVV+Lqq6/G66+/jn/+85/
"text/plain": [
"<Figure size 1500x1000 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"idx = np.random.randint(len(test_gen))\n",
"sample = test_gen[idx]\n",
"plot_sample(sample, model, idx)\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}