ML 181: B-Splines for Kolmogorov-Arnold Networks (KANs) (15 pts)

What You Need

Purpose

To learn the fundamentals of B-Splines, which are used in Kolmogorov-Arnold Networks (KANs), which are an alternative to traditional neural networks, which are technically called Multilayer Perceptrons.

KANs take longer to train, but they have the advantages of being far more interpretable and faster to use, because they find an analytical equation as a final model, which is typically just a series of polynomials.

Using Google Colab

In a browser, go to
https://colab.research.google.com/
If you see a blue "Sign In" button at the top right, click it and log into a Google account.

From the menu, click File, "New notebook".

Importing Libraries

Execute these commands:
import torch
import torch.nn as nn
import numpy as np

import matplotlib.pyplot as plt
The commands complete, without any output, as shown below.

Fitting with B-Splines

We'll fit a series of "control points" to a curve using B-splines.

For this example, we'll use 20 control points evenly spaced in the range from x = -1 to x = 1.

Execute these commands:

grid = torch.linspace(-1, 1, steps=20) 
x = torch.linspace(-1, 1, steps=1000) # we take the entire domain to plot the basis function as a function of x

# Reshape so that each x can be compared to each control point
grid_ = grid.unsqueeze(dim=0)
x_ = x.unsqueeze(dim=1)

k = 0 # This gives us basis functions for order 0
value1 = (x_ >= grid_[:,:-1]) * (x_ < grid_[:, 1:])

k = 1 # This gives us basis functions for order 1
value21 = (x_ - grid_[:, :-(k+1)]) / (grid_[:, 1:-k] - grid_[:, :-(k+1)]) * value1[:, :-1]
value22 = (grid_[:, (k+1):] - x_) / (grid_[:, (k+1):] - grid_[:, 1:-k]) * value1[:, 1:]
value2 = value21 + value22

k = 2 # This gives us basis functions for order 2
value31 = (x_ - grid_[:, :-(k+1)]) / (grid_[:, 1:-k] - grid_[:, :-(k+1)]) * value2[:, :-1]
value32 = (grid_[:, (k+1):] - x_) / (grid_[:, (k+1):] - grid_[:, 1:-k]) * value2[:, 1:]
value3 = value31 + value32

k = 3 # This gives us basis functions for order 3
value41 = (x_ - grid_[:, :-(k+1)]) / (grid_[:, 1:-k] - grid_[:, :-(k+1)]) * value3[:, :-1]
value42 = (grid_[:, (k+1):] - x_) / (grid_[:, (k+1):] - grid_[:, 1:-k]) * value3[:, 1:]
value4 = value41 + value42

fn = lambda x: torch.sin(torch.pi * (x+1)* 8)*(x+1)

all_basis = [(value1 * 1.0, 'k=0'), (value2, 'k=1'), (value3, 'k=2'), (value4, 'k=3')]

fig, axs = plt.subplots(nrows=len(all_basis), figsize=(15, 15), dpi=100)

for i in range(len(all_basis)):
        ax = axs[i] 
        value, label = all_basis[i]
        
        y_correct = fn(x)
        coeff = torch.linalg.lstsq(value, y_correct).solution # find the coefficients 
        y_pred = torch.einsum('i, ji -> j', coeff, value) # use these coefficients to evaluate y
        
        residuals = 0
        for k in range(150, len(y_correct) - 150):
          residuals += (y_pred[k] - y_correct[k]) ** 2
        
        ax.scatter(x, y_correct, color='green', label='target')
        ax.scatter(x, y_pred, color='blue', label="predicted")

        ax.grid()
        ax.set_title(f"{label}" + " Residuals:" + str(round(residuals.item(), 2)))
        ax.legend()
The result is four graphs, as shown below.

Order 0

B-Splies with order 0 simply fit the points by assigning the same y-value as the nearest control point.

This approximates the curve as a series of horizontal lines, as shown below.

This is a very sloppy fit, missing a lot of the actual curves in the data.

Notice the Residuals value: this is the sum of the squared errors, for the center portion of the data, to exclude problems near the edges.

Order 1

B-Splies with order 1 draw straight lines between adjacent control points.

This is better than order 0, with a smaller Residuals value, but still misses some of the curves, as shown below.

Order 2

B-Splies with order 2 fit parabolas through each group of three adjacent control points.

This fit still misses badly on the right side of the chart, as shown below.

Order 3

B-Splies with order 3 fit cubic curves through each set of four adjacent control points.

