How We Rebuilt Our ML Infrastructure to Handle 100M+ Daily Predictions

Gopal Singh
4 min readNov 7, 2024

ML Architecture Evolution: From Microservices to Scalable ML Systems

Introduction

When our team started building our product recommendation engine, we followed the standard advice: “Use microservices for flexibility and scalability.” Six months and $45K in technical debt later, we learned that conventional wisdom doesn’t always apply to ML systems.

This article details our journey from a simple microservice architecture to a robust ML system handling 100M+ daily predictions. More importantly, it explains why common software architecture patterns might not be the best fit for ML systems.

The Initial Architecture

Our first architecture was what you’d expect from any modern application:

# Initial FastAPI service
from fastapi import FastAPI
from transformers import RecommenderModel

app = FastAPI()
model = RecommenderModel()

@app.post("/predict")
async def get_recommendations(user_id: str):
user_features = await get_user_features(user_id)
recommendations = model.predict(user_features)
return {"recommendations": recommendations}

Simple, clean, and followed all the microservice best practices. But it had several hidden problems:

  1. Cold Starts: Every new instance loaded the model from scratch
  2. Resource Utilization: Idle instances still consumed full memory
  3. Data Freshness: Features were often stale
  4. Scaling Issues: Linear resource scaling with request volume

The Evolution

1. Request Handling Layer

The first major change was splitting read and write paths:

# Separated read/write paths with batching
from typing import List
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel

app = FastAPI()
batch_processor = BatchProcessor()

class PredictionRequest(BaseModel):
user_id: str
context: dict

@app.post("/predict/batch")
async def batch_predictions(requests: List[PredictionRequest]):
# Handle batch of requests
batch_id = batch_processor.queue_batch(requests)
return {"batch_id": batch_id}

@app.get("/predict/{batch_id}")
async def get_batch_results(batch_id: str):
# Retrieve batch results
results = batch_processor.get_results(batch_id)
return results

@app.post("/predict/realtime")
async def realtime_predict(request: PredictionRequest):
# Priority queue for realtime requests
if is_priority_user(request.user_id):
return await realtime_processor.predict(request)
return await batch_processor.quick_predict(request)

Key improvements:

  • Batch processing for efficiency
  • Priority queuing for important requests
  • Circuit breakers for system protection

2. Prediction Layer

The biggest performance gain came from intelligent caching:

# Intelligent caching layer
from functools import lru_cache
from redis import Redis
from typing import List, Dict

class PredictionCache:
def __init__(self):
self.redis = Redis()
self.local_cache = {}

@lru_cache(maxsize=10000)
def get_cached_prediction(self, user_id: str, context: Dict):
# Try local cache
cache_key = f"{user_id}:{hash(str(context))}"
if cache_key in self.local_cache:
return self.local_cache[cache_key]

# Try Redis cache
cached = self.redis.get(cache_key)
if cached:
self.local_cache[cache_key] = cached
return cached

# Generate new prediction
prediction = self.generate_prediction(user_id, context)

# Cache with appropriate TTL
ttl = self.calculate_ttl(user_id, context)
self.redis.setex(cache_key, ttl, prediction)
self.local_cache[cache_key] = prediction

return prediction

def calculate_ttl(self, user_id: str, context: Dict) -> int:
# Dynamic TTL based on user behavior and context
base_ttl = 3600 # 1 hour
if is_active_user(user_id):
base_ttl = 1800 # 30 minutes
if is_high_volatility_context(context):
base_ttl = 300 # 5 minutes
return base_ttl

This caching system provided:

  • Multi-level caching (local + Redis)
  • Dynamic TTL based on user behavior
  • Pre-computation for popular items
  • Intelligent cache invalidation

3. Data Processing Layer

Event-driven processing transformed our data freshness:

# Event-driven data processing
from pydantic import BaseModel
from typing import List, Optional

class FeatureStore:
def __init__(self):
self.feature_validators = {}
self.quality_gates = {}

async def process_event(self, event: Dict):
# Process events as they arrive
features = self.extract_features(event)
if self.validate_features(features):
await self.update_features(features)
await self.trigger_recomputation(features)

def validate_features(self, features: Dict) -> bool:
# Data quality gates
for feature, value in features.items():
if feature in self.quality_gates:
if not self.quality_gates[feature](value):
log_quality_issue(feature, value)
return False
return True

async def update_features(self, features: Dict):
# Incremental feature updates
for feature, value in features.items():
await self.feature_store.update(
feature_name=feature,
value=value,
timestamp=datetime.now()
)

async def trigger_recomputation(self, features: Dict):
# Trigger async model updates if needed
affected_models = self.get_affected_models(features)
for model in affected_models:
await self.model_updater.queue_update(model)

This approach enabled:

  • Real-time feature updates
  • Data quality monitoring
  • Incremental model updates
  • Feature freshness tracking

The Counter-Intuitive Learnings

1. Fewer Services, Better Performance

Instead of having separate services for every function, we consolidated into three main components:

  • Request Handler
  • Prediction Engine
  • Data Processor

This reduced network hops, simplified monitoring, and improved reliability.

2. Caching > Computational Optimization

We spent weeks optimizing our model inference before realizing that intelligent caching would give us better results with less complexity.

3. Data Flow > Service Architecture

Focus on optimizing how data moves through your system rather than how code is organized into services.

Performance Metrics

Before and after comparison:

| Metric            | Before  | After   | Improvement |
|-------------------|---------|---------|-------------|
| Latency (p95) | 2000ms | 150ms | 92.5% |
| Cost per prediction | $0.001 | $0.0001 | 90% |
| System uptime | 99.9% | 99.99% | 0.09% |
| Data freshness | 4 hours | 5 minutes | 98% |
  1. Start Simple
  • Begin with a monolith
  • Identify bottlenecks through monitoring
  • Split only what causes problems

2. Cache Strategically

  • Implement multi-level caching
  • Use dynamic TTLs
  • Pre-compute popular predictions

3. Monitor Everything

  • Track business metrics
  • Monitor data quality
  • Set up alerting for degradation

4. Build Fallbacks

  • Implement degraded service modes
  • Have backup prediction strategies
  • Design for failure

Conclusion

The microservices architecture pattern, while excellent for many applications, isn’t always the best fit for ML systems. The key is to understand your specific requirements and constraints, then design accordingly.

Remember:

  1. Complex problems don’t always need complex solutions
  2. Focus on data flow over service boundaries
  3. Cache aggressively, update lazily
  4. Build fallbacks before optimizations

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

Gopal Singh
Gopal Singh

Written by Gopal Singh

Data Scientist , Programmer, Writer. Let’s connect: https://www.linkedin.com/in/theunblunt

No responses yet

Write a response