Image by Editor (Kanwal Mehreen) | Canva
Machine learning (ML) is mostly done using Python. Python is popular because it’s easy to learn and has many ML libraries. But now, Rust is becoming a strong alternative. Rust is fast, safe with memory, and good at handling many tasks simultaneously. These features make Rust great for high-performance ML.
Linfa is a library in Rust that helps you build ML models. It makes it easier to create and use ML models in Rust. In this article, we will show you how to do two ML tasks: linear regression and k-means clustering using Linfa.
Why Rust for Machine Learning?
Rust is increasingly being considered for machine learning due to several advantages:
- Performance: Rust is a compiled language, which gives it performance characteristics close to C and C++. Its low-level control over system resources and absence of a garbage collector make it perfect for performance-critical applications like machine learning.
- Memory Safety: One of Rust’s standout features is its ownership model, which guarantees memory safety without the need for a garbage collector. This eliminates many common programming bugs such as null pointer dereferencing or data races.
- Concurrency: Rust’s concurrency model ensures safe parallel processing. Machine learning often involves large datasets and heavy computations. Rust handles multi-threaded operations efficiently. Its ownership system prevents data races and memory issues.
What is Linfa?
Linfa is a machine learning library for Rust. It offers a variety of ML algorithms, much like Python’s scikit-learn. The library integrates well with Rust’s ecosystem. It enables high-performance data manipulation, statistics, and optimization. Linfa includes algorithms like linear regression, k-means clustering, and support vector machines. These implementations are efficient and easy to use. Developers can build powerful machine learning models with Rust’s speed and safety.
Let’s explore how to use Linfa to build machine learning models with two simple yet essential examples: linear regression and k-means clustering.
Setting Up the Environment
First, ensure that you have Rust installed. If not, use the following command to install it via rustup:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
Next, add Linfa and related dependencies to your project. Open your Cargo.toml file and include the following:
[dependencies]
linfa = "0.5.0"
linfa-linear = "0.5.0" # For linear regression
linfa-clustering = "0.5.0" # For k-means clustering
ndarray = "0.15.4" # For numerical operations
ndarray-rand = "0.14.0" # For random number generation
With this setup, you are ready to implement machine learning models using Linfa.
Linear Regression in Rust
Linear regression is one of the simplest and most commonly used supervised learning algorithms. It models the relationship between a dependent variable 𝑦 and one or more independent variables 𝑥 by fitting a linear equation to the observed data. In this section, we’ll explore how to implement linear regression using Rust’s Linfa library.
Prepare the Data
To understand and test linear regression, we need to start with a dataset.
use ndarray::Array2, Axis;
fn generate_data() -> Array2 2.0 * v + 1.0);
let data = ndarray::stack(ndarray::Axis(1), &[x.view(), y.view()]).unwrap();
data
Here, we simulate a simple dataset where the relationship between 𝑥 and 𝑦 follows the formula: y=2x+1.
Train the Model
After preparing the dataset, we train the model using Linfa’s LinearRegression module. Training involves determining the coefficients of the linear equation (y=mx+c) by minimizing the error between predicted and actual values. Using Linfa’s LinearRegression module, we train a regression model on this dataset.
use linfa::prelude::*;
use linfa_linear::LinearRegression;
fn train_model(data: Array2) -> LinearRegression
let (x, y) = (data.slice(s![.., 0..1]), data.slice(s![.., 1..2]));
LinearRegression::default().fit(&x, &y).unwrap()
Key Points:
- The fit method learns the slope and intercept of the line that best fits the data.
- unwrap handles any errors that might occur during training.
Make Predictions
After training the model, we can use it to predict results for new data.
fn make_predictions(model: &LinearRegression, input: Array2) -> Array2
model.predict(&input)
fn main()
let data = generate_data();
let model = train_model(data);
let input = Array2::from_shape_vec((5, 1), vec![11.0, 12.0, 13.0, 14.0, 15.0]).unwrap();
let predictions = make_predictions(&model, input);
println!("Predictions: :?", predictions);
For input values [11.0, 12.0, 13.0, 14.0, 15.0], the predictions will be:
Predictions: [[23.0], [25.0], [27.0], [29.0], [31.0]]
This output corresponds to y=2x+1.
K-Means Clustering in Rust
K-means clustering is an unsupervised learning algorithm that partitions data into k clusters based on similarity.
Prepare the Data
To demonstrate K-means clustering, we generate a random dataset using the ndarray-rand crate.
use ndarray::Array2;
use ndarray_rand::RandomExt;
use rand_distr::Uniform;
fn generate_random_data() -> Array2
let dist = Uniform::new(0.0, 10.0);
Array2::random((100, 2), dist)
This creates a 100×2 matrix of random points, simulating two-dimensional data.
Train the Model
The train_kmeans_model function uses Linfa’s KMeans module to group the data into k=3 clusters.
use linfa_clustering::KMeans;
use linfa::traits::Fit;
fn train_kmeans_model(data: Array2) -> KMeans
KMeans::params(3).fit(&data).unwrap()
Key Points:
- KMeans::params(3) specifies 3 clusters.
- The fit method learns cluster centroids based on the data.
Assign Clusters
After training, we can assign each data point to one of the clusters.
fn assign_clusters(model: &KMeans, data: Array2)
let labels = model.predict(&data);
println!("Cluster Labels: :?", labels);
fn main()
let data = generate_random_data();
let model = train_kmeans_model(data);
assign_clusters(&model, data);
The output will display the cluster labels assigned to each data point. Each label will correspond to one of the three clusters.
Conclusion
Rust is a great choice for creating fast machine learning models. It makes sure there are no mistakes when handling data by being memory safe. Rust can also use multiple threads at once, which is important when working with large datasets in machine learning.
The Linfa library makes machine learning in Rust even easier. It helps you easily use algorithms like linear regression and k-means clustering. Rust’s ownership system keeps memory safe without needing garbage collection. Its ability to handle multiple threads prevents errors when working with large amounts of data.
Jayita Gulati is a machine learning enthusiast and technical writer driven by her passion for building machine learning models. She holds a Master’s degree in Computer Science from the University of Liverpool.