{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "99e430c8", "metadata": {}, "outputs": [], "source": [ "!conda install --yes ipympl pytorch graphviz\n", "!pip install torchviz" ] }, { "cell_type": "code", "execution_count": null, "id": "b2c6b958", "metadata": {}, "outputs": [], "source": [ "import math\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import tqdm\n", "\n", "import IPython\n", "from ipywidgets import interactive, widgets, Layout\n", "from IPython.display import display, HTML" ] }, { "cell_type": "code", "execution_count": null, "id": "34d01662", "metadata": {}, "outputs": [], "source": [ "# other possible options: inline, notebook, qt, web\n", "%matplotlib widget" ] }, { "cell_type": "code", "execution_count": null, "id": "61cf3329", "metadata": {}, "outputs": [], "source": [ "# Constants\n", "cap_value = 1e-6 # Farads\n", "R_init = 500 # Ohms\n", "cutoff_mag = 1. / math.sqrt(2)\n", "cutoff_dB = 20 * math.log10(cutoff_mag)\n", "dataset_size = 1000\n", "max_training_steps = 100000" ] }, { "cell_type": "markdown", "id": "6a5aebcb", "metadata": {}, "source": [ "## (a) Designing a Low Pass Filter by Matching Transfer Functions" ] }, { "cell_type": "code", "execution_count": null, "id": "103bf234", "metadata": {}, "outputs": [], "source": [ "# Transfer function: evaluates magnitude of given frequencies for a resistor value in the low pass circuit\n", "def evaluate_lp_circuit(freqs, R_low):\n", " return 1. / torch.sqrt(1 + (R_low * cap_value * freqs) ** 2)" ] }, { "cell_type": "code", "execution_count": null, "id": "b636c220", "metadata": {}, "outputs": [], "source": [ "# Plot transfer function for a given low pass circuit\n", "fig = plt.figure(figsize=(9, 4))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "mags = 20 * torch.log10(evaluate_lp_circuit(ws, R_init))\n", "R_low_des = 1 / (2 * math.pi * 800 * cap_value)\n", "mags_des = 20 * torch.log10(evaluate_lp_circuit(ws, R_low_des))\n", "tf, = plt.semilogx(ws / (2 * math.pi), mags, linewidth=3)\n", "tf_des, = plt.semilogx(ws / (2 * math.pi), mags_des, linestyle=\"--\", linewidth=3)\n", "plt.xlim([1, 1e6])\n", "plt.ylim([-60, 1])\n", "plt.title(\"Low Pass Transfer Functions\")\n", "plt.xlabel(\"Frequency (Hz)\")\n", "plt.ylabel(\"dB\")\n", "plt.grid(which=\"both\")\n", "leg = plt.legend([\"Predicted Transfer Function\", \"Desired Transfer Function\"])\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plot\n", "def update_tfs(R=R_init):\n", " mags = 20 * torch.log10(evaluate_lp_circuit(ws, R))\n", " tf.set_data(ws / (2 * math.pi), mags)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_tfs, \n", " R=widgets.IntSlider(value=R_init, min=1, max=1000, step=1, description=\"R\", layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "c1b99831", "metadata": {}, "source": [ "## (b) Designing a Low pass Filter from Binary Data" ] }, { "cell_type": "code", "execution_count": null, "id": "c1ec5419", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Plot transfer function for a given low pass circuit\n", "fig = plt.figure(figsize=(9, 5))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "mags = 20 * torch.log10(evaluate_lp_circuit(ws, R_init))\n", "cutoff = ws[np.argmax(mags < cutoff_dB)]\n", "tf, = plt.semilogx(ws / (2 * math.pi), mags, linewidth=3)\n", "cut = plt.axvline(cutoff / (2 * math.pi), c=\"red\", linestyle=\"--\", linewidth=3)\n", "plt.xlim([1, 1e6])\n", "plt.ylim([-60, 1])\n", "plt.title(\"Low Pass Transfer Function\")\n", "plt.xlabel(\"Frequency (Hz)\")\n", "plt.ylabel(\"dB\")\n", "plt.grid(which=\"both\")\n", "leg = plt.legend([\"Transfer Function\", f\"Cutoff Frequency ({1 / (2 * math.pi * R_init * cap_value):.0f} Hz)\"])\n", "\n", "# Plot table of LED on/off values (predicted and desired)\n", "ws_test = 2 * math.pi * np.linspace(300, 1500, num=7)\n", "table_txt = np.zeros((3, len(ws_test) + 1), dtype=\"U15\")\n", "table_txt[0, :] = [\"Frequency\"] + [f\"{w / (2 * math.pi):.0f} Hz\" for w in ws_test]\n", "table_txt[1:, 0] = [\"Predicted\", \"Desired\"]\n", "table_colors = np.zeros_like(table_txt, dtype=(np.int32, (3,)))\n", "table_colors[-1, 1:4] = (1, 0, 0)\n", "table_colors[1, 1] = (1, 0, 0)\n", "table_colors[:, :1] = (1, 1, 1)\n", "table_colors[:1, :] = (1, 1, 1)\n", "tab = plt.table(table_txt, table_colors, bbox=[0.0, -0.5, 1.0, 0.25], cellLoc=\"center\")\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plot\n", "def update_lights(R=R_init):\n", " mags = 20 * torch.log10(evaluate_lp_circuit(ws, R))\n", " cutoff = ws[np.argmax(mags < cutoff_dB)]\n", " tf.set_data(ws / (2 * math.pi), mags)\n", " cut.set_xdata(cutoff / (2 * math.pi))\n", " for i, w in enumerate(ws_test):\n", " if w < cutoff:\n", " tab[(1, i+1)].set_facecolor((1, 0, 0))\n", " else:\n", " tab[(1, i+1)].set_facecolor((0, 0, 0))\n", " leg.get_texts()[1].set_text(f\"Cutoff Frequency ({1 / (2 * math.pi * R * cap_value):.0f} Hz)\")\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_lights, \n", " R=widgets.IntSlider(value=R_init, min=1, max=1000, step=1, description=\"R\", layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "943a723c", "metadata": {}, "source": [ "## (c) Learning a Low Pass Filter from Desired Transfer Function Samples" ] }, { "cell_type": "code", "execution_count": null, "id": "eb078f6c", "metadata": {}, "outputs": [], "source": [ "# PyTorch model of the low pass circuit (for training)\n", "class LowPassCircuit(nn.Module):\n", " def __init__(self, R=None):\n", " super().__init__()\n", " self.R = nn.Parameter(torch.tensor(R, dtype=float) if R is not None else torch.rand(1) * 1000)\n", " \n", " # Note: the forward function is called automatically when the __call__ function of this object is called\n", " def forward(self, freqs):\n", " return evaluate_lp_circuit(freqs, self.R)\n", " \n", "# Generate training data in a uniform log scale of frequences, then evaluate using the true transfer function\n", "def generate_lp_training_data(n):\n", " rand_ws = 2 * math.pi * torch.pow(10, torch.rand(n) * 6)\n", " labels = evaluate_lp_circuit(rand_ws, R_low_des)\n", " return rand_ws, labels\n", "\n", "# Train a given low pass filter\n", "def train_lp_circuit_tf(circuit, loss_fn, dataset_size, max_training_steps, lr):\n", " \n", " R_values = [float(circuit.R.data)]\n", " grad_values = [np.nan]\n", " train_data = generate_lp_training_data(dataset_size)\n", " print(f\"Initial Resistor Value: R = {float(circuit.R.data):.0f}\")\n", " iter_bar = tqdm.trange(max_training_steps, desc=\"Training Iter\")\n", " for i in iter_bar:\n", " pred = circuit(train_data[0])\n", " loss = loss_fn(pred, train_data[1]).mean()\n", " grad = torch.autograd.grad(loss, circuit.R)\n", " with torch.no_grad():\n", " circuit.R -= lr * grad[0]\n", "\n", " R_values.append(float(circuit.R.data))\n", " grad_values.append(float(grad[0].data))\n", " iter_bar.set_postfix_str(f\"Loss: {float(loss.data):.3f}, R={float(circuit.R.data):.0f}\")\n", " if loss.data < 1e-6 or abs(grad[0].data) < 1e-6:\n", " break\n", "\n", " print(f\"Final Resistor Value: R = {float(circuit.R.data):.0f}\") \n", " return train_data, R_values, grad_values" ] }, { "cell_type": "code", "execution_count": null, "id": "4502a565", "metadata": {}, "outputs": [], "source": [ "# Create a circuit, use mean squared error loss w/ learning rate of 200\n", "circuit = LowPassCircuit(1000)\n", "loss_fn = lambda x, y: (x - y) ** 2\n", "lr = 200\n", "train_data_low_tf, R_values_low_tf, grad_values_low_tf = train_lp_circuit_tf(circuit, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "79fe9b39", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 100)\n", "ax1.scatter(train_data_low_tf[0][::subsample] / (2 * math.pi), 20 * torch.log10(train_data_low_tf[1][::subsample]), c=\"k\", marker=\"x\")\n", "learned_tf, = ax1.semilogx(ws / (2 * math.pi), 20 * torch.log10(evaluate_lp_circuit(ws, R_values_low_tf[0])), linewidth=3)\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"dB\")\n", "ax1.legend([\"Learned Transfer Function\", \"True Transfer Function Samples\"])\n", "\n", "# Show loss surface over training\n", "eval_pts = torch.arange(10, 1001, 1)\n", "eval_vals = evaluate_lp_circuit(train_data_low_tf[0][:, None], eval_pts[None, :])\n", "loss_surface_mse = loss_fn(eval_vals, train_data_low_tf[1][:, None].expand(eval_vals.shape))\n", "ax2.plot(eval_pts, loss_surface_mse.sum(0), linewidth=3)\n", "cur_loss, = ax2.plot(R_values_low_tf[0], loss_surface_mse[:, int(R_values_low_tf[0] - 10)].sum(0), marker=\"o\")\n", "cur_loss_label = ax2.annotate(f\"R = {R_values_low_tf[0]:.0f}\", (0, 0), xytext=(0.82, 0.9), textcoords='axes fraction')\n", "ax2.set_title(\"Loss Surface\")\n", "ax2.set_xlim([0, 1000])\n", "ax2.set_xlabel(\"$R \\; (\\Omega)$\")\n", "ax2.set_ylabel(\"Loss\")\n", "\n", "# Show loss contributions of each data point\n", "cur_circuit = LowPassCircuit(R_values_low_tf[0])\n", "data_losses = loss_fn(cur_circuit(train_data_low_tf[0][::subsample]), (train_data_low_tf[1][::subsample]).float())\n", "data_grads = torch.zeros(len(data_losses))\n", "for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", "data_grads_scat = ax3.scatter(train_data_low_tf[0][::subsample] / (2 * math.pi), data_grads, marker=\"x\", c=\"k\")\n", "ax3.set_xscale(\"log\")\n", "ax3.set_ylabel(\"Derivative\")\n", "ax3.set_xlim([1, 1e6])\n", "ax3.set_ylim([-1e-4, 1e-3])\n", "ax3.set_xlabel(\"Frequency (Hz)\")\n", "ax3.set_title(\"Derivative by Training Datapoint\")\n", "\n", "# Show total gradient at each training iteration\n", "ax4.plot(np.arange(len(grad_values_low_tf)), grad_values_low_tf, linewidth=3)\n", "cur_iter, = ax4.plot(0, grad_values_low_tf[0], marker=\"o\")\n", "cur_grad_label = ax4.annotate(f\"Grad = {grad_values_low_tf[0]:.2e}\", (0, 0), xytext=(0.65, 0.9), textcoords='axes fraction')\n", "ax4.set_xlabel(\"Training Iteration\")\n", "ax4.set_ylabel(\"Gradient\")\n", "ax4.set_title(\"Gradients\")\n", "ax4.set_xlim([-1, len(grad_values_low_tf)])\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_tf(t=0):\n", " learned_tf.set_data(ws / (2 * math.pi), 20 * torch.log10(evaluate_lp_circuit(ws, R_values_low_tf[t])))\n", " cur_loss.set_data(R_values_low_tf[t], loss_surface_mse[:, int(R_values_low_tf[t] - 10)].sum(0))\n", " cur_loss_label.set_text(f\"R = {R_values_low_tf[t]:.0f}\")\n", " cur_iter.set_data(t, grad_values_low_tf[t])\n", " cur_grad_label.set_text(f\"Grad = {grad_values_low_tf[t]:.2e}\")\n", " cur_circuit = LowPassCircuit(R_values_low_tf[t])\n", " data_losses = loss_fn(cur_circuit(train_data_low_tf[0][::subsample]), (train_data_low_tf[1][::subsample]).float())\n", " data_grads = torch.zeros(len(data_losses))\n", " for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", " data_grads_scat.set_offsets(torch.stack((train_data_low_tf[0][::subsample] / (2 * math.pi), data_grads)).T)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_tf, \n", " t=widgets.IntSlider(value=0, min=0, max=len(R_values_low_tf) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "38bd4f56", "metadata": {}, "source": [ "## (d) Learning a Low Pass Filter from Binary Data with Mean Squared Error Loss" ] }, { "cell_type": "code", "execution_count": null, "id": "cd02eb2e", "metadata": {}, "outputs": [], "source": [ "# Train a given low pass filter from binary data\n", "def train_lp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr):\n", " \n", " R_values = [float(circuit.R.data)]\n", " grad_values = [np.nan]\n", " train_data = generate_lp_training_data(dataset_size)\n", " print(f\"Initial Resistor Value: R = {float(circuit.R.data):.0f}\")\n", " iter_bar = tqdm.trange(max_training_steps, desc=\"Training Iter\")\n", " for i in iter_bar:\n", " pred = circuit(train_data[0])\n", " ### YOUR CODE HERE\n", " loss = loss_fn(?, ?).mean()\n", " ### END YOUR CODE\n", " grad = torch.autograd.grad(loss, circuit.R)\n", " with torch.no_grad():\n", " circuit.R -= lr * grad[0]\n", "\n", " R_values.append(float(circuit.R.data))\n", " grad_values.append(float(grad[0].data))\n", " iter_bar.set_postfix_str(f\"Loss: {float(loss.data):.3f}, R={float(circuit.R.data):.0f}\")\n", " if loss.data < 1e-6 or abs(grad[0].data) < 1e-6:\n", " break\n", "\n", " print(f\"Final Resistor Value: R = {float(circuit.R.data):.0f}\") \n", " return train_data, R_values, grad_values" ] }, { "cell_type": "code", "execution_count": null, "id": "93cf1d40", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Create a circuit, use MSE loss with learning rate of 200\n", "circuit = LowPassCircuit(500)\n", "loss_fn = lambda x, y: (x - y) ** 2\n", "lr = 200\n", "train_data_low_bin, R_values_low_bin, grad_values_low_bin = train_lp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "0f86febe", "metadata": {}, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 100)\n", "train_data_mask = train_data_low_bin[1][::subsample] > cutoff_mag\n", "ax1.scatter(train_data_low_bin[0][::subsample][train_data_mask] / (2 * math.pi), np.ones(train_data_mask.sum()), c=\"r\", marker=\"x\")\n", "ax1.scatter(train_data_low_bin[0][::subsample][~train_data_mask] / (2 * math.pi), np.zeros((~train_data_mask).sum()), c=\"k\", marker=\"x\")\n", "mags = evaluate_lp_circuit(ws, R_values_low_bin[0])\n", "learned_tf, = ax1.semilogx(ws / (2 * math.pi), mags, linewidth=3)\n", "cutoff = ws[np.argmax(mags < cutoff_mag)]\n", "cut = ax1.axvline(cutoff / (2 * math.pi), c=\"red\", linestyle=\"--\", linewidth=3)\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"Magnitude\")\n", "ax1.legend([\"Learned TF\", \"Learned $f_c$\", \"TF + Samples\", \"TF - Samples\"])\n", "\n", "# Show loss surface over training\n", "eval_pts = torch.arange(10, 1001, 1)\n", "eval_vals = evaluate_lp_circuit(train_data_low_bin[0][:, None], eval_pts[None, :])\n", "loss_surface_mse = loss_fn(eval_vals, (train_data_low_bin[1][:, None].expand(eval_vals.shape) > cutoff_mag).float())\n", "ax2.plot(eval_pts, loss_surface_mse.sum(0), linewidth=3)\n", "cur_loss, = ax2.plot(R_values_low_bin[0], loss_surface_mse[:, int(R_values_low_bin[0] - 10)].sum(0), marker=\"o\")\n", "cur_loss_label = ax2.annotate(f\"R = {R_values_low_bin[0]:.0f}\", (0, 0), xytext=(0.82, 0.9), textcoords='axes fraction')\n", "ax2.set_title(\"Loss Surface\")\n", "ax2.set_xlim([0, 1000])\n", "ax2.set_xlabel(\"$R \\; (\\Omega)$\")\n", "ax2.set_ylabel(\"Loss\")\n", "\n", "# Show loss contributions of each data point\n", "cur_circuit = LowPassCircuit(R_values_low_bin[0])\n", "data_losses = loss_fn(cur_circuit(train_data_low_bin[0][::subsample]), (train_data_low_bin[1][::subsample] > cutoff_mag).float())\n", "data_grads = torch.zeros(len(data_losses))\n", "for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", "data_grads_scat = ax3.scatter(train_data_low_bin[0][::subsample] / (2 * math.pi), data_grads, marker=\"x\", c=\"k\")\n", "ax3.set_xscale(\"log\")\n", "ax3.set_ylabel(\"Derivative\")\n", "ax3.set_xlim([1, 1e6])\n", "ax3.set_ylim([-1.5e-3, 1.5e-3])\n", "ax3.set_xlabel(\"Frequency (Hz)\")\n", "ax3.set_title(\"Derivative by Training Datapoint\")\n", "\n", "# Show gradient at each training iteration\n", "ax4.plot(np.arange(len(grad_values_low_bin)), grad_values_low_bin, linewidth=3)\n", "cur_iter, = ax4.plot(0, grad_values_low_bin[0], marker=\"o\")\n", "cur_grad_label = ax4.annotate(f\"Grad = {grad_values_low_bin[0]:.2e}\", (0, 0), xytext=(0.65, 0.9), textcoords='axes fraction')\n", "ax4.set_xlabel(\"Training Iteration\")\n", "ax4.set_ylabel(\"Gradient\")\n", "ax4.set_title(\"Gradients\")\n", "ax4.set_xlim([-1, len(grad_values_low_bin)])\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_low_bin(t=0):\n", " mags = evaluate_lp_circuit(ws, R_values_low_bin[t])\n", " learned_tf.set_data(ws / (2 * math.pi), mags)\n", " cutoff = ws[np.argmax(mags < cutoff_mag)]\n", " cut.set_xdata(cutoff / (2 * math.pi))\n", " cur_loss.set_data(R_values_low_bin[t], loss_surface_mse[:, int(R_values_low_bin[t] - 10)].sum(0))\n", " cur_loss_label.set_text(f\"R = {R_values_low_bin[t]:.0f}\")\n", " cur_iter.set_data(t, grad_values_low_bin[t])\n", " cur_grad_label.set_text(f\"Grad = {grad_values_low_bin[t]:.2e}\")\n", " cur_circuit = LowPassCircuit(R_values_low_bin[t])\n", " data_losses = loss_fn(cur_circuit(train_data_low_bin[0][::subsample]), (train_data_low_bin[1][::subsample] > cutoff_mag).float())\n", " data_grads = torch.zeros(len(data_losses))\n", " for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", " data_grads_scat.set_offsets(torch.stack((train_data_low_bin[0][::subsample] / (2 * math.pi), data_grads)).T)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_low_bin, \n", " t=widgets.IntSlider(value=0, min=0, max=len(R_values_low_bin) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "320acfe9", "metadata": {}, "source": [ "## (e) Learning a Low Pass Filter from Binary Data with a Different Loss" ] }, { "cell_type": "code", "execution_count": null, "id": "c6b1d442", "metadata": {}, "outputs": [], "source": [ "circuit = LowPassCircuit(500)\n", "### YOUR CODE HERE\n", "loss_fn = lambda x, y:\n", "### END YOUR CODE\n", "train_data_low_bin, R_values_low_bin, grad_values_low_bin = train_lp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "a43257ab", "metadata": {}, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 100)\n", "train_data_mask = train_data_low_bin[1][::subsample] > cutoff_mag\n", "ax1.scatter(train_data_low_bin[0][::subsample][train_data_mask] / (2 * math.pi), np.ones(train_data_mask.sum()), c=\"r\", marker=\"x\")\n", "ax1.scatter(train_data_low_bin[0][::subsample][~train_data_mask] / (2 * math.pi), np.zeros((~train_data_mask).sum()), c=\"k\", marker=\"x\")\n", "mags = evaluate_lp_circuit(ws, R_values_low_bin[0])\n", "learned_tf, = ax1.semilogx(ws / (2 * math.pi), mags, linewidth=3)\n", "cutoff = ws[np.argmax(mags < cutoff_mag)]\n", "cut = ax1.axvline(cutoff / (2 * math.pi), c=\"red\", linestyle=\"--\", linewidth=3)\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"Magnitude\")\n", "ax1.legend([\"Learned TF\", \"Learned $f_c$\", \"TF + Samples\", \"TF - Samples\"])\n", "\n", "# Show loss surface over training\n", "eval_pts = torch.arange(10, 1001, 1)\n", "eval_vals = evaluate_lp_circuit(train_data_low_bin[0][:, None], eval_pts[None, :])\n", "loss_surface_mse = loss_fn(eval_vals, (train_data_low_bin[1][:, None].expand(eval_vals.shape) > cutoff_mag).float())\n", "ax2.plot(eval_pts, loss_surface_mse.sum(0), linewidth=3)\n", "cur_loss, = ax2.plot(R_values_low_bin[0], loss_surface_mse[:, int(R_values_low_bin[0] - 10)].sum(0), marker=\"o\")\n", "cur_loss_label = ax2.annotate(f\"R = {R_values_low_bin[0]:.0f}\", (0, 0), xytext=(0.82, 0.9), textcoords='axes fraction')\n", "ax2.set_title(\"Loss Surface\")\n", "ax2.set_xlim([0, 1000])\n", "ax2.set_xlabel(\"$R \\; (\\Omega)$\")\n", "ax2.set_ylabel(\"Loss\")\n", "\n", "# Show loss contributions of each data point\n", "cur_circuit = LowPassCircuit(R_values_low_bin[0])\n", "data_losses = loss_fn(cur_circuit(train_data_low_bin[0][::subsample]), (train_data_low_bin[1][::subsample] > cutoff_mag).float())\n", "data_grads = torch.zeros(len(data_losses))\n", "for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", "data_grads_scat = ax3.scatter(train_data_low_bin[0][::subsample] / (2 * math.pi), data_grads, marker=\"x\", c=\"k\")\n", "ax3.set_xscale(\"log\")\n", "ax3.set_ylabel(\"Derivative\")\n", "ax3.set_xlim([1, 1e6])\n", "ax3.set_ylim([-1.5e-3, 1.5e-3])\n", "ax3.set_xlabel(\"Frequency (Hz)\")\n", "ax3.set_title(\"Derivative by Training Datapoint\")\n", "\n", "# Show gradient at each training iteration\n", "ax4.plot(np.arange(len(grad_values_low_bin)), grad_values_low_bin, linewidth=3)\n", "cur_iter, = ax4.plot(0, grad_values_low_bin[0], marker=\"o\")\n", "cur_grad_label = ax4.annotate(f\"Grad = {grad_values_low_bin[0]:.2e}\", (0, 0), xytext=(0.65, 0.9), textcoords='axes fraction')\n", "ax4.set_xlabel(\"Training Iteration\")\n", "ax4.set_ylabel(\"Gradient\")\n", "ax4.set_title(\"Gradients\")\n", "ax4.set_xlim([-1, len(grad_values_low_bin)])\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_low_bin(t=0):\n", " mags = evaluate_lp_circuit(ws, R_values_low_bin[t])\n", " learned_tf.set_data(ws / (2 * math.pi), mags)\n", " cutoff = ws[np.argmax(mags < cutoff_mag)]\n", " cut.set_xdata(cutoff / (2 * math.pi))\n", " cur_loss.set_data(R_values_low_bin[t], loss_surface_mse[:, int(R_values_low_bin[t] - 10)].sum(0))\n", " cur_loss_label.set_text(f\"R = {R_values_low_bin[t]:.0f}\")\n", " cur_iter.set_data(t, grad_values_low_bin[t])\n", " cur_grad_label.set_text(f\"Grad = {grad_values_low_bin[t]:.2e}\")\n", " cur_circuit = LowPassCircuit(R_values_low_bin[t])\n", " data_losses = loss_fn(cur_circuit(train_data_low_bin[0][::subsample]), (train_data_low_bin[1][::subsample] > cutoff_mag).float())\n", " data_grads = torch.zeros(len(data_losses))\n", " for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", " data_grads_scat.set_offsets(torch.stack((train_data_low_bin[0][::subsample] / (2 * math.pi), data_grads)).T)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_low_bin, \n", " t=widgets.IntSlider(value=0, min=0, max=len(R_values_low_bin) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "6284868f", "metadata": {}, "source": [ "## (f) Learning a High Pass Filter from Binary Data" ] }, { "cell_type": "code", "execution_count": null, "id": "f66e7a16", "metadata": {}, "outputs": [], "source": [ "# Transfer function: evaluates magnitude of given frequencies for a resistor value in the high pass circuit\n", "def evaluate_hp_circuit(freqs, R_high):\n", " ### YOUR CODE HERE\n", " return \n", " ### END YOUR CODE\n", "\n", "# PyTorch model of the high pass circuit (for training)\n", "class HighPassCircuit(nn.Module):\n", " def __init__(self, R=None):\n", " super().__init__()\n", " self.R = nn.Parameter(torch.tensor(R, dtype=float) if R is not None else torch.rand(1) * 1000)\n", " \n", " def forward(self, freqs):\n", " return evaluate_hp_circuit(freqs, self.R)\n", " \n", "# Generate training data in a uniform log scale of frequences, then evaluate using the true transfer function\n", "R_high_des = 1 / (2 * math.pi * 5000 * cap_value)\n", "def generate_hp_training_data(n):\n", " rand_ws = 2 * math.pi * torch.pow(10, torch.rand(n) * 6)\n", " labels = evaluate_hp_circuit(rand_ws, R_high_des)\n", " return rand_ws, labels\n", "\n", "# Train a given low pass filter from binary data\n", "def train_hp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr):\n", " \n", " R_values = [float(circuit.R.data)]\n", " grad_values = [np.nan]\n", " train_data = generate_hp_training_data(dataset_size)\n", " print(f\"Initial Resistor Value: R = {float(circuit.R.data):.0f}\")\n", " iter_bar = tqdm.trange(max_training_steps, desc=\"Training Iter\")\n", " for i in iter_bar:\n", " pred = circuit(train_data[0])\n", " loss = loss_fn(pred, (train_data[1] > cutoff_mag).float()).mean()\n", " ### YOUR CODE HERE\n", " grad = torch.autograd.grad(?, ?)\n", " ### END YOUR CODE\n", " with torch.no_grad():\n", " ### YOUR CODE HERE\n", " circuit.R -=\n", " ### END YOUR CODE\n", " \n", " R_values.append(float(circuit.R.data))\n", " grad_values.append(float(grad[0].data))\n", " iter_bar.set_postfix_str(f\"Loss: {float(loss.data):.3f}, R={float(circuit.R.data):.0f}\")\n", " if loss.data < 1e-6 or abs(grad[0].data) < 1e-6:\n", " break\n", "\n", " print(f\"Final Resistor Value: R = {float(circuit.R.data):.0f}\") \n", " return train_data, R_values, grad_values" ] }, { "cell_type": "code", "execution_count": null, "id": "83d01c6a", "metadata": {}, "outputs": [], "source": [ "# Create a circuit, use loss_fn with learning rate of 1000\n", "circuit = HighPassCircuit(500)\n", "### YOUR CODE HERE\n", "loss_fn = lambda x, y: \n", "### END YOUR CODE\n", "lr = 1000\n", "train_data_high_bin, R_values_high_bin, grad_values_high_bin = train_hp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "209aee6b", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 100)\n", "train_data_mask = train_data_high_bin[1][::subsample] > cutoff_mag\n", "ax1.scatter(train_data_high_bin[0][::subsample][train_data_mask] / (2 * math.pi), np.ones(train_data_mask.sum()), c=\"r\", marker=\"x\")\n", "ax1.scatter(train_data_high_bin[0][::subsample][~train_data_mask] / (2 * math.pi), np.zeros((~train_data_mask).sum()), c=\"k\", marker=\"x\")\n", "mags = evaluate_hp_circuit(ws, R_values_high_bin[0])\n", "learned_tf, = ax1.semilogx(ws / (2 * math.pi), mags, linewidth=3)\n", "cutoff = ws[np.argmax(mags > cutoff_mag)]\n", "cut = ax1.axvline(cutoff / (2 * math.pi), c=\"red\", linestyle=\"--\", linewidth=3)\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"Magnitude\")\n", "ax1.legend([\"Learned TF\", \"Learned $f_c$\", \"TF + Samples\", \"TF - Samples\"])\n", "\n", "# Show loss surface over training\n", "eval_pts = torch.arange(10, 1001, 1)\n", "eval_vals = evaluate_hp_circuit(train_data_high_bin[0][:, None], eval_pts[None, :])\n", "loss_surface_mse = loss_fn(eval_vals, (train_data_high_bin[1][:, None].expand(eval_vals.shape) > cutoff_mag).float())\n", "ax2.plot(eval_pts, loss_surface_mse.sum(0), linewidth=3)\n", "cur_loss, = ax2.plot(R_values_high_bin[0], loss_surface_mse[:, int(R_values_high_bin[0] - 10)].sum(0), marker=\"o\")\n", "cur_loss_label = ax2.annotate(f\"R = {R_values_high_bin[0]:.0f}\", (0, 0), xytext=(0.82, 0.9), textcoords='axes fraction')\n", "ax2.set_title(\"Loss Surface\")\n", "ax2.set_xlim([0, 1000])\n", "ax2.set_xlabel(\"$R \\; (\\Omega)$\")\n", "ax2.set_ylabel(\"Loss\")\n", "\n", "# Show loss contributions of each data point\n", "cur_circuit = HighPassCircuit(R_values_high_bin[0])\n", "data_losses = loss_fn(cur_circuit(train_data_high_bin[0][::subsample]), (train_data_high_bin[1][::subsample] > cutoff_mag).float())\n", "data_grads = torch.zeros(len(data_losses))\n", "for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", "data_grads_scat = ax3.scatter(train_data_high_bin[0][::subsample] / (2 * math.pi), data_grads, marker=\"x\", c=\"k\")\n", "ax3.set_xscale(\"log\")\n", "ax3.set_ylabel(\"Derivative\")\n", "ax3.set_xlim([1, 1e6])\n", "ax3.set_ylim([-3e-3, 3e-3])\n", "ax3.set_xlabel(\"Frequency (Hz)\")\n", "ax3.set_title(\"Derivative by Training Datapoint\")\n", "\n", "# Show gradient at each training iteration\n", "ax4.plot(np.arange(len(grad_values_high_bin)), grad_values_high_bin, linewidth=3)\n", "cur_iter, = ax4.plot(0, grad_values_high_bin[0], marker=\"o\")\n", "cur_grad_label = ax4.annotate(f\"Grad = {grad_values_high_bin[0]:.2e}\", (0, 0), xytext=(0.65, 0.9), textcoords='axes fraction')\n", "ax4.set_xlabel(\"Training Iteration\")\n", "ax4.set_ylabel(\"Gradient\")\n", "ax4.set_title(\"Gradients\")\n", "ax4.set_xlim([-1, len(grad_values_high_bin)])\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_high_bin(t=0):\n", " mags = evaluate_hp_circuit(ws, R_values_high_bin[t])\n", " learned_tf.set_data(ws / (2 * math.pi), mags)\n", " cutoff = ws[np.argmax(mags > cutoff_mag)]\n", " cut.set_xdata(cutoff / (2 * math.pi))\n", " cur_loss.set_data(R_values_high_bin[t], loss_surface_mse[:, int(R_values_high_bin[t] - 10)].sum(0))\n", " cur_loss_label.set_text(f\"R = {R_values_high_bin[t]:.0f}\")\n", " cur_iter.set_data(t, grad_values_high_bin[t])\n", " cur_grad_label.set_text(f\"Grad = {grad_values_high_bin[t]:.2e}\")\n", " cur_circuit = HighPassCircuit(R_values_high_bin[t])\n", " data_losses = loss_fn(cur_circuit(train_data_high_bin[0][::subsample]), (train_data_high_bin[1][::subsample] > cutoff_mag).float())\n", " data_grads = torch.zeros(len(data_losses))\n", " for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.autograd.grad(dl, cur_circuit.R, retain_graph=True)[0]\n", " data_grads_scat.set_offsets(torch.stack((train_data_high_bin[0][::subsample] / (2 * math.pi), data_grads)).T)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_high_bin, \n", " t=widgets.IntSlider(value=0, min=0, max=len(R_values_high_bin) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "95e0606c", "metadata": {}, "source": [ "## (g) Learning a Band Pass Filter from Binary Data" ] }, { "cell_type": "code", "execution_count": null, "id": "17a763da", "metadata": {}, "outputs": [], "source": [ "# Transfer function: evaluates magnitude of given frequencies for resistor values in the band pass circuit\n", "def evaluate_bp_circuit(freqs, R_low, R_high):\n", " ### YOUR CODE HERE\n", " return \n", " ### END YOUR CODE\n", "\n", "# PyTorch model of the band pass circuit (for training)\n", "class BandPassCircuit(nn.Module):\n", " def __init__(self, R_low=None, R_high=None):\n", " super().__init__()\n", " self.R_low = nn.Parameter(torch.tensor(R_low, dtype=float) if R_low is not None else torch.rand(1) * 1000)\n", " self.R_high = nn.Parameter(torch.tensor(R_high, dtype=float) if R_high is not None else torch.rand(1) * 1000)\n", " \n", " def forward(self, freqs):\n", " return evaluate_bp_circuit(freqs, self.R_low, self.R_high)\n", " \n", "# Generate training data in a uniform log scale of frequences, then evaluate using true transfer function\n", "R_low_des = 1 / (2 * math.pi * 4000 * cap_value)\n", "R_high_des = 1 / (2 * math.pi * 1000 * cap_value)\n", "def generate_bp_training_data(n):\n", " rand_ws = 2 * math.pi * torch.pow(10, torch.rand(n) * 6)\n", " labels = evaluate_bp_circuit(rand_ws, R_low_des, R_high_des)\n", " return rand_ws, labels\n", "\n", "# Train a given low pass filter from binary data\n", "def train_bp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr):\n", " \n", " R_values = [[float(circuit.R_low.data), float(circuit.R_high.data)]]\n", " grad_values = [[np.nan, np.nan]]\n", " train_data = generate_bp_training_data(dataset_size)\n", " print(f\"Initial Resistor Values: R_low = {float(circuit.R_low.data):.0f}, R_high = {float(circuit.R_high.data):.0f}\")\n", " iter_bar = tqdm.trange(max_training_steps, desc=\"Training Iter\")\n", " for i in iter_bar:\n", " pred = circuit(train_data[0])\n", " loss = loss_fn(pred, (train_data[1] > cutoff_mag).float()).mean()\n", " ### YOUR CODE HERE\n", " grad = torch.autograd.grad(?, ?)\n", " ### END YOUR CODE\n", " with torch.no_grad():\n", " ### YOUR CODE HERE\n", " circuit.R_low -= \n", " circuit.R_high -= \n", " ### END YOUR CODE\n", " \n", " R_values.append([float(circuit.R_low.data), float(circuit.R_high.data)])\n", " grad_values.append([float(grad[0].data), float(grad[1].data)])\n", " iter_bar.set_postfix_str(f\"Loss: {float(loss.data):.3f}, R_low={float(circuit.R_low.data):.0f}, R_high={float(circuit.R_high.data):.0f}\")\n", " if loss.data < 1e-6 or (abs(grad[0].data) < 1e-6 and abs(grad[1].data) < 1e-6):\n", " break\n", "\n", " print(f\"Final Resistor Values: R_low = {float(circuit.R_low.data):.0f}, R_high = {float(circuit.R_high.data):.0f}\")\n", " return train_data, R_values, grad_values" ] }, { "cell_type": "code", "execution_count": null, "id": "07eea001", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Create a circuit, use loss_fn with learning rate of 1000\n", "circuit = BandPassCircuit(500, 500)\n", "lr = 1000\n", "train_data_band_bin, R_values_band_bin, grad_values_band_bin = train_bp_circuit_binary(circuit, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "eee25792", "metadata": {}, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 100)\n", "train_data_mask = train_data_band_bin[1][::subsample] > cutoff_mag\n", "ax1.scatter(train_data_band_bin[0][::subsample][train_data_mask] / (2 * math.pi), np.ones(train_data_mask.sum()), c=\"r\", marker=\"x\")\n", "ax1.scatter(train_data_band_bin[0][::subsample][~train_data_mask] / (2 * math.pi), np.zeros((~train_data_mask).sum()), c=\"k\", marker=\"x\")\n", "learned_tf, = ax1.semilogx(ws / (2 * math.pi), evaluate_bp_circuit(ws, *R_values_band_bin[0]), linewidth=3)\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"Magnitude\")\n", "ax1.legend([\"Learned TF\", \"TF + Samples\", \"TF - Samples\"])\n", "\n", "# Show loss surfaces for BCE and MSE Loss\n", "eval_pts = torch.stack(torch.meshgrid((torch.arange(0, 1000, 10), torch.arange(0, 1000, 10)), indexing=\"ij\"))\n", "eval_vals = evaluate_bp_circuit(train_data_band_bin[0][:, None, None], eval_pts[0][None, ...], eval_pts[1][None, ...])\n", "loss_surface = loss_fn(eval_vals, (train_data_band_bin[1][..., None, None].expand(eval_vals.shape) > cutoff_mag).float())\n", "loss_surf = ax2.imshow(torch.log(loss_surface.mean(0)).T, cmap=plt.cm.jet, extent=(0, 1000, 0, 1000), aspect=\"auto\", origin=\"lower\")\n", "cur_loss, = ax2.plot(*R_values_band_bin[0], marker=\"o\")\n", "cur_loss_label = ax2.annotate(f\"R_low = {R_values_band_bin[0][0]:.0f}\\nR_high = {R_values_band_bin[0][1]:.0f}\", (0, 0), xytext=(0.6, 0.85), textcoords='axes fraction')\n", "ax2.set_title(\"Loss Surface\")\n", "ax2.set_xlabel(\"$R_\\mathrm{low} \\; (\\Omega)$\")\n", "ax2.set_ylabel(\"$R_\\mathrm{high} \\; (\\Omega)$\")\n", "fig.colorbar(loss_surf, ax=ax2, label=\"log(loss)\")\n", "\n", "# Show loss contributions of each data point\n", "cur_circuit = BandPassCircuit(*R_values_band_bin[0])\n", "data_losses = loss_fn(cur_circuit(train_data_band_bin[0][::subsample]), (train_data_band_bin[1][::subsample] > cutoff_mag).float())\n", "data_grads = torch.zeros((len(data_losses), 2))\n", "for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.tensor(torch.autograd.grad(dl, (cur_circuit.R_low, cur_circuit.R_high), retain_graph=True))\n", "data_grads_scat1 = ax3.scatter(train_data_band_bin[0][::subsample] / (2 * math.pi), data_grads[:, 0], marker=\"x\")\n", "data_grads_scat2 = ax3.scatter(train_data_band_bin[0][::subsample] / (2 * math.pi), data_grads[:, 1], marker=\"x\")\n", "ax3.set_xscale(\"log\")\n", "ax3.set_ylabel(\"Derivative\")\n", "ax3.set_xlim([1, 1e6])\n", "ax3.set_ylim([-2e-3, 2e-3])\n", "ax3.set_xlabel(\"Frequency (Hz)\")\n", "ax3.set_title(\"Derivative by Training Datapoint\")\n", "ax3.legend([\"$R_\\mathrm{low}$ Derivatives\", \"$R_\\mathrm{high}$ Derivatives\"])\n", "\n", "# Show gradient at each training iteration\n", "ax4.plot(np.arange(len(grad_values_band_bin)), grad_values_band_bin, linewidth=3)\n", "cur_grad0, = ax4.plot(0, grad_values_band_bin[0][0], marker=\"o\")\n", "cur_grad1, = ax4.plot(0, grad_values_band_bin[0][1], marker=\"o\")\n", "ax4.set_xlabel(\"Training Iteration\")\n", "ax4.set_ylabel(\"Gradient\")\n", "ax4.set_title(\"Gradients\")\n", "ax4.set_xlim([-1, len(grad_values_band_bin)])\n", "ax4.legend([\"$R_\\mathrm{low}$ Grad\", \"$R_\\mathrm{high}$ Grad\"])\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_band_bin(t=0):\n", " mags = evaluate_bp_circuit(ws, *R_values_band_bin[t])\n", " learned_tf.set_data(ws / (2 * math.pi), mags)\n", " cur_loss.set_data(*R_values_band_bin[t])\n", " cur_loss_label.set_text(f\"R_low = {R_values_band_bin[t][0]:.0f}\\nR_high = {R_values_band_bin[t][1]:.0f}\")\n", " cur_grad0.set_data(t, grad_values_band_bin[t][0])\n", " cur_grad1.set_data(t, grad_values_band_bin[t][1])\n", " cur_circuit = BandPassCircuit(*R_values_band_bin[t])\n", " data_losses = loss_fn(cur_circuit(train_data_band_bin[0][::subsample]), (train_data_band_bin[1][::subsample] > cutoff_mag).float())\n", " data_grads = torch.zeros((len(data_losses), 2))\n", " for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.tensor(torch.autograd.grad(dl, (cur_circuit.R_low, cur_circuit.R_high), retain_graph=True))\n", " data_grads_scat1.set_offsets(torch.stack((train_data_band_bin[0][::subsample] / (2 * math.pi), data_grads[:, 0])).T)\n", " data_grads_scat2.set_offsets(torch.stack((train_data_band_bin[0][::subsample] / (2 * math.pi), data_grads[:, 1])).T)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_band_bin, \n", " t=widgets.IntSlider(value=0, min=0, max=len(R_values_band_bin) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "18acd86a", "metadata": {}, "source": [ "## (h) Learning a Band Pass Filter Bode Plot from Transfer Function Samples" ] }, { "cell_type": "code", "execution_count": null, "id": "0e7a1ef2", "metadata": {}, "outputs": [], "source": [ "def evaluate_bp_bode(freqs, low_cutoff, high_cutoff):\n", " return -20 * nn.ReLU()(torch.log10(freqs / low_cutoff)) + -20 * nn.ReLU()(torch.log10(high_cutoff / freqs))\n", "\n", "# PyTorch model of the band pass bode plot\n", "class BandPassBodePlot(nn.Module):\n", " def __init__(self, low_cutoff=None, high_cutoff=None):\n", " super().__init__()\n", " self.low_cutoff = nn.Parameter(torch.rand(1) * 5000 if low_cutoff is None else torch.tensor(float(low_cutoff)))\n", " self.high_cutoff = nn.Parameter(torch.rand(1) * 5000 if high_cutoff is None else torch.tensor(float(high_cutoff)))\n", " \n", " def forward(self, freqs):\n", " return evaluate_bp_bode(freqs, self.low_cutoff, self.high_cutoff)\n", "\n", "# Train a given band pass bode plot\n", "def train_bp_bode(bode, loss_fn, dataset_size, max_training_steps, lr):\n", " \n", " cutoff_values = [[float(bode.low_cutoff.data), float(bode.high_cutoff.data)]]\n", " grad_values = [[np.nan, np.nan]]\n", " train_data = generate_bp_training_data(dataset_size)\n", " print(f\"Initial Cutoff Values: f_c,l = {float(bode.low_cutoff.data / (2 * math.pi)):.0f} Hz, f_c,h = {float(bode.high_cutoff.data / (2 * math.pi)):.0f} Hz\")\n", " iter_bar = tqdm.trange(max_training_steps, desc=\"Training Iter\")\n", " for i in iter_bar:\n", " \n", " pred = bode(train_data[0])\n", " loss = loss_fn(pred, 20 * torch.log10(train_data[1])).mean()\n", " grad = torch.autograd.grad(loss, (bode.low_cutoff, bode.high_cutoff))\n", " with torch.no_grad():\n", " bode.low_cutoff -= lr * grad[0]\n", " bode.high_cutoff -= lr * grad[1]\n", " \n", " cutoff_values.append([float(bode.low_cutoff.data), float(bode.high_cutoff.data)])\n", " grad_values.append([float(grad[0].data), float(grad[1].data)])\n", " iter_bar.set_postfix_str(f\"Loss: {float(loss.data):.3f}, f_c,l = {float(bode.low_cutoff.data / (2 * math.pi)):.0f} Hz, f_c,h = {float(bode.high_cutoff.data / (2 * math.pi)):.0f} Hz\")\n", " if loss.data < 1e-6 or (abs(grad[0].data) < 1e-6 and abs(grad[1].data) < 1e-6):\n", " break\n", "\n", " print(f\"Final Cutoff Values: f_c,l = {float(bode.low_cutoff.data / (2 * math.pi)):.0f} Hz, f_c,h = {float(bode.high_cutoff.data / (2 * math.pi)):.0f} Hz\")\n", " return train_data, cutoff_values, grad_values" ] }, { "cell_type": "code", "execution_count": null, "id": "e74b7515", "metadata": { "scrolled": true }, "outputs": [], "source": [ "bode = BandPassBodePlot()\n", "loss_fn = lambda x, y: (x - y) ** 2 # MSE loss\n", "lr = 1000\n", "train_data_band_bode, cutoffs_band_bode, grad_values_band_bode = train_bp_bode(bode, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "c4352b63", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 100)\n", "train_data_mask = train_data_band_bode[1][::subsample] > cutoff_mag\n", "ax1.scatter(train_data_band_bode[0][::subsample]/ (2 * math.pi), 20 * torch.log10(train_data_band_bode[1][::subsample]), c=\"k\", marker=\"x\")\n", "learned_tf, = ax1.semilogx(ws / (2 * math.pi), evaluate_bp_bode(ws, *cutoffs_band_bode[0]), linewidth=3)\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"dB\")\n", "ax1.legend([\"Learned Bode Plot\", \"TF Samples\"])\n", "\n", "# Show loss surfaces for BCE and MSE Loss\n", "eval_pts = torch.stack(torch.meshgrid((torch.arange(1, 5001, 50), torch.arange(1, 5001, 50)), indexing=\"ij\"))\n", "eval_vals = evaluate_bp_bode(train_data_band_bode[0][:, None, None], 2 * math.pi * eval_pts[0][None, ...], 2 * math.pi * eval_pts[1][None, ...])\n", "loss_surface = loss_fn(eval_vals, 20 * torch.log10(train_data_band_bode[1])[..., None, None].expand(eval_vals.shape))\n", "loss_surf = ax2.imshow(torch.log(loss_surface.mean(0)).T, cmap=plt.cm.jet, extent=(1, 5000, 1, 5000), aspect=\"auto\", origin=\"lower\")\n", "cur_loss, = ax2.plot(cutoffs_band_bode[0][0] / (2 * math.pi), cutoffs_band_bode[0][1] / (2 * math.pi), marker=\"o\")\n", "cur_loss_label = ax2.annotate(f\"$f_{{c,l}}$ = {cutoffs_band_bode[0][0]:.0f}\\n$f_{{c,h}}$ = {cutoffs_band_bode[0][1]:.0f}\", (0, 0), xytext=(0.7, 0.82), textcoords='axes fraction')\n", "ax2.set_title(\"Loss Surface\")\n", "ax2.set_xlabel(\"$f_\\mathrm{c,low} \\; (Hz)$\")\n", "ax2.set_ylabel(\"$f_\\mathrm{c,high} \\; (Hz)$\")\n", "fig.colorbar(loss_surf, ax=ax2, label=\"log(loss)\")\n", "\n", "# Show loss contributions of each data point\n", "cur_bode = BandPassBodePlot(*cutoffs_band_bode[0])\n", "data_losses = loss_fn(cur_bode(train_data_band_bode[0][::subsample]), 20 * torch.log10(train_data_band_bode[1][::subsample]))\n", "data_grads = torch.zeros((len(data_losses), 2))\n", "for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.tensor(torch.autograd.grad(dl, (cur_bode.low_cutoff, cur_bode.high_cutoff), retain_graph=True))\n", "data_grads_scat1 = ax3.scatter(train_data_band_bode[0][::subsample] / (2 * math.pi), data_grads[:, 0], marker=\"x\")\n", "data_grads_scat2 = ax3.scatter(train_data_band_bode[0][::subsample] / (2 * math.pi), data_grads[:, 1], marker=\"x\")\n", "ax3.set_xscale(\"log\")\n", "ax3.set_ylabel(\"Derivative\")\n", "ax3.set_xlim([1, 1e6])\n", "ax3.set_ylim([-5e-3, 5e-3])\n", "ax3.set_xlabel(\"Frequency (Hz)\")\n", "ax3.set_title(\"Derivative by Training Datapoint\")\n", "ax3.legend([\"$f_{c,l}$ Derivatives\", \"$f_{c,h}$ Derivatives\"])\n", "\n", "# Show gradient at each training iteration\n", "ax4.plot(np.arange(len(grad_values_band_bode)), grad_values_band_bode, linewidth=3)\n", "cur_grad0, = ax4.plot(0, grad_values_band_bode[0][0], marker=\"o\")\n", "cur_grad1, = ax4.plot(0, grad_values_band_bode[0][1], marker=\"o\")\n", "ax4.set_xlabel(\"Training Iteration\")\n", "ax4.set_ylabel(\"Gradient\")\n", "ax4.set_title(\"Gradients\")\n", "ax4.set_xlim([-1, len(grad_values_band_bode)])\n", "ax4.legend([\"$f_\\mathrm{c,l}$ Grad\", \"$f_\\mathrm{c,h}$ Grad\"])\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_band_bode(t=0):\n", " learned_tf.set_data(ws / (2 * math.pi), evaluate_bp_bode(ws, *cutoffs_band_bode[t]))\n", " cur_loss.set_data(cutoffs_band_bode[t][0] / (2 * math.pi), cutoffs_band_bode[t][1] / (2 * math.pi))\n", " cur_loss_label.set_text(f\"$f_{{c,l}}$ = {cutoffs_band_bode[t][0] / (2 * math.pi):.0f}\\n$f_{{c,h}}$ = {cutoffs_band_bode[t][1] / (2 * math.pi):.0f}\")\n", " cur_grad0.set_data(t, grad_values_band_bode[t][0])\n", " cur_grad1.set_data(t, grad_values_band_bode[t][1])\n", " cur_bode = BandPassBodePlot(*cutoffs_band_bode[t])\n", " data_losses = loss_fn(cur_bode(train_data_band_bode[0][::subsample]), 20 * torch.log10(train_data_band_bode[1][::subsample]))\n", " data_grads = torch.zeros((len(data_losses), 2))\n", " for i, dl in enumerate(data_losses):\n", " data_grads[i] = torch.tensor(torch.autograd.grad(dl, (cur_bode.low_cutoff, cur_bode.high_cutoff), retain_graph=True))\n", " data_grads_scat1.set_offsets(torch.stack((train_data_band_bode[0][::subsample] / (2 * math.pi), data_grads[:, 0])).T)\n", " data_grads_scat2.set_offsets(torch.stack((train_data_band_bode[0][::subsample] / (2 * math.pi), data_grads[:, 1])).T)\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_band_bode, \n", " t=widgets.IntSlider(value=0, min=0, max=len(cutoffs_band_bode) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "09a08e57", "metadata": {}, "source": [ "## (i) Learn a Color Organ Circuit" ] }, { "cell_type": "code", "execution_count": null, "id": "0fd16f07", "metadata": {}, "outputs": [], "source": [ "# PyTorch model of the color organ circuit\n", "class ColorOrganCircuit(nn.Module):\n", " def __init__(self, R_low=None, R_high=None, R_band_low=None, R_band_high=None):\n", " super().__init__()\n", " self.low = LowPassCircuit(R_low)\n", " self.high = HighPassCircuit(R_high)\n", " self.band = BandPassCircuit(R_band_low, R_band_high)\n", " \n", " def forward(self, freqs):\n", " return torch.stack((self.low(freqs), self.band(freqs), self.high(freqs)))\n", " \n", " \n", "# Generate training data in a uniform log scale of frequences, then evaluate using the true transfer function\n", "R_low_des = 1 / (2 * math.pi * 800 * cap_value)\n", "R_band_low_des = 1 / (2 * math.pi * 4000 * cap_value)\n", "R_band_high_des = 1 / (2 * math.pi * 1000 * cap_value)\n", "R_high_des = 1 / (2 * math.pi * 5000 * cap_value)\n", "def generate_co_training_data(n):\n", " rand_ws = 2 * math.pi * torch.pow(10, torch.rand(n) * 6)\n", " labels = torch.stack((evaluate_lp_circuit(rand_ws, R_low_des), evaluate_bp_circuit(rand_ws, R_band_low_des, R_band_high_des), evaluate_hp_circuit(rand_ws, R_high_des)))\n", " return rand_ws, labels\n", "\n", "# Train a given color organ circuit\n", "def train_co_circuit(circuit, loss_fn, dataset_size, max_training_steps, lr):\n", " \n", " R_values = [[float(circuit.low.R.data), float(circuit.band.R_low.data), float(circuit.band.R_high.data), float(circuit.high.R.data)]]\n", " grad_values = [[np.nan, np.nan, np.nan, np.nan]]\n", " train_data = generate_co_training_data(dataset_size)\n", " print(f\"Initial Resistor Values: LP: {float(circuit.low.R.data):.0f} Ohms, BP (Low): {float(circuit.band.R_low.data):.0f} Ohms, BP (High): {float(circuit.band.R_high.data):.0f} Ohms, HP: {float(circuit.high.R.data):.0f} Ohms\")\n", " \n", " iter_bar = tqdm.trange(max_training_steps, desc=\"Training Iter\")\n", " for i in iter_bar:\n", " pred = circuit(train_data[0])\n", " loss = loss_fn(pred, (train_data[1] > cutoff_mag).float()).mean()\n", " grad = torch.autograd.grad(loss, (circuit.low.R, circuit.band.R_low, circuit.band.R_high, circuit.high.R))\n", " with torch.no_grad():\n", " circuit.low.R -= lr * grad[0]\n", " circuit.band.R_low -= lr * grad[1]\n", " circuit.band.R_high -= lr * grad[2]\n", " circuit.high.R -= lr * grad[3]\n", " \n", " R_values.append([float(circuit.low.R.data), float(circuit.band.R_low.data), float(circuit.band.R_high.data), float(circuit.high.R.data)])\n", " grad_values.append([float(grad[0].data), float(grad[1].data), float(grad[2].data), float(grad[3].data)])\n", " iter_bar.set_postfix_str(f\"Loss: {float(loss.data):.3f}, Rs = {float(circuit.low.R.data):.0f}, {float(circuit.band.R_low.data):.0f}, {float(circuit.band.R_high.data):.0f}, {float(circuit.high.R.data):.0f}\")\n", " if loss.data < 1e-6 or (abs(grad[0].data) < 1e-6 and abs(grad[1].data) < 1e-6):\n", " break\n", "\n", " print(f\"Final Resistor Values: LP: {float(circuit.low.R.data):.0f} Ohms, BP (Low): {float(circuit.band.R_low.data):.0f} Ohms, BP (High): {float(circuit.band.R_high.data):.0f} Ohms, HP: {float(circuit.high.R.data):.0f} Ohms\")\n", " print(f\"Final Cutoff Frequencies: LP: {1 / (2 * math.pi * cap_value * float(circuit.low.R.data)):.0f} Hz, BP (Low): {1 / (2 * math.pi * cap_value * float(circuit.band.R_low.data)):.0f} Hz, BP (High): {1 / (2 * math.pi * cap_value * float(circuit.band.R_high.data)):.0f} Hz, HP: {1 / (2 * math.pi * cap_value * float(circuit.high.R.data)):.0f} Hz\")\n", " return train_data, R_values, grad_values" ] }, { "cell_type": "code", "execution_count": null, "id": "4b7ffa2b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "co = ColorOrganCircuit(200, 200, 200, 200)\n", "loss_fn = lambda x, y: (x - (0.3 + 0.7 * y)) ** 2 # weighted MSE loss\n", "lr = 500\n", "train_data_co, R_values_co, grad_values_co = train_co_circuit(co, loss_fn, dataset_size, max_training_steps, lr)" ] }, { "cell_type": "code", "execution_count": null, "id": "17593172", "metadata": {}, "outputs": [], "source": [ "# Plot transfer function over training\n", "fig, ax1 = plt.subplots(1, 1, figsize=(9, 6))\n", "ws = 2 * math.pi * 10 ** torch.linspace(0, 6, 1000)\n", "subsample = int(dataset_size / 250)\n", "train_data_mask = train_data_co[1][:, ::subsample] > cutoff_mag\n", "learned_tf1, = ax1.semilogx(ws / (2 * math.pi), evaluate_lp_circuit(ws, R_values_co[0][0]), linewidth=3)\n", "learned_tf2, = ax1.semilogx(ws / (2 * math.pi), evaluate_bp_circuit(ws, *R_values_co[0][1:3]), linewidth=3)\n", "learned_tf3, = ax1.semilogx(ws / (2 * math.pi), evaluate_hp_circuit(ws, R_values_co[0][-1]), linewidth=3)\n", "ax1.scatter(train_data_co[0][::subsample][train_data_mask[0]] / (2 * math.pi), np.ones(train_data_mask[0].sum()), c=learned_tf1.get_color(), marker=\"x\")\n", "ax1.scatter(train_data_co[0][::subsample][train_data_mask[1]] / (2 * math.pi), np.ones(train_data_mask[1].sum()), c=learned_tf2.get_color(), marker=\"x\")\n", "ax1.scatter(train_data_co[0][::subsample][train_data_mask[2]] / (2 * math.pi), np.ones(train_data_mask[2].sum()), c=learned_tf3.get_color(), marker=\"x\")\n", "# ax1.scatter(train_data_co[0][::subsample][(~train_data_mask).all(0)] / (2 * math.pi), np.zeros((~(train_data_mask.any(0))).sum()), c=\"k\", marker=\"x\")\n", "ax1.set_xlim([1, 1e6])\n", "ax1.set_title(\"Transfer Function\")\n", "ax1.set_xlabel(\"Frequency (Hz)\")\n", "ax1.set_ylabel(\"Magnitude\")\n", "ax1.legend([\"Learned LP\", \"Learned BP\", \"Learned HP\", \n", " \"TF + Samples (LP)\", \"TF + Samples (BP)\", \"TF + Samples (HP)\", \n", " \"TF - Samples\"], bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1)\n", "\n", "plt.tight_layout()\n", "\n", "# Main update function for interactive plots\n", "def update_iter_co(t=0):\n", " learned_tf1.set_data(ws / (2 * math.pi), evaluate_lp_circuit(ws, R_values_co[t][0]))\n", " learned_tf2.set_data(ws / (2 * math.pi), evaluate_bp_circuit(ws, *R_values_co[t][1:3]))\n", " learned_tf3.set_data(ws / (2 * math.pi), evaluate_hp_circuit(ws, R_values_co[t][-1]))\n", " fig.canvas.draw_idle()\n", " \n", "# Include sliders for relevant quantities\n", "ip = interactive(update_iter_co, \n", " t=widgets.IntSlider(value=0, min=0, max=len(R_values_co) - 1, step=1, description=\"Training Iteration\", style={'description_width': 'initial'}, layout=Layout(width='100%')))\n", "ip" ] }, { "cell_type": "markdown", "id": "d6cfaa10", "metadata": {}, "source": [ "## Visualizing the computation graph for the Color Organ" ] }, { "cell_type": "code", "execution_count": null, "id": "2521bdf5", "metadata": {}, "outputs": [], "source": [ "from torchviz import make_dot\n", "make_dot(co(generate_co_training_data(dataset_size)[0]), params=dict(co.named_parameters()))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.0 ('eecs16b')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" }, "vscode": { "interpreter": { "hash": "645b9769f00f9f3c5f7ee6e4081596b29edeb96d0eb359a6ef047ae12ce17c20" } } }, "nbformat": 4, "nbformat_minor": 5 }