212 lines
7.0 KiB
Plaintext
212 lines
7.0 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"authorship_tag": "ABX9TyMSk8qTqDYqFnRJVZKlsue0",
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap12/12_2_Multihead_Self_Attention.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# **Notebook 12.1: Multhead Self-Attention**\n",
|
|
"\n",
|
|
"This notebook builds a multihead self-attention mechanism as in figure 12.6\n",
|
|
"\n",
|
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
|
"\n",
|
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n",
|
|
"\n"
|
|
],
|
|
"metadata": {
|
|
"id": "t9vk9Elugvmi"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
],
|
|
"metadata": {
|
|
"id": "OLComQyvCIJ7"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"The multihead self-attention mechanism maps $N$ inputs $\\mathbf{x}_{n}\\in\\mathbb{R}^{D}$ and returns $N$ outputs $\\mathbf{x}'_{n}\\in \\mathbb{R}^{D}$. \n",
|
|
"\n"
|
|
],
|
|
"metadata": {
|
|
"id": "9OJkkoNqCVK2"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Set seed so we get the same random numbers\n",
|
|
"np.random.seed(3)\n",
|
|
"# Number of inputs\n",
|
|
"N = 6\n",
|
|
"# Number of dimensions of each input\n",
|
|
"D = 8\n",
|
|
"# Create an empty list\n",
|
|
"X = np.random.normal(size=(D,N))\n",
|
|
"# Print X\n",
|
|
"print(X)"
|
|
],
|
|
"metadata": {
|
|
"id": "oAygJwLiCSri"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"We'll use two heads. We'll need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4). We'll use two heads, and (as in the figure), we'll make the queries keys and values of size D/H"
|
|
],
|
|
"metadata": {
|
|
"id": "W2iHFbtKMaDp"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Number of heads\n",
|
|
"H = 2\n",
|
|
"# QDV dimension\n",
|
|
"H_D = int(D/H)\n",
|
|
"\n",
|
|
"# Set seed so we get the same random numbers\n",
|
|
"np.random.seed(0)\n",
|
|
"\n",
|
|
"# Choose random values for the parameters for the first head\n",
|
|
"omega_q1 = np.random.normal(size=(H_D,D))\n",
|
|
"omega_k1 = np.random.normal(size=(H_D,D))\n",
|
|
"omega_v1 = np.random.normal(size=(H_D,D))\n",
|
|
"beta_q1 = np.random.normal(size=(H_D,1))\n",
|
|
"beta_k1 = np.random.normal(size=(H_D,1))\n",
|
|
"beta_v1 = np.random.normal(size=(H_D,1))\n",
|
|
"\n",
|
|
"# Choose random values for the parameters for the second head\n",
|
|
"omega_q2 = np.random.normal(size=(H_D,D))\n",
|
|
"omega_k2 = np.random.normal(size=(H_D,D))\n",
|
|
"omega_v2 = np.random.normal(size=(H_D,D))\n",
|
|
"beta_q2 = np.random.normal(size=(H_D,1))\n",
|
|
"beta_k2 = np.random.normal(size=(H_D,1))\n",
|
|
"beta_v2 = np.random.normal(size=(H_D,1))\n",
|
|
"\n",
|
|
"# Choose random values for the parameters\n",
|
|
"omega_c = np.random.normal(size=(D,D))"
|
|
],
|
|
"metadata": {
|
|
"id": "79TSK7oLMobe"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"Now let's compute the multiscale self-attention"
|
|
],
|
|
"metadata": {
|
|
"id": "VxaKQtP3Ng6R"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Define softmax operation that works independently on each column\n",
|
|
"def softmax_cols(data_in):\n",
|
|
" # Exponentiate all of the values\n",
|
|
" exp_values = np.exp(data_in) ;\n",
|
|
" # Sum over columns\n",
|
|
" denom = np.sum(exp_values, axis = 0);\n",
|
|
" # Replicate denominator to N rows\n",
|
|
" denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])\n",
|
|
" # Compute softmax\n",
|
|
" softmax = exp_values / denom\n",
|
|
" # return the answer\n",
|
|
" return softmax"
|
|
],
|
|
"metadata": {
|
|
"id": "obaQBdUAMXXv"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
" # Now let's compute self attention in matrix form\n",
|
|
"def multihead_scaled_self_attention(X,omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1, omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2, omega_c):\n",
|
|
"\n",
|
|
" # TODO Write the multihead scaled self-attention mechanism.\n",
|
|
" # Replace this line\n",
|
|
" X_prime = np.zeros_like(X) ;\n",
|
|
"\n",
|
|
"\n",
|
|
" return X_prime"
|
|
],
|
|
"metadata": {
|
|
"id": "gb2WvQ3SiH8r"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Run the self attention mechanism\n",
|
|
"X_prime = multihead_scaled_self_attention(X,omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1, omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2, omega_c)\n",
|
|
"\n",
|
|
"# Print out the results\n",
|
|
"np.set_printoptions(precision=3)\n",
|
|
"print(\"Your answer:\")\n",
|
|
"print(X_prime)\n",
|
|
"\n",
|
|
"print(\"True values:\")\n",
|
|
"print(\"[[-21.207 -5.373 -20.933 -9.179 -11.319 -17.812]\")\n",
|
|
"print(\" [ -1.995 7.906 -10.516 3.452 9.863 -7.24 ]\")\n",
|
|
"print(\" [ 5.479 1.115 9.244 0.453 5.656 7.089]\")\n",
|
|
"print(\" [ -7.413 -7.416 0.363 -5.573 -6.736 -0.848]\")\n",
|
|
"print(\" [-11.261 -9.937 -4.848 -8.915 -13.378 -5.761]\")\n",
|
|
"print(\" [ 3.548 10.036 -2.244 1.604 12.113 -2.557]\")\n",
|
|
"print(\" [ 4.888 -5.814 2.407 3.228 -4.232 3.71 ]\")\n",
|
|
"print(\" [ 1.248 18.894 -6.409 3.224 19.717 -5.629]]\")\n",
|
|
"\n",
|
|
"# If your answers don't match, then make sure that you are doing the scaling, and make sure the scaling value is correct"
|
|
],
|
|
"metadata": {
|
|
"id": "MUOJbgJskUpl"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |