{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Banded ridge regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial builds on the previous encoding model tutorial to map _multiple_ feature spaces onto brain activity during natural language comprehension. We will use two language models in one joint model: speech embeddings from Whisper's encoder and language embeddings from GPT2-XL. The [Himalaya](https://gallantlab.org/himalaya/index.html) package ([Dupré La Tour et al., 2022](https://doi.org/10.1016/j.neuroimage.2022.119728)) provides code to run and evaluate these joint encoding models (example [tutorial](https://gallantlab.org/voxelwise_tutorials/_auto_examples/shortclips/06_plot_banded_ridge_model.html)).\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hassonlab/podcast-ecog-tutorials/blob/main/notebooks/06-banded-ridge.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# only run this cell in colab\n", "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n", "!pip install mne mne_bids himalaya scikit-learn pandas matplotlib nilearn" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mne\n", "import h5py\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "from nilearn.plotting import plot_markers\n", "from mne_bids import BIDSPath\n", "\n", "from himalaya.backend import set_backend, get_backend\n", "from himalaya.kernel_ridge import Kernelizer, ColumnKernelizer, MultipleKernelRidgeCV\n", "from himalaya.scoring import correlation_score_split\n", "\n", "from sklearn.model_selection import KFold\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from sklearn import set_config\n", "set_config(display='diagram')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use a GPU for fitting an encoding model, if available." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda!\n" ] } ], "source": [ "if torch.cuda.is_available():\n", " set_backend(\"torch_cuda\")\n", " print(\"Using cuda!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load GPT-2 embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to previous tutorials, we will load the contextual word embeddings from GPT-2." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using embedding file path: ../../monkey/stimuli/gpt2-xl/features.hdf5\n" ] } ], "source": [ "bids_root = \"\" # if using a local dataset, set this variable accordingly\n", "\n", "# Download the transcript, if required\n", "embedding_path = f\"{bids_root}stimuli/gpt2-xl/features.hdf5\"\n", "if not len(bids_root):\n", " !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$embedding_path\n", " embedding_path = \"features.hdf5\"\n", "\n", "print(f\"Using embedding file path: {embedding_path}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LLM embedding matrix has shape: (5491, 1600)\n" ] } ], "source": [ "modelname, layer = 'gpt2-xl', 24\n", "with h5py.File(embedding_path, \"r\") as f:\n", " contextual_embeddings = f[f\"layer-{layer}\"][...]\n", "print(f\"LLM embedding matrix has shape: {contextual_embeddings.shape}\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model accuracy: 30.942%\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_idxwordstartendhftokentoken_idranktrue_probtop_predentropy
00Act3.7103.790ĠAct219116440.00001202.402717
11one,3.9904.190Ġone530920.0003423523.732053
21one,3.9904.190,1130.059520254.259335
32monkey4.6514.931Ġmonkey2165740220.00001837156.621269
43in4.9515.011Ġin287150.00423704.444838
\n", "
" ], "text/plain": [ " word_idx word start end hftoken token_id rank true_prob \\\n", "0 0 Act 3.710 3.790 ĠAct 2191 1644 0.000012 \n", "1 1 one, 3.990 4.190 Ġone 530 92 0.000342 \n", "2 1 one, 3.990 4.190 , 11 3 0.059520 \n", "3 2 monkey 4.651 4.931 Ġmonkey 21657 4022 0.000018 \n", "4 3 in 4.951 5.011 Ġin 287 15 0.004237 \n", "\n", " top_pred entropy \n", "0 0 2.402717 \n", "1 352 3.732053 \n", "2 25 4.259335 \n", "3 3715 6.621269 \n", "4 0 4.444838 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Download the transcript, if required\n", "transcript_path = f\"{bids_root}stimuli/gpt2-xl/transcript.tsv\"\n", "if not len(bids_root):\n", " !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$transcript_path\n", " transcript_path = \"transcript.tsv\"\n", "\n", "# Load transcript\n", "df_contextual = pd.read_csv(transcript_path, sep=\"\\t\", index_col=0)\n", "if \"rank\" in df_contextual.columns:\n", " model_acc = (df_contextual[\"rank\"] == 0).mean()\n", " print(f\"Model accuracy: {model_acc*100:.3f}%\")\n", "\n", "df_contextual.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LLM embeddings matrix has shape: (5136, 1600)\n" ] } ], "source": [ "aligned_gpt_embeddings = []\n", "for _, group in df_contextual.groupby(\"word_idx\"): # group by word index\n", " indices = group.index.to_numpy()\n", " average_emb = contextual_embeddings[indices].mean(0) # average features\n", " aligned_gpt_embeddings.append(average_emb)\n", "aligned_gpt_embeddings = np.stack(aligned_gpt_embeddings)\n", "print(f\"LLM embeddings matrix has shape: {aligned_gpt_embeddings.shape}\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
wordstartend
word_idx
0Act3.7103.790
1one,3.9904.190
2monkey4.6514.931
3in4.9515.011
4the5.0515.111
\n", "
" ], "text/plain": [ " word start end\n", "word_idx \n", "0 Act 3.710 3.790\n", "1 one, 3.990 4.190\n", "2 monkey 4.651 4.931\n", "3 in 4.951 5.011\n", "4 the 5.051 5.111" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_word_gpt = df_contextual.groupby(\"word_idx\").agg(dict(word=\"first\", start=\"first\", end=\"last\"))\n", "df_word_gpt.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Whisper encoder embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also load speech embeddings from Whisper (medium size) from the dataset." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using embedding file path: ../../monkey/stimuli/whisper-medium/features.hdf5\n" ] } ], "source": [ "# Download the transcript, if required\n", "embedding_path = f\"{bids_root}stimuli/whisper-medium/features.hdf5\"\n", "if not len(bids_root):\n", " !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$embedding_path\n", " embedding_path = \"features.hdf5\"\n", "\n", "print(f\"Using embedding file path: {embedding_path}\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LLM embedding matrix has shape: (5134, 2048)\n" ] } ], "source": [ "with h5py.File(embedding_path, \"r\") as f:\n", " whisper_embeddings = f[\"vectors\"][...]\n", "print(f\"LLM embedding matrix has shape: {whisper_embeddings.shape}\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_idxwordstartend
00Act3.7103.790
11one,3.9904.190
22monkey4.6514.931
33in4.9515.011
44the5.0515.111
\n", "
" ], "text/plain": [ " word_idx word start end\n", "0 0 Act 3.710 3.790\n", "1 1 one, 3.990 4.190\n", "2 2 monkey 4.651 4.931\n", "3 3 in 4.951 5.011\n", "4 4 the 5.051 5.111" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Download the transcript, if required\n", "transcript_path = f\"{bids_root}stimuli/whisper-medium/transcript.tsv\"\n", "if not len(bids_root):\n", " !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$transcript_path\n", " transcript_path = \"transcript.tsv\"\n", "\n", "# Load transcript\n", "df_whisper = pd.read_csv(transcript_path, sep=\"\\t\", index_col=0)\n", "df_whisper.head()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LLM embeddings matrix has shape: (5134, 2048)\n" ] } ], "source": [ "aligned_whisper_embeddings = []\n", "for _, group in df_whisper.groupby(\"word_idx\"): # group by word index\n", " indices = group.index.to_numpy()\n", " average_emb = whisper_embeddings[indices].mean(0) # average features\n", " aligned_whisper_embeddings.append(average_emb)\n", "aligned_whisper_embeddings = np.stack(aligned_whisper_embeddings)\n", "print(f\"LLM embeddings matrix has shape: {aligned_whisper_embeddings.shape}\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
wordstartend
word_idx
0Act3.7103.790
1one,3.9904.190
2monkey4.6514.931
3in4.9515.011
4the5.0515.111
\n", "
" ], "text/plain": [ " word start end\n", "word_idx \n", "0 Act 3.710 3.790\n", "1 one, 3.990 4.190\n", "2 monkey 4.651 4.931\n", "3 in 4.951 5.011\n", "4 the 5.051 5.111" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_word_whisper = df_whisper.groupby(\"word_idx\").agg(dict(word=\"first\", start=\"first\", end=\"last\"))\n", "df_word_whisper.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because we are using two different kinds of language models, they may have different tokenizers and thus their embeddings may not be aligned. We'll aligns their transcripts together here." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_xstart_xend_xword_ystart_yend_y
word_idx
0Act3.7103.790Act3.7103.790
1one,3.9904.190one,3.9904.190
2monkey4.6514.931monkey4.6514.931
3in4.9515.011in4.9515.011
4the5.0515.111the5.0515.111
.....................
5131go1798.5461798.646go1798.5461798.646
5132to1798.6661798.746to1798.6661798.746
5133court1798.7861799.006court1798.7861799.006
5134over1799.0461799.226over1799.0461799.226
5135it.1799.3271799.367it.1799.3271799.367
\n", "

5134 rows × 6 columns

\n", "
" ], "text/plain": [ " word_x start_x end_x word_y start_y end_y\n", "word_idx \n", "0 Act 3.710 3.790 Act 3.710 3.790\n", "1 one, 3.990 4.190 one, 3.990 4.190\n", "2 monkey 4.651 4.931 monkey 4.651 4.931\n", "3 in 4.951 5.011 in 4.951 5.011\n", "4 the 5.051 5.111 the 5.051 5.111\n", "... ... ... ... ... ... ...\n", "5131 go 1798.546 1798.646 go 1798.546 1798.646\n", "5132 to 1798.666 1798.746 to 1798.666 1798.746\n", "5133 court 1798.786 1799.006 court 1798.786 1799.006\n", "5134 over 1799.046 1799.226 over 1799.046 1799.226\n", "5135 it. 1799.327 1799.367 it. 1799.327 1799.367\n", "\n", "[5134 rows x 6 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_merged = pd.merge(df_word_gpt, df_word_whisper, left_index=True, right_index=True)\n", "df_merged" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading brain data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will load the preprocessed high-gamma ECoG data using MNE. Here, we will demonstrate loading data from our third subject." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "File path within the dataset: ../../monkey/derivatives/ecogprep/sub-03/ieeg/sub-03_task-podcast_desc-highgamma_ieeg.fif\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "
\n", " \n", " \n", " General\n", "
Filename(s)\n", " \n", " sub-03_task-podcast_desc-highgamma_ieeg.fif\n", " \n", " \n", "
MNE object typeRaw
Measurement date2019-03-11 at 10:54:21 UTC
Participantsub-03
ExperimenterUnknown
\n", " \n", " \n", " Acquisition\n", "
Duration00:29:60 (HH:MM:SS)
Sampling frequency512.00 Hz
Time points921,600
\n", " \n", " \n", " Channels\n", "
ECoG\n", " \n", "\n", " \n", "
Head & sensor digitization238 points
\n", " \n", " \n", " Filters\n", "
Highpass70.00 Hz
Lowpass200.00 Hz
" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "file_path = BIDSPath(root=f\"{bids_root}derivatives/ecogprep\",\n", " subject=\"03\", task=\"podcast\", datatype=\"ieeg\", description=\"highgamma\",\n", " suffix=\"ieeg\", extension=\".fif\")\n", "print(f\"File path within the dataset: {file_path}\")\n", "\n", "# You only need to run this if using Colab (i.e. if you did not set bids_root to a local directory)\n", "if not len(bids_root):\n", " !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$file_path\n", " file_path = file_path.basename\n", "\n", "raw = mne.io.read_raw_fif(file_path, verbose=False)\n", "picks = mne.pick_channels_regexp(raw.ch_names, \"LG[AB]*\")\n", "raw = raw.pick(picks)\n", "raw" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will map the start information (in seconds) of each word in the dataframe onto the brain signal data by multiplying by the sampling rate. Here the first column of `events` mark the start of each word on the brain signal data." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5134, 3)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "events = np.zeros((len(df_merged), 3), dtype=int)\n", "events[:, 0] = (df_merged.start_x * raw.info['sfreq']).astype(int)\n", "events.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we'll take advantage of MNE's tools for creating epochs around stimulus events, which here are the starts (onsets) of each word, to visualize brain signal that respond to word onsets. Here, we take a fixed-width window ranging from -2 seconds to +2 seconds relative to word onset. Since the sampling rate is 512 Hz (512 samples per second), we have 2049 lags total. The ECoG data is a numpy array with the shape of (number of words * number of ECoG electrodes * number of lags)." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Not setting metadata\n", "5134 matching events found\n", "No baseline correction applied\n", "Loading data for 5134 events and 2049 original time points ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_3099603/3482921024.py:1: RuntimeWarning: The events passed to the Epochs constructor are not chronologically ordered.\n", " epochs = mne.Epochs(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "6 bad epochs dropped\n", "Epochs object has a shape of: (5128, 127, 2049)\n" ] } ], "source": [ "epochs = mne.Epochs(\n", " raw,\n", " events,\n", " tmin=-2.0,\n", " tmax=2.0,\n", " baseline=None,\n", " proj=False,\n", " event_id=None,\n", " preload=True,\n", " event_repeated=\"merge\",\n", ")\n", "print(f\"Epochs object has a shape of: {epochs._data.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll downsample the temporal resolution to 32 Hz, which reduces the number of lags to 32 * 4 = 128.\n", "\n", "
\n", "\n", "**Note**\n", "\n", "This code block may take ~3 minutes to run.\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epochs object has a shape of: (5128, 127, 128)\n" ] } ], "source": [ "epochs = epochs.resample(sfreq=32, npad='auto', method='fft', window='hamming')\n", "print(f\"Epochs object has a shape of: {epochs._data.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up feature and brain data\n", "\n", "Now we have both the features and the ECoG data ready. We plan to fit encoding models at each electrode and for each lag, so we'll reshape our target matrix `Y` to horizontally stack both electrodes and lags along the second dimension." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ECoG data matrix shape: (5128, 16256)\n" ] } ], "source": [ "epochs_data = epochs.get_data(copy=True)\n", "epochs_data = epochs_data.reshape(len(epochs), -1)\n", "print(f\"ECoG data matrix shape: {epochs_data.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also align our features with the ECoG data. At the same time, we need to construct one design matrix as predictor variables in our encoding model. We do this by horizontally stacking the embeddings together, to get one wide model:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Combined model embeddings size: (5128, 3648)\n" ] } ], "source": [ "gpt2_embeddings = aligned_gpt_embeddings[epochs.selection]\n", "whisper_embeddings = aligned_whisper_embeddings[epochs.selection]\n", "input_embeddings = np.hstack((gpt2_embeddings, whisper_embeddings))\n", "\n", "print(f\"Combined model embeddings size: {input_embeddings.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will change the float precision to float32 for all data to take advantage of the GPU memory and computational speed." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((5128, 3648), (5128, 16256))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = input_embeddings\n", "Y = epochs_data\n", "\n", "if \"torch\" in get_backend().__name__:\n", " X = X.astype(np.float32)\n", " Y = Y.astype(np.float32)\n", "\n", "X.shape, Y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Building the encoding model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This section closely follows the Himalaya [banded ridge tutorial](https://gallantlab.org/voxelwise_tutorials/_auto_examples/shortclips/06_plot_banded_ridge_model.html) to construct an encoding model pipeline prior to fitting." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('standardscaler', StandardScaler(with_std=False)),\n",
       "                ('kernelizer', Kernelizer())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('standardscaler', StandardScaler(with_std=False)),\n", " ('kernelizer', Kernelizer())])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocess_pipeline = make_pipeline(\n", " StandardScaler(with_mean=True, with_std=False),\n", " Kernelizer(kernel=\"linear\"),\n", ")\n", "preprocess_pipeline" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
ColumnKernelizer(transformers=[('gpt2',\n",
       "                                Pipeline(steps=[('standardscaler',\n",
       "                                                 StandardScaler(with_std=False)),\n",
       "                                                ('kernelizer', Kernelizer())]),\n",
       "                                slice(0, 1600, None)),\n",
       "                               ('whisper',\n",
       "                                Pipeline(steps=[('standardscaler',\n",
       "                                                 StandardScaler(with_std=False)),\n",
       "                                                ('kernelizer', Kernelizer())]),\n",
       "                                slice(1600, None, None))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ColumnKernelizer(transformers=[('gpt2',\n", " Pipeline(steps=[('standardscaler',\n", " StandardScaler(with_std=False)),\n", " ('kernelizer', Kernelizer())]),\n", " slice(0, 1600, None)),\n", " ('whisper',\n", " Pipeline(steps=[('standardscaler',\n", " StandardScaler(with_std=False)),\n", " ('kernelizer', Kernelizer())]),\n", " slice(1600, None, None))])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_names = ['gpt2', 'whisper']\n", "slices = [slice(0, 1600), slice(1600, None)]\n", "\n", "kernelizers_tuples = [(name, preprocess_pipeline, slice_)\n", " for name, slice_ in zip(feature_names, slices)]\n", "column_kernelizer = ColumnKernelizer(kernelizers_tuples)\n", "column_kernelizer" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
MultipleKernelRidgeCV(cv=KFold(n_splits=5, random_state=None, shuffle=False),\n",
       "                      kernels='precomputed',\n",
       "                      solver_params={'alphas': array([1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07, 1.e+08,\n",
       "       1.e+09, 1.e+10]),\n",
       "                                     'n_iter': 20})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "MultipleKernelRidgeCV(cv=KFold(n_splits=5, random_state=None, shuffle=False),\n", " kernels='precomputed',\n", " solver_params={'alphas': array([1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07, 1.e+08,\n", " 1.e+09, 1.e+10]),\n", " 'n_iter': 20})" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_iter = 20\n", "alphas = np.logspace(1, 10, 10) # specify alpha values\n", "inner_cv = KFold(n_splits=5, shuffle=False) # inner 5-fold cross-validation setup\n", "\n", "solver_params = dict(n_iter=n_iter, alphas=alphas)\n", "mkr_model = MultipleKernelRidgeCV(kernels=\"precomputed\", solver='random_search',\n", " solver_params=solver_params, cv=inner_cv)\n", "\n", "mkr_model" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('columnkernelizer',\n",
       "                 ColumnKernelizer(transformers=[('gpt2',\n",
       "                                                 Pipeline(steps=[('standardscaler',\n",
       "                                                                  StandardScaler(with_std=False)),\n",
       "                                                                 ('kernelizer',\n",
       "                                                                  Kernelizer())]),\n",
       "                                                 slice(0, 1600, None)),\n",
       "                                                ('whisper',\n",
       "                                                 Pipeline(steps=[('standardscaler',\n",
       "                                                                  StandardScaler(with_std=False)),\n",
       "                                                                 ('kernelizer',\n",
       "                                                                  Kernelizer())]),\n",
       "                                                 slice(1600, None, None))])),\n",
       "                ('multiplekernelridgecv',\n",
       "                 MultipleKernelRidgeCV(cv=KFold(n_splits=5, random_state=None, shuffle=False),\n",
       "                                       kernels='precomputed',\n",
       "                                       solver_params={'alphas': array([1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07, 1.e+08,\n",
       "       1.e+09, 1.e+10]),\n",
       "                                                      'n_iter': 20}))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('columnkernelizer',\n", " ColumnKernelizer(transformers=[('gpt2',\n", " Pipeline(steps=[('standardscaler',\n", " StandardScaler(with_std=False)),\n", " ('kernelizer',\n", " Kernelizer())]),\n", " slice(0, 1600, None)),\n", " ('whisper',\n", " Pipeline(steps=[('standardscaler',\n", " StandardScaler(with_std=False)),\n", " ('kernelizer',\n", " Kernelizer())]),\n", " slice(1600, None, None))])),\n", " ('multiplekernelridgecv',\n", " MultipleKernelRidgeCV(cv=KFold(n_splits=5, random_state=None, shuffle=False),\n", " kernels='precomputed',\n", " solver_params={'alphas': array([1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07, 1.e+08,\n", " 1.e+09, 1.e+10]),\n", " 'n_iter': 20}))])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = make_pipeline(\n", " column_kernelizer,\n", " mkr_model,\n", ")\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training encoding models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "This code block may take a while to run. Make sure you are using a GPU if you have one (verify by running `nvidia-smi`). You may also consider resampling the epochs even further to use fewer lags, and/or choose specific electrodes to run to use fewer electrodes.\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[........................................] 100% | 22.45 sec | 20 random sampling with cv | \n", "[........................................] 100% | 19.45 sec | 20 random sampling with cv | \n", "Encoding performance correlating matrix shape: (2, 2, 127, 128)\n" ] } ], "source": [ "epochs_shape = epochs._data.shape[1:] # number of electrodes * number of lags\n", "\n", "def train_encoding(X, Y):\n", "\n", " corrs = [] # empty array to store correlation results\n", " kfold = KFold(2, shuffle=False) # outer 2-fold cross-validation setup\n", " for train_index, test_index in kfold.split(X): # loop through folds\n", "\n", " # Split train and test datasets\n", " X1_train, X1_test = X[train_index], X[test_index]\n", " Y_train, Y_test = Y[train_index], Y[test_index]\n", "\n", " # Standardize Y\n", " scaler = StandardScaler()\n", " Y_train = scaler.fit_transform(Y_train)\n", " Y_test = scaler.transform(Y_test)\n", "\n", " model.fit(X1_train, Y_train) # Fit pipeline with transforms and ridge estimator\n", " Y_preds = model.predict(X1_test, split=True) # Use trained model to predict on test set\n", " corr = correlation_score_split(Y_test, Y_preds) # Compute correlation score\n", "\n", " if \"torch\" in get_backend().__name__: # if using gpu, transform tensor back to numpy\n", " corr = corr.numpy(force=True)\n", "\n", " corrs.append(corr) # append fold correlation results to final results\n", " return np.stack(corrs)\n", "\n", "# set_backend(\"torch\") # resort to torch or numpy if cuda out of memory\n", "corrs_embedding = train_encoding(X, Y)\n", "corrs_embedding = corrs_embedding.reshape(2, 2, *epochs_shape)\n", "print(f\"Encoding performance correlating matrix shape: {corrs_embedding.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting encoding performance" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Coordinate matrix shape: (127, 3)\n" ] } ], "source": [ "ch2loc = {ch['ch_name']: ch['loc'][:3] for ch in raw.info['chs']}\n", "coords = np.vstack([ch2loc[ch] for ch in raw.info['ch_names']])\n", "coords *= 1000 # nilearn likes to plot in meters, not mm\n", "print(\"Coordinate matrix shape: \", coords.shape)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", "\n", "values = corrs_embedding[:, 0].mean(0).max(-1)\n", "order = values.argsort()\n", "plot_markers(values[order], coords[order],\n", " node_size=30, display_mode='l',\n", " figure=fig, axes=axes[0],\n", " node_vmin=0, node_vmax=0.2,\n", " title=\"GPT-2\",\n", " node_cmap='inferno_r', colorbar=True)\n", "\n", "\n", "values = corrs_embedding[:, 1].mean(0).max(-1)\n", "order = values.argsort()\n", "plot_markers(values[order], coords[order],\n", " node_size=30, display_mode='l',\n", " figure=fig, axes=axes[1],\n", " node_vmin=0, node_vmax=0.2,\n", " title=\"Whisper\",\n", " node_cmap='inferno_r', colorbar=True)\n", "\n", "fig.show()" ] } ], "metadata": { "kernelspec": { "display_name": "mne", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }