Why Deep Learning Struggles with Tables? | by Ori Golfryd | Dec, 2024


Deep learning has brought significant advancements in the areas of image recognition and natural language processing. But when it comes to tabular data, it can be seen that deep learning hasn’t been as successful. In this short paper, we’ll explore what deep learning is, what tabular data entails, why deep learning struggles with it, and delve into new methods researchers are developing to address these challenges.

Deep learning is a type of machine learning that uses artificial neural networks designed to mimic the way humans think and learn. Imagine layers of interconnected nodes (neurons) that process data and make decisions. Each layer learns to recognize different features of the data, building up from simple to complex patterns.

For example, in image recognition, the first layer might detect edges, the next layer might recognize shapes, and higher layers identify objects like cats or cars by combining these features.

Tabular data is data organized into rows and columns, like a table in a spreadsheet. Each row is one sample (e.g., a customer), and each column is a feature (e.g., name, age, or income).

This type of data is common in many fields:

  • Finance: Transaction records, account balances.
  • Healthcare: Patient information, test results.
  • Retail: Sales data, inventory levels.

Because it is simple and well-organized, traditional machine learning methods like XGBoost and LightGBM work really well with tabular data. They can easily handle different types of features and understand complex relationships between them.

Despite its success in other areas, deep learning often doesn’t do as well with tabular data. Here’s why:

1. Complex Feature Interactions

Tabular data often has complicated relationships between features. For example, how age and income together affect whether a person buys a product. Traditional methods like decision trees are good at spotting these interactions, but deep learning models can miss them unless they are specially designed to look for them.

2. Small Datasets

Deep learning models usually need a lot of data to work well. However, in many real-world situations, datasets are relatively small. This makes it hard for deep learning models to generalize.

3. No Spatial or Sequential Structure

Deep learning excels with data that has spatial (like images) or sequential (like text or audio) structures. However, tabular data doesn’t have these structures, so it can be harder for deep learning to find useful patterns.

4. Overfitting

Deep learning models have a lot of parameters to learn, which can make them overfit. This means they do well on training data but poorly on new, unseen data. This is a big problem, especially with small datasets.

5. Computational Resources

Deep learning models take a lot of time and computer power to train, often needing special hardware like GPUs. This can make them less practical for quick analyses or in situations with limited computational resources.

Researchers are coming up with new ways to make deep learning work better with tabular data. Here are some examples:

1. TabTransformer

TabTransformer uses the transformer model, originally designed for text, to handle tabular data [1].

How does it work?

  • Embedding Categorical Features: It turns categorical variables (like “red,” “blue,” “green”) into continuous numerical vectors called embeddings. This helps the model understand relationships between categories. For example, if “red” and “blue” often lead to similar outcomes in the data, their embeddings will be close together in the model’s space.
  • Self-Attention Mechanism: The model uses self-attention to weigh the importance of each feature relative to others within a data sample. For example, when predicting if a customer will make a purchase, the model might learn that “age” and “income” are more important than “favorite color”.
  • Capturing Feature Interactions: By stacking multiple attention layers, TabTransformer can model complex interactions between features.

Why is it helpful?

By embedding categorical features, the model can capture delicate relationships that traditional methods might miss. Moreover, Self-attention can allow the model to focus on important features and how they interact, improving prediction accuracy.

2. SAINT (Self-Attention and Intersample Attention Transformer)

SAINT builds on TabTransformer by also looking at relationships between different data samples [2].

How does it work?

  • Self-Attention: Like TabTransformer, it weighs the importance of features within each sample.
  • Intersample Attention: The model looks at other samples in the dataset to find patterns and similarities. For example, if several customers have similar shopping habits, SAINT learns from these similarities to make better predictions.
  • Contrastive Pre-Training: SAINT uses a pre-training step where it learns to distinguish between similar and different samples, which helps it generalize better.

Why is it helpful?

By considering how samples relate to each other, SAINT can find patterns that might be missed when looking at samples individually. Additionally, intersample attention can help the model make better predictions on new, unseen data

3. FT-Transformer (Feature Tokenizer Transformer)

FT-Transformer makes the transformer model simpler for tabular data [3].

How does it work?

  • Unified Processing of Features: It turns both numerical and categorical features into embeddings, treating them the same way.
  • Simplified Architecture: It focuses on the essential parts of the transformer architecture that are most beneficial for tabular data.
  • Regularization Techniques: It uses methods to prevent overfitting, such as dropout and batch normalization.

Why is it helpful?

By efficiently handling different types of features and simplifying the model, FT-Transformer can achieve high performance without requiring excessive computational resources.

4. NODE (Neural Oblivious Decision Ensembles)

NODE combines decision trees with neural networks [4].

