{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting word embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial introduces how to extract features, or word embeddings based on our stimulus transcript. Features are numeric vectors that capture the meaning of the words in our transcript. Here, we will present two types types of features: interpretable syntactic features and high-dimensional word embeddings from a language model.\n", "\n", "Acknowledgments: This tutorial draws heavily on the [encling tutorial](https://github.com/snastase/encling-tutorial/blob/main/encling_tutorial.ipynb) by Samuel A. Nastase.\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hassonlab/podcast-ecog-tutorials/blob/main/notebooks/03-features.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# only run this cell in colab\n", "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n", "!pip install accelerate transformers spacy scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we'll import some general-purpose Python packages. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extracting syntactic features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One type of linguistic features are explicit grammatical features we are familiar with and have names for. These can include parts of speech (e.g., noun, verb) or syntactic dependencies (e.g., root, subject, object). We will use the [spaCy](https://github.com/explosion/spaCy) library (Honnibal et al., 2020)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import spacy\n", "from sklearn.preprocessing import LabelBinarizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we need to load the transcript as a `pandas` dataframe. It contains columns of words and their start and end timestamps." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bids_root = \"\" # if using a local dataset, set this variable accordingly\n", "\n", "# Download the transcript, if required\n", "transcript_path = f\"{bids_root}stimuli/podcast_transcript.csv\"\n", "if not len(bids_root):\n", " !wget -nc https://s3.amazonaws.com/openneuro.org/ds005574/$transcript_path\n", " transcript_path = \"podcast_transcript.csv\"\n", "\n", "df = pd.read_csv(transcript_path)\n", "df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SpaCy requires us to download and load a model that enables its features. First, we will download the [en-core-web-sm](https://spacy.io/models/en#en_core_web_sm) model trained on English and includes components for part-of-speech tagging and dependency parsing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python -m spacy download en_core_web_sm" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting en-core-web-lg==3.8.0\n", " Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl (400.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.7/400.7 MB\u001b[0m \u001b[31m105.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25h\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n", "You can now load the package via spacy.load('en_core_web_lg')\n" ] } ], "source": [ "modelname = \"en_core_web_sm\"\n", "nlp = spacy.load(modelname)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Language processing pipelines typically use a `tokenizer` to standarize the (sub-)word units (called tokens) they can operate on. Some words and punctuation will get separated into multiple tokens. For example, the word \"there's\" will be tokenized into \"there\" and \"'s\". Thus the first step for us is to transform our transcript words into tokens that spaCy can work with.\n", "\n", "To keep track of our word and their indices we first create a `word_idx` column. We then tokenize the words using the tokenizer. Then, we will [explode](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.explode.html) the dataframe so that each row of the dataframe is a token (and not a word). Note that we will add white spaces to the end of words before tokenization so we can track the boundary of each word. Compare the dataframes from before and from below." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_idxwordstartendword_with_wshftoken
00Act3.7103.790ActAct
11one,3.9904.190one,one
21one,3.9904.190one,,
32monkey4.6514.931monkeymonkey
43in4.9515.011inin
54the5.0515.111thethe
65middle.5.1515.391middle.middle
75middle.5.1515.391middle..
86So6.5926.732SoSo
97there's6.7526.912there'sthere
\n", "
" ], "text/plain": [ " word_idx word start end word_with_ws hftoken\n", "0 0 Act 3.710 3.790 Act Act\n", "1 1 one, 3.990 4.190 one, one\n", "2 1 one, 3.990 4.190 one, ,\n", "3 2 monkey 4.651 4.931 monkey monkey\n", "4 3 in 4.951 5.011 in in\n", "5 4 the 5.051 5.111 the the\n", "6 5 middle. 5.151 5.391 middle. middle\n", "7 5 middle. 5.151 5.391 middle. .\n", "8 6 So 6.592 6.732 So So\n", "9 7 there's 6.752 6.912 there's there" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.insert(0, \"word_idx\", df.index.values)\n", "df[\"word_with_ws\"] = df.word.astype(str) + \" \"\n", "df[\"hftoken\"] = df.word_with_ws.apply(nlp.tokenizer)\n", "df = df.explode(\"hftoken\", ignore_index=True)\n", "df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will create a [doc](https://spacy.io/api/doc) objcet (essentially a list of token objects) from our tokenized text:\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "words = [token.text for token in df.hftoken.tolist()]\n", "spaces = [token.whitespace_ == \" \" for token in df.hftoken.tolist()]\n", "doc = spacy.tokens.Doc(nlp.vocab, words=words, spaces=spaces)\n", "doc = nlp(doc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will loop through the doc, and get the features for each token. The [features](https://spacy.io/usage/linguistic-features#pos-tagging) include `text`, `tag` (detailed part-of-speech tag), `dep` (syntactic dependency, i.e. the relation between tokens), and `is_stop` (is the token part of a [stop list](https://en.wikipedia.org/wiki/Stop_word)). We will organize the features into a second dataframe and add those columns back to `df`. We will drop the two columns we don't need anymore, and then save `df` for future encoding." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_idxwordstartendtokenposdepstop
00Act3.7103.790ActNNPROOTFalse
11one,3.9904.190oneCDnummodTrue
21one,3.9904.190,,punctFalse
32monkey4.6514.931monkeyNNapposFalse
43in4.9515.011inINprepTrue
54the5.0515.111theDTdetTrue
65middle.5.1515.391middleNNpobjFalse
75middle.5.1515.391..punctFalse
86So6.5926.732SoRBadvmodTrue
97there's6.7526.912thereEXexplTrue
\n", "
" ], "text/plain": [ " word_idx word start end token pos dep stop\n", "0 0 Act 3.710 3.790 Act NNP ROOT False\n", "1 1 one, 3.990 4.190 one CD nummod True\n", "2 1 one, 3.990 4.190 , , punct False\n", "3 2 monkey 4.651 4.931 monkey NN appos False\n", "4 3 in 4.951 5.011 in IN prep True\n", "5 4 the 5.051 5.111 the DT det True\n", "6 5 middle. 5.151 5.391 middle NN pobj False\n", "7 5 middle. 5.151 5.391 . . punct False\n", "8 6 So 6.592 6.732 So RB advmod True\n", "9 7 there's 6.752 6.912 there EX expl True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features = []\n", "for token in doc:\n", " features.append([token.text, token.tag_, token.dep_, token.is_stop])\n", "\n", "df2 = pd.DataFrame(\n", " features, columns=[\"token\", \"pos\", \"dep\", \"stop\"], index=df.index\n", " )\n", "df = pd.concat([df, df2], axis=1)\n", "df.drop([\"hftoken\", \"word_with_ws\"], axis=1, inplace=True)\n", "df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the features we extracted are all categorical, we need to turn them into numerical vectors. We will use [LabelBinarizer](https://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.LabelBinarizer.html) from `sklearn`, which fits to all the possible category labels for a feature and then transforms our labels into one-hot binary vectors. There are 50 possible labels for `tag` and 45 possible for `dep`. So those two features will be turned into 50-dimensional and 45-dimensional vectors respectively. Our `is_stop` feature is binary, so it will just be one dimensional. We concatenate all three features to form a 96-dimensional syntactic feature overall and save it for future encoding." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Embeddings have a shape of: (5305, 96)\n" ] } ], "source": [ "taggerEncoder = LabelBinarizer().fit(nlp.get_pipe(\"tagger\").labels)\n", "dependencyEncoder = LabelBinarizer().fit(nlp.get_pipe(\"parser\").labels)\n", "\n", "a = taggerEncoder.transform(df.pos.tolist())\n", "b = dependencyEncoder.transform(df.dep.tolist())\n", "c = LabelBinarizer().fit_transform(df.stop.tolist())\n", "embeddings = np.hstack((a, b, c))\n", "print(f\"Embeddings have a shape of: {embeddings.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extracting GPT-2 Features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will extract contextual word embeddings from an autoregressive (or \"causal\") large language model (LLM) called GPT-2 ([Radford et al., 2019](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)). GPT-2 relies on the Transformer architecture to sculpt the embedding of a given word based on the preceding context. The model is composed of a repeated circuit motif—called the \"attention head\"—by which the model can \"attend\" to previous words in the context window when determining the meaning of the current word. This GPT-2 implementation is composed of 12 layers, each of which contains 12 attention heads that influence the embedding as it proceeds to the subsequent layer. The embeddings at each layer of the model comprise 768 features and the context window includes the preceding 1024 tokens. Note that certain words will be broken up into multiple tokens; we'll need to use GPT-2's \"tokenizer\" to convert words into the appropriate tokens. GPT-2 has been (pre)trained on large corpora of text according to a simple self-supervised objective function: predict the next word based on the prior context.\n", "\n", "We will be using the [HuggingFace](https://huggingface.co) [transformers](https://huggingface.co/docs/transformers/index) library for working with these models. If you want to learn more about LLMs and GPT-2, here are some great blogs explaining [transformers](https://jalammar.github.io/illustrated-transformer/) and [GPT-2](https://jalammar.github.io/illustrated-gpt2/) architecture. The HuggingFace website also has many useful resources.\n", "\n", "
\n", "\n", "**Note**\n", "\n", "Using large language models, even small ones, requires a lot of compute resources. If you're use Colab, go to `Edit` → `Notebook Settings` and select a GPU. Restart the runtime and try running again. Aftewards, you can run `!nvidia-smi` in a new cell to verify you have GPU available.\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from accelerate import Accelerator, find_executable_batch_size\n", "from transformers import AutoModelForCausalLM, AutoTokenizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's reload the stimulus transcript." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
wordstartend
0Act3.7103.790
1one,3.9904.190
2monkey4.6514.931
3in4.9515.011
4the5.0515.111
5middle.5.1515.391
6So6.5926.732
7there's6.7526.912
8some6.8927.052
9places7.0727.342
\n", "
" ], "text/plain": [ " word start end\n", "0 Act 3.710 3.790\n", "1 one, 3.990 4.190\n", "2 monkey 4.651 4.931\n", "3 in 4.951 5.011\n", "4 the 5.051 5.111\n", "5 middle. 5.151 5.391\n", "6 So 6.592 6.732\n", "7 there's 6.752 6.912\n", "8 some 6.892 7.052\n", "9 places 7.072 7.342" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(transcript_path)\n", "df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will define some of the general arguments, including the model name as it appears on HuggingFace, the context length (i.e., how many tokens we input into the model), and compute device. We can set the device to `cuda` to utilize a GPU if it's available." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "modelname = \"gpt2\"\n", "context_len = 32\n", "device = torch.device(\"cpu\")\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\", 0)\n", " print(\"Using cuda!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now load the GPT-2 tokenizer to convert words into a list of tokens. Then, we will [explode](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.explode.html) the dataframe so that each row of the dataframe is a token. We will convert tokens to `token_ids` (integers IDs corresponding to words in the GPT-2 vocabulary, which contains approximately 50,000 tokens) to use as input into GPT-2." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_idxwordstartendhftokentoken_id
00Act3.7103.790ĠAct2191
11one,3.9904.190Ġone530
21one,3.9904.190,11
32monkey4.6514.931Ġmonkey21657
43in4.9515.011Ġin287
54the5.0515.111Ġthe262
65middle.5.1515.391Ġmiddle3504
75middle.5.1515.391.13
86So6.5926.732ĠSo1406
97there's6.7526.912Ġthere612
\n", "
" ], "text/plain": [ " word_idx word start end hftoken token_id\n", "0 0 Act 3.710 3.790 ĠAct 2191\n", "1 1 one, 3.990 4.190 Ġone 530\n", "2 1 one, 3.990 4.190 , 11\n", "3 2 monkey 4.651 4.931 Ġmonkey 21657\n", "4 3 in 4.951 5.011 Ġin 287\n", "5 4 the 5.051 5.111 Ġthe 262\n", "6 5 middle. 5.151 5.391 Ġmiddle 3504\n", "7 5 middle. 5.151 5.391 . 13\n", "8 6 So 6.592 6.732 ĠSo 1406\n", "9 7 there's 6.752 6.912 Ġthere 612" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load model\n", "tokenizer = AutoTokenizer.from_pretrained(modelname)\n", "\n", "df.insert(0, \"word_idx\", df.index.values)\n", "df[\"hftoken\"] = df.word.apply(lambda x: tokenizer.tokenize(\" \" + x))\n", "\n", "df = df.explode(\"hftoken\", ignore_index=True)\n", "df[\"token_id\"] = df.hftoken.apply(tokenizer.convert_tokens_to_ids)\n", "\n", "df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we will download and load the pretrained GPT-2 model. You can inspect its configurations in `model.config` for more detailed information (e.g., number of layers, max context length)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading model...\n", "Model : gpt2\n", "Layers: 12\n", "EmbDim: 768\n", "Config: GPT2Config {\n", " \"_name_or_path\": \"gpt2\",\n", " \"activation_function\": \"gelu_new\",\n", " \"architectures\": [\n", " \"GPT2LMHeadModel\"\n", " ],\n", " \"attn_pdrop\": 0.1,\n", " \"bos_token_id\": 50256,\n", " \"embd_pdrop\": 0.1,\n", " \"eos_token_id\": 50256,\n", " \"initializer_range\": 0.02,\n", " \"layer_norm_epsilon\": 1e-05,\n", " \"model_type\": \"gpt2\",\n", " \"n_ctx\": 1024,\n", " \"n_embd\": 768,\n", " \"n_head\": 12,\n", " \"n_inner\": null,\n", " \"n_layer\": 12,\n", " \"n_positions\": 1024,\n", " \"reorder_and_upcast_attn\": false,\n", " \"resid_pdrop\": 0.1,\n", " \"scale_attn_by_inverse_layer_idx\": false,\n", " \"scale_attn_weights\": true,\n", " \"summary_activation\": null,\n", " \"summary_first_dropout\": 0.1,\n", " \"summary_proj_to_labels\": true,\n", " \"summary_type\": \"cls_index\",\n", " \"summary_use_proj\": true,\n", " \"task_specific_params\": {\n", " \"text-generation\": {\n", " \"do_sample\": true,\n", " \"max_length\": 50\n", " }\n", " },\n", " \"transformers_version\": \"4.45.2\",\n", " \"use_cache\": true,\n", " \"vocab_size\": 50257\n", "}\n", "\n" ] } ], "source": [ "print(\"Loading model...\")\n", "model = AutoModelForCausalLM.from_pretrained(modelname)\n", "\n", "print(\n", " f\"Model : {modelname}\"\n", " f\"\\nLayers: {model.config.num_hidden_layers}\"\n", " f\"\\nEmbDim: {model.config.hidden_size}\"\n", " f\"\\nConfig: {model.config}\"\n", ")\n", "model = model.eval()\n", "model = model.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since our transcript contains more tokens than the context window (32), we will reformat all the `token_ids` into `data`, a torch tensor with a shape of (number of tokens x 33). This is because to extract feature for a token from GPT-2 using context length 32, we will need to input 33 tokens to GPT-2, which contains the token itself and the 32 preceding tokens. Note that for the first 32 tokens in the transcript, we will use the pad_token_id or 0 to pad the input length to 33." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data has a shape of: torch.Size([5491, 33])\n" ] } ], "source": [ "token_ids = df.token_id.tolist()\n", "fill_value = 0\n", "if tokenizer.pad_token_id is not None:\n", " fill_value = tokenizer.pad_token_id\n", "\n", "data = torch.full((len(token_ids), context_len + 1), fill_value, dtype=torch.long)\n", "for i in range(len(token_ids)):\n", " example_tokens = token_ids[max(0, i - context_len) : i + 1]\n", " data[i, -len(example_tokens) :] = torch.tensor(example_tokens)\n", "\n", "print(f\"Data has a shape of: {data.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use [Accelerator](https://github.com/huggingface/accelerate) to make extracting features more efficient. It includes a [find_executable_batch_size](https://huggingface.co/docs/accelerate/v0.11.0/en/memory) algorithm, which can find the optimal batch size for the code by decreasing the batch size in half after each failed run on the code (in this case, our `inference_loop` function).\n", "\n", "Inside the `inference_loop` funcion, we will use a PyTorch `DataLoader` to supply token IDs to the model in batches and extract the features. In addition to the embeddings, we'll also extract several other features of potential interest from the model. As GPT-2 proceeds through the text, it generates a probability distribution (the `logits` extracted below) across all words in the vocabulary with the goal of correctly predicting the next word. We can use this probability distribution to derive other features of the model's internal computations. We'll extract the following features from GPT-2:\n", "\n", "* **embeddings**: the 768-dimensional contextual embedding capturing the meaning of the current word\n", "* **top_guesses**: the highest probability word GPT-2 predicts for the current word\n", "* **ranks**: the rank of the correct word given probabilities across the vocabulary\n", "* **true_probs**: the probability at which GPT-2 predicted the current word\n", "* **entropies**: how uncertain GPT-2 was about the current word\n", " * low entropy indicates that the probability distribution was \"focused\" on certain words\n", " * high entropy indicates the probability distribution was more uniform/dispersed across words" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" ] } ], "source": [ "accelerator = Accelerator()\n", "\n", "@find_executable_batch_size(starting_batch_size=32)\n", "def inference_loop(batch_size=32):\n", " # nonlocal accelerator # Ensure they can be used in our context\n", " accelerator.free_memory() # Free all lingering references\n", "\n", " data_dl = torch.utils.data.DataLoader(\n", " data, batch_size=batch_size, shuffle=False\n", " )\n", "\n", " top_guesses = []\n", " ranks = []\n", " true_probs = []\n", " entropies = []\n", " embeddings = []\n", "\n", " with torch.no_grad():\n", " for batch in data_dl:\n", " # Get output from model\n", " output = model(batch.to(device), output_hidden_states=True)\n", " logits = output.logits\n", " states = output.hidden_states\n", "\n", " true_ids = batch[:, -1]\n", " brange = list(range(len(true_ids)))\n", " logits_order = logits[:, -2, :].argsort(descending=True)\n", " batch_top_guesses = logits_order[:, 0]\n", " batch_ranks = torch.eq(logits_order, true_ids.reshape(-1, 1).to(device)).nonzero()[:, 1]\n", " batch_probs = torch.softmax(logits[:, -2, :], dim=-1)\n", " batch_true_probs = batch_probs[brange, true_ids]\n", " batch_entropy = torch.distributions.Categorical(probs=batch_probs).entropy()\n", " batch_embeddings = [state[:, -1, :].numpy(force=True) for state in states ]\n", "\n", " top_guesses.append(batch_top_guesses.numpy(force=True))\n", " ranks.append(batch_ranks.numpy(force=True))\n", " true_probs.append(batch_true_probs.numpy(force=True))\n", " entropies.append(batch_entropy.numpy(force=True))\n", " embeddings.append(batch_embeddings)\n", "\n", " return top_guesses, ranks, true_probs, entropies, embeddings\n", "\n", "top_guesses, ranks, true_probs, entropies, embeddings = inference_loop()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will add the additional information from GPT-2 as columns to `df`." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
word_idxwordstartendhftokentoken_idranktrue_probtop_predentropy
00Act3.7103.790ĠAct219131851.000139e-0800.092728
11one,3.9904.190Ġone530462.847577e-033525.294118
21one,3.9904.190,1128.006448e-0204.976894
32monkey4.6514.931Ġmonkey2165769786.075863e-067345.869678
43in4.9515.011Ġin287241.004823e-0302.478687
54the5.0515.111Ġthe26203.898537e-012624.340655
65middle.5.1515.391Ġmiddle350424.331103e-0252285.842120
75middle.5.1515.391.1334.237065e-022862.115351
86So6.5926.732ĠSo14061161.016026e-0321915.861630
97there's6.7526.912Ġthere612168.699116e-03115.249004
\n", "
" ], "text/plain": [ " word_idx word start end hftoken token_id rank true_prob \\\n", "0 0 Act 3.710 3.790 ĠAct 2191 3185 1.000139e-08 \n", "1 1 one, 3.990 4.190 Ġone 530 46 2.847577e-03 \n", "2 1 one, 3.990 4.190 , 11 2 8.006448e-02 \n", "3 2 monkey 4.651 4.931 Ġmonkey 21657 6978 6.075863e-06 \n", "4 3 in 4.951 5.011 Ġin 287 24 1.004823e-03 \n", "5 4 the 5.051 5.111 Ġthe 262 0 3.898537e-01 \n", "6 5 middle. 5.151 5.391 Ġmiddle 3504 2 4.331103e-02 \n", "7 5 middle. 5.151 5.391 . 13 3 4.237065e-02 \n", "8 6 So 6.592 6.732 ĠSo 1406 116 1.016026e-03 \n", "9 7 there's 6.752 6.912 Ġthere 612 16 8.699116e-03 \n", "\n", " top_pred entropy \n", "0 0 0.092728 \n", "1 352 5.294118 \n", "2 0 4.976894 \n", "3 734 5.869678 \n", "4 0 2.478687 \n", "5 262 4.340655 \n", "6 5228 5.842120 \n", "7 286 2.115351 \n", "8 2191 5.861630 \n", "9 11 5.249004 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[\"rank\"] = np.concatenate(ranks)\n", "df[\"true_prob\"] = np.concatenate(true_probs)\n", "df[\"top_pred\"] = np.concatenate(top_guesses)\n", "df[\"entropy\"] = np.concatenate(entropies)\n", "\n", "df.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And confirm the size and number of embeddings we got. Note that there are 13 layers (instead of the expected 12) because also included are the initial embeddings before the first layer of the network. Note that list of embeddings will be in batches, which will require flatenning to match the number of tokens." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are 13 layers of embeddings\n", "Each word embedding is 768 dimensions long\n" ] } ], "source": [ "print(f\"There are {len(embeddings[0])} layers of embeddings\")\n", "print(f\"Each word embedding is {embeddings[0][0].shape[1]} dimensions long\")" ] } ], "metadata": { "kernelspec": { "display_name": "mne", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }