Using PySpark, Spark MLlib, and the Databricks Lakehouse Platform, this article demonstrates how to build a scalable taxi fare prediction model from the ground up. We begin by ingesting and preprocessing large-scale trip data using PySpark’s distributed computing capabilities. Through feature engineering and exploratory data analysis, we prepare the dataset for regression modeling. Leveraging Spark MLlib’s pipeline architecture, we train and evaluate a fare prediction model at scale. Finally, we integrate the workflow within the Databricks environment to enable reproducibility, performance optimization, and future deployment. This end-to-end approach showcases how modern big data in the cloud tools streamline the development of intelligent transportation analytics.
Accurate fare prediction is critical for taxi services and ride-sharing platforms. It sets customer expectations, improves service transparency, and helps drivers optimize their routes. But predicting fares accurately is complex — it’s influenced by multiple factors including distance, time of day, passenger count, and even weather conditions.
Our objective is to build a machine learning pipeline that:
- Analyzes patterns in historical taxi trip data
- Identifies the key factors influencing fare amounts
- Creates a model that accurately predicts fares for new trips
- Provides insights for business optimization
For this project, I used the New York City taxi trip dataset containing millions of records with timestamps, location IDs, trip distances, passenger counts, and fare amounts. Each record represents a completed taxi journey with the following key features:
- Pickup and dropoff timestamps
- Trip distance
- Passenger count
- Rate code (standard, airport, etc.)
- Payment type
- Fare amount and additional charges
The project leverages Databricks and Apache Spark for distributed computing, handling the large volume of data efficiently. Here’s what our tech stack looks like:
- Databricks: Cloud-based platform for collaborative data science
- PySpark: Python API for Apache Spark
- MLlib: Spark’s machine learning library
- Pandas: For easier data manipulation and visualization
- Matplotlib & Seaborn: For data visualization
Before diving into the modeling phase, I needed to transform the raw data into a format suitable for machine learning. Here’s what the cleaning process involved:
# Explicitly cast column types to avoid StringIndexer issues
df_clean = df.withColumn("passenger_count", col("passenger_count").cast(IntegerType()))
df_clean = df_clean.withColumn("RatecodeID", col("RatecodeID").cast(IntegerType()))
# Fill nulls for columns used in modeling
df_clean = df_clean.na.fill(1, ["passenger_count"])
df_clean = df_clean.na.fill(1, ["RatecodeID"])
Outliers can significantly impact model performance, especially for regression tasks. I filtered out trips with:
- Negative fare amounts
- Unreasonably long trip durations (>24 hours)
- Extremely long distances (>100 miles)
- Unrealistic speeds (>80 mph)
# Filter out unreasonable durations and distances
df_clean = df_clean.filter(
(df_clean.trip_duration_minutes > 0) &
(df_clean.trip_duration_minutes < 24 * 60) &
(df_clean.trip_distance >= 0) &
(df_clean.trip_distance < 100)
)
To enhance the predictive power of the model, I created several derived features:
# Create time-based features
df_clean = df_clean.withColumn("pickup_hour", hour("tpep_pickup_datetime")) \
.withColumn("pickup_day", dayofweek("tpep_pickup_datetime"))
# Calculate trip duration in minutes
df_clean = df_clean.withColumn(
"trip_duration_minutes",
round((unix_timestamp("tpep_dropoff_datetime") - unix_timestamp("tpep_pickup_datetime")) / 60, 2)
)# Calculate speed (mph)
df_clean = df_clean.withColumn(
"avg_speed_mph",
when(col("trip_duration_minutes") > 0,
round(col("trip_distance") / (col("trip_duration_minutes") / 60), 2))
.otherwise(0)
)# Create rush hour flag
df_clean = df_clean.withColumn(
"is_rush_hour",
(((col("pickup_hour").between(7, 9)) | (col("pickup_hour").between(16, 18))) &
(col("pickup_day").between(2, 6)))
.cast("int")
)
Understanding the data was essential before building the model. Here are some key insights:
Analyzing the fare distribution revealed a right-skewed pattern, with most trips costing between $5 and $20, though there were some outliers with significantly higher fares.
As expected, trip distance and fare amount showed a strong positive correlation with an R² value of 0.86. However, the relationship wasn’t perfectly linear, suggesting other influencing factors.
Fares varied by time of day, with peak hours showing higher average fares. Weekend trips tended to be longer and more expensive than weekday trips.
# Fare by hour of day
hourly_fares = df_clean.groupBy("pickup_hour").agg(
"fare_amount": "avg", "trip_distance": "avg"
).orderBy("pickup_hour")
# Visualize
plt.figure(figsize=(12, 6))
plt.plot(hourly_fares_pd['pickup_hour'], hourly_fares_pd['avg(fare_amount)'], marker='o')
plt.title('Average Fare by Hour of Day')
plt.xlabel('Hour of Day')
plt.ylabel('Average Fare ($)')
With a clean dataset and good understanding of the key relationships, I constructed a machine learning pipeline using PySpark’s ML libraries.
# Prepare feature columns
categorical_cols = ["passenger_count", "RatecodeID", "payment_type", "is_rush_hour"]
numeric_cols = ["trip_distance", "trip_duration_minutes", "avg_speed_mph", "pickup_hour"]
# Create StringIndexer transformers with handleInvalid="keep"
indexers = [
StringIndexer(
inputCol=col_name,
outputCol=f"col_name_idx",
handleInvalid="keep"
) for col_name in categorical_cols
]# VectorAssembler to combine all features
assembler = VectorAssembler(
inputCols=[f"col_idx" for col in categorical_cols] + numeric_cols,
outputCol="features",
handleInvalid="keep"
)
I evaluated three different algorithms to find the best performer:
- Random Forest Regression: Known for handling non-linear relationships
- Gradient Boosted Trees: Often provides high accuracy for structured data
- Linear Regression: Simple but sometimes effective for clear relationships
The models were trained using an 80/20 train-test split:
# Split data
train_data, test_data = df_clean.randomSplit([0.8, 0.2], seed=42)
# Define multiple models
rf = RandomForestRegressor(
featuresCol="features",
labelCol="fare_amount",
numTrees=20,
maxDepth=10
)gbt = GBTRegressor(
featuresCol="features",
labelCol="fare_amount",
maxIter=10,
maxDepth=8
)lr = LinearRegression(
featuresCol="features",
labelCol="fare_amount",
regParam=0.1,
elasticNetParam=0.8
)# Define pipelines for each model
pipeline_rf = Pipeline(stages=indexers + [assembler, rf])
pipeline_gbt = Pipeline(stages=indexers + [assembler, gbt])
pipeline_lr = Pipeline(stages=indexers + [assembler, lr])
After training and evaluating each model on the test set, I compared their performance using Root Mean Squared Error (RMSE):
Model RMSE Features Used Random Forest 9.43 All features GBT 9.21 All features Linear Regression 10.87 All features Simple Linear 12.34 Distance, time only
The Gradient Boosted Trees model performed best, with an RMSE of 9.21, meaning that on average, our predictions were within about $9.21 of the actual fare amount.
Understanding which features influence the prediction most is valuable for business insights. The Random Forest model revealed that the top predictors were:
- Trip distance (0.42)
- Trip duration (0.31)
- Average speed (0.11)
- Pickup hour (0.06)
- Passenger count (0.04)
This confirms our intuition but also reveals the substantial impact of time-based features.
To make the model usable in production, I created a simple prediction function and a SQL User-Defined Function (UDF):
# Function to make predictions with the saved model
def predict_fare(model_path, trip_data):
"""
Makes fare predictions using the saved modelArgs:
model_path: Path to the saved model
trip_data: Spark DataFrame with trip features
"""
from pyspark.ml import PipelineModel
# Load the model
loaded_model = PipelineModel.load(model_path)
# Make predictions
predictions = loaded_model.transform(trip_data)
return predictions
# SQL UDF for simple predictions
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType@udf(returnType=DoubleType())
def simple_fare_predictor(distance, passengers, hour_of_day):
"""Simple UDF for fare prediction in SQL"""
base_fare = 3.0
distance_fare = distance * 2.5# Time factors
rush_hour = hour_of_day in [7, 8, 9, 17, 18, 19]
night_time = hour_of_day >= 22 or hour_of_day <= 5time_multiplier = 1.0
if rush_hour:
time_multiplier = 1.2
elif night_time:
time_multiplier = 1.1# Passenger factor
passenger_factor = 1.0 + (min(passengers, 6) - 1) * 0.05return (base_fare + distance_fare) * time_multiplier * passenger_factor
# Register the UDF
spark.udf.register("simple_fare_predictor", simple_fare_predictor)
Beyond just prediction, our analysis revealed several valuable business insights:
- Optimal Pricing Windows: The most profitable hours for taxi operation are 7–9 AM and 5–7 PM on weekdays.
- Trip Efficiency Metrics: Longer trips tend to be more profitable per minute than shorter trips, suggesting potential for incentivizing longer rides.
- Passenger Count Impact: Trips with 3+ passengers had higher average fares, indicating an opportunity for promoting shared rides.
- Location-Based Analysis: Certain pickup-dropoff pairs consistently generated higher fares, pointing to potential for strategic driver positioning.
Building this system wasn’t without challenges:
The dataset contained anomalies like zero-distance trips with substantial fares and unrealistic speeds. I addressed these by creating domain-specific filters based on physics and business rules.
While GBT provided the best accuracy, it was also the most computationally expensive. For real-time predictions, the Random Forest model offered a better balance of accuracy and inference speed.
Creating meaningful features from timestamp data required domain knowledge. For example, understanding that rush hour patterns differ by day of week significantly improved predictions.
There are several ways the model could be enhanced:
- Geospatial Features: Incorporate mapping data to calculate actual route distances instead of straight-line distances.
- External Data Integration: Add weather conditions, local events, and traffic data to better account for contextual factors.
- Time Series Analysis: Implement more sophisticated time-based models to capture seasonal patterns.
- Real-time Updates: Develop an online learning system that continuously updates as new trip data becomes available.
Building a taxi fare prediction system with PySpark and Databricks demonstrates how machine learning can solve real-world business problems at scale. The distributed nature of Spark allows processing millions of records efficiently, while Databricks provides a collaborative environment for the entire data science workflow.
Beyond technical implementation, this project illustrates how data science can deliver business value through better pricing strategies, operational efficiency, and enhanced customer experience.
The complete code for this project is available on GitHub.