How does it work?

  • Differentiable Decision Trees: It uses tree-like structures that can be trained with gradient descent, making the trees “differentiable”.
  • Ensemble Learning: It combines multiple trees to improve accuracy.
  • Feature Selection: The model automatically learns which features are most important for the task.

Why is it helpful?

By mimicking decision trees within a neural network framework, NODE can leverage the strengths of both methods.

5. TabPFN (Tabular Prior-Data Fitted Network)

TabPFN is designed for small datasets and fast training [5].

How does it work?

  • Pre-Training on Synthetic Data: The model is pre-trained on lots of synthetic datasets, learning general patterns that can be applied to new data.
  • Fast Fine-Tuning: When applied to a new task, it can be fine-tuned quickly with minimal data and computational effort.
  • Probabilistic Predictions: It gives not only predictions but also how confident it is about them.

Why is it helpful?

TabPFN can perform well on small datasets, which is often a limitation for deep learning models. Its ability to provide uncertainty estimates adds an extra layer of insight.

6. DeepGBM

DeepGBM combines gradient boosting machines (GBMs) with deep learning [6].

How does it work?

  • Feature Learning: It uses deep learning to create new features from the raw data.
  • Prediction: Uses GBMs to make accurate predictions with these features.

Why is it helpful?

By combining the feature learning capabilities of neural networks with the predictive power of GBMs, DeepGBM aims to achieve better results than using either method alone.

7. ResNet for Tabular Data

ResNet is a deep learning architecture originally designed for image processing. It has been adapted for tabular data to address some challenges deep learning models face in this domain [3].

How does it work?

  • Residual Connections: In ResNet, some layers skip connections and add their output directly to later layers. This helps the model learn better by avoiding issues like vanishing gradients (when gradients become too small to update the model effectively).
  • Batch Normalization: It normalizes inputs to layers, which helps stabilize and speed up training.

Why is it helpful?

ResNet can allow training of deeper models that can capture more complex patterns. It also addresses issues like vanishing gradients, which can hurt model performance.

Even though deep learning has challenges, it has some unique benefits for tabular data:

1. Handling Complex, High-Dimensional Data

In cases where there are many features (high-dimensional data), deep learning models can find patterns that traditional methods might miss.

2. Integrating Different Data Types

Deep learning can combine tabular data with unstructured data like images or text. For example, in healthcare, a model might use patient records (tabular data) and MRI scans (images) together to improve diagnosis.

3. Anomaly Detection

Deep learning models can learn normal patterns in data and identify when something doesn’t fit, which is useful for detecting anomalies.

4. Automatic Feature Learning

Neural networks can automatically learn which features are important, reducing the need for manual feature engineering.

From my perspective, making deep learning work well with tabular data involves creativity and mixing different approaches. Here are some insights:

  • Hybrid Models Are Promising: Combining deep learning with traditional methods can leverage the strengths of both. For example, use a neural network to create features that capture complex interactions, then feed these into a gradient boosting model for prediction. This way, you get the neural network’s feature learning and the boosting model’s predictive power.
  • Data Augmentation Can Help: Techniques that increase the amount of training data, such as creating synthetic data points, might improve model performance.
  • Interpretability is Key: Deep learning models are often seen as “black boxes,” which can be a problem in fields where understanding decisions is crucial.
  • Experimentation is Essential: There’s no one-size-fits-all solution; different datasets may need different approaches. In some cases, ensemble methods like random forests may outperform deep learning models. It’s important to test multiple models and settings to find the best fit.
  • Focusing on Data Quality: High-quality data is essential for model — both deep leaning models and traditional models — success. Cleaning data to handle missing values, outliers, and inconsistencies can significantly improve model performance.

Deep learning faces challenges when working with tabular data, but new methods are helping to overcome these obstacles. By adapting models to better suit the nature of tabular data and combining deep learning with traditional techniques, we can unlock new possibilities for data analysis and prediction.

[1] Huang, Khetan, Cvitkovic & Karnin TabTransformer: Tabular data modeling using contextual embeddings (2020).

GitHub Repository: github.com/lucidrains/tab-transformer-pytorch.

[2] Somepalli, Murthy, Vasisht & Jawahar SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training (2021).

GitHub Repository: github.com/somepago/saint.

[3] Gorishniy, Rubachev, Khrulkov, & Babenko Revisiting deep learning models for tabular data (2021).

GitHub Repository: github.com/yandex-research/rtdl-revisiting-models.

[4] Popov, Morozov & Babenko A. Neural oblivious decision ensembles for deep learning on tabular data (2019).

GitHub Repository: github.com/Qwicen/node.

[5] Hüttenrauch, Kaddour & Bonhoeffer TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second (2022).

GitHub Repository: github.com/automl/TabPFN.

[6] Somepalli, Kuppili, & Jawahar DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks (2021).

GitHub Repository: github.com/motefly/DeepGBM.

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here