This is a better fit, with the lowest Residials value of all, as shown below.

The bad errors on the right side occur because there aren't any control points for x > 1, so there are errors at the boundary.

B-Spline Basis Functions

B-splines use one "basis function" for each group of control points, and the final fit is just a sum of various weights multiplied by the basis fucnctions.

To see the basis functions, execute these commands:

# Here, we define the control points. 
grid = torch.linspace(-1, 1, steps=10)

# we take the entire domain to plot the basis function as a function of x
x = torch.linspace(-1, 1, steps=1000) 

# Reshape so that each x can be compared to each control point
grid_ = grid.unsqueeze(dim=0)
x_ = x.unsqueeze(dim=1)

# Base case
k = 0
value1 = (x_ >= grid_[:,:-1]) * (x_ < grid_[:, 1:])

# Other cases
k = 1 # This gives us basis functions for order 1
value21 = (x_ - grid_[:, :-(k+1)]) / (grid_[:, 1:-k] - grid_[:, :-(k+1)]) * value1[:, :-1]
value22 = (grid_[:, (k+1):] - x_) / (grid_[:, (k+1):] - grid_[:, 1:-k]) * value1[:, 1:]
value2 = value21 + value22

k = 2 # This gives us basis functions for order 2
value31 = (x_ - grid_[:, :-(k+1)]) / (grid_[:, 1:-k] - grid_[:, :-(k+1)]) * value2[:, :-1]
value32 = (grid_[:, (k+1):] - x_) / (grid_[:, (k+1):] - grid_[:, 1:-k]) * value2[:, 1:]
value3 = value31 + value32

k = 3 # This gives us basis functions for order 3
value41 = (x_ - grid_[:, :-(k+1)]) / (grid_[:, 1:-k] - grid_[:, :-(k+1)]) * value3[:, :-1]
value42 = (grid_[:, (k+1):] - x_) / (grid_[:, (k+1):] - grid_[:, 1:-k]) * value3[:, 1:]
value4 = value41 + value42

fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(9, 9), dpi=100)

n_basis_to_plot = 4
all_basis = [(value1, 'k=0'), (value2, 'k=1'), (value3, 'k=2'), (value4, 'k=3')]
for i in range(4):
    
    ax = axs[i // 2, i % 2] 

    value, label = all_basis[i]
    for idx in range(value.shape[-1])[:n_basis_to_plot]:
        ax.scatter(x, value[:, idx], marker='x', label=f"Basis fn {idx}", alpha=0.5)
    ax.grid()
    ax.set_title(label)
    ax.legend()
The output shows four graphs, as shown below.

Orders 0 and 1

The chart below shows the first four basis functions for orders 0 and 1.

Order zero uses basis functions that are 1 near the control point, and 0 elsewhere.

For order 1, the basis functions are pyramids with linear sides, peaking at the location of the control points.

Orders 2 and 3

Orders 2 and 3 use basis functions that are smooth curves peaking at each control point. They are very similar, with order 3 just a bit wider, as shown below.

Interpretability

The weights in the KAN model are simply the scaling factors of the basis functions, so each weight depends only on a small region of inputs. This means it's far easier to understand how the model works, and what each weight does, than in traditional neural net models.

This means that the model can be controlled and adjusted more easily and reliably. KAN models resist the problem of "catastrophic forgetting", when neural nets that are fine-tuned on new data forget what they learned from the old data.

ML 181.1: Adding Noise (5 pts)

Repace the fn definition with this:
torch.manual_seed(0)
fn = lambda x: torch.sin(torch.pi * (x+1)* 8)*(x+1) + 0.1 * torch.randn(x.shape)
Fit the data with B-splines as you did before.

The flag is covered by a green rectangle in the image below.

ML 181.2: FM (5 pts)

Repace the fn definition with this:
torch.manual_seed(0)
fn = lambda x: torch.sin(torch.pi * (x+1)*(x+1)) + 0.1 * torch.randn(x.shape)
Fit the data with B-splines.

The flag is covered by a green rectangle in the image below.

ML 181.3: Wild (5 pts)

Repace the fn definition with this:
torch.manual_seed(0)
fn = lambda x: torch.sin(torch.pi * (x+1)*(x+1)*8) + 0.1 * torch.randn(x.shape)
Use a grid with 50 steps rather than 20.

Fit the data with B-splines.

The flag is covered by a green rectangle in the image below.

References

B-Splines for KAN

Posted 8-8-24
Flag 1 answer updated 5-16-25