Distributed ML Example

This example shows how to distribute ML inference across multiple Erlang nodes, leveraging the BEAM’s built-in clustering.

Architecture

                    ┌─────────────────┐
                    │   Web Server    │
                    │  (Hornbeam)     │
                    └────────┬────────┘

              ┌──────────────┼──────────────┐
              │              │              │
       ┌──────▼─────┐ ┌──────▼─────┐ ┌──────▼─────┐
       │  GPU Node  │ │  GPU Node  │ │  GPU Node  │
       │  worker1@  │ │  worker2@  │ │  worker3@  │
       └────────────┘ └────────────┘ └────────────┘

Web Server (Hornbeam)

# app.py
from fastapi import FastAPI
from pydantic import BaseModel
from hornbeam_erlang import rpc_call, nodes, node
from hornbeam_ml import cached_inference
import asyncio
from typing import Optional

app = FastAPI()

# ============================================================
# Models
# ============================================================

class InferenceRequest(BaseModel):
    texts: list[str]
    model: str = "default"

class InferenceResponse(BaseModel):
    embeddings: list[list[float]]
    node: str
    cached: int

# ============================================================
# Distributed Inference
# ============================================================

def get_gpu_nodes():
    """Get available GPU worker nodes."""
    return [n for n in nodes() if n.startswith('gpu') or n.startswith('worker')]

def select_node(nodes_list):
    """Select node with least load (round-robin for simplicity)."""
    if not nodes_list:
        return None
    # Simple round-robin
    from hornbeam_erlang import state_incr
    idx = state_incr('node_selector') % len(nodes_list)
    return nodes_list[idx]

@app.post("/infer", response_model=InferenceResponse)
async def distributed_inference(request: InferenceRequest):
    """Run inference on a GPU node."""
    gpu_nodes = get_gpu_nodes()

    if not gpu_nodes:
        # Fallback to local inference
        from sentence_transformers import SentenceTransformer
        model = SentenceTransformer('all-MiniLM-L6-v2')
        embeddings = model.encode(request.texts)
        return InferenceResponse(
            embeddings=embeddings.tolist(),
            node=node(),
            cached=0
        )

    # Select a GPU node
    target_node = select_node(gpu_nodes)

    # Call remote node
    result = rpc_call(
        target_node,
        'ml_worker',
        'encode',
        [request.texts, request.model],
        timeout_ms=60000
    )

    return InferenceResponse(
        embeddings=result['embeddings'],
        node=target_node,
        cached=result.get('cached', 0)
    )

@app.post("/infer/parallel")
async def parallel_inference(request: InferenceRequest):
    """Distribute across all GPU nodes in parallel."""
    gpu_nodes = get_gpu_nodes()

    if not gpu_nodes:
        raise HTTPException(status_code=503, detail="No GPU nodes available")

    # Split texts across nodes
    n_nodes = len(gpu_nodes)
    chunk_size = (len(request.texts) + n_nodes - 1) // n_nodes
    chunks = [
        request.texts[i:i + chunk_size]
        for i in range(0, len(request.texts), chunk_size)
    ]

    # Call all nodes in parallel
    import concurrent.futures

    def call_node(node, texts):
        return rpc_call(
            node,
            'ml_worker',
            'encode',
            [texts, request.model],
            timeout_ms=60000
        )

    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=n_nodes) as executor:
        futures = {
            executor.submit(call_node, node, chunk): node
            for node, chunk in zip(gpu_nodes, chunks)
            if chunk
        }

        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            results.extend(result['embeddings'])

    return {
        'embeddings': results,
        'nodes_used': len(gpu_nodes)
    }

# ============================================================
# Cluster Status
# ============================================================

@app.get("/cluster")
async def cluster_status():
    """Get cluster status."""
    gpu_nodes = get_gpu_nodes()

    status = {
        'this_node': node(),
        'gpu_nodes': [],
        'total_nodes': len(nodes())
    }

    for gpu_node in gpu_nodes:
        try:
            info = rpc_call(gpu_node, 'ml_worker', 'info', [], timeout_ms=5000)
            status['gpu_nodes'].append({
                'node': gpu_node,
                'status': 'online',
                **info
            })
        except Exception as e:
            status['gpu_nodes'].append({
                'node': gpu_node,
                'status': 'offline',
                'error': str(e)
            })

    return status

GPU Worker Node

Erlang Module

%% ml_worker.erl
-module(ml_worker).
-export([encode/2, info/0, predict/2]).

encode(Texts, Model) ->
    %% Call Python with caching
    Result = py:call('ml_service', 'encode', [Texts, Model]),
    Result.

predict(Input, Model) ->
    py:call('ml_service', 'predict', [Input, Model]).

info() ->
    #{
        model => py:call('ml_service', 'get_model_info', []),
        memory => py:memory_stats(),
        cache => py:call('ml_service', 'cache_stats', [])
    }.

Python Service

# ml_service.py
from sentence_transformers import SentenceTransformer
from hornbeam_ml import cached_inference, cache_stats as get_cache_stats
from hornbeam_erlang import state_get, state_set

# Load model at import
models = {}

def get_model(name='default'):
    if name not in models:
        model_map = {
            'default': 'all-MiniLM-L6-v2',
            'large': 'all-mpnet-base-v2',
            'multilingual': 'paraphrase-multilingual-MiniLM-L12-v2'
        }
        model_name = model_map.get(name, name)
        models[name] = SentenceTransformer(model_name)
    return models[name]

def encode(texts, model_name='default'):
    """Encode texts with caching."""
    model = get_model(model_name)

    embeddings = []
    cached_count = 0

    for text in texts:
        cache_key = f'{model_name}:{hash(text)}'
        cached = state_get(f'emb:{cache_key}')

        if cached:
            embeddings.append(cached)
            cached_count += 1
        else:
            emb = model.encode(text).tolist()
            state_set(f'emb:{cache_key}', emb)
            embeddings.append(emb)

    return {
        'embeddings': embeddings,
        'cached': cached_count
    }

def predict(input_data, model_name='default'):
    """Run prediction."""
    model = get_model(model_name)
    return model.encode(input_data).tolist()

def get_model_info():
    """Get loaded model info."""
    return {
        'loaded_models': list(models.keys()),
        'default_dimensions': get_model().get_sentence_embedding_dimension()
    }

def cache_stats():
    return get_cache_stats()

Starting the Cluster

Start GPU Worker Nodes

# On gpu-server-1
erl -name worker1@gpu-server-1 -setcookie mysecret

# In the shell
application:ensure_all_started(erlang_python).

Start Web Server

# On web-server
erl -name web@web-server -setcookie mysecret

# Connect to GPU nodes
net_adm:ping('worker1@gpu-server-1').
net_adm:ping('worker2@gpu-server-2').

# Start Hornbeam
hornbeam:start("app:app", #{
    worker_class => asgi,
    pythonpath => ["distributed_ml"]
}).

Load Balancing Strategies

1. Round-Robin (Simple)

def select_node_round_robin(nodes_list):
    idx = state_incr('node_rr') % len(nodes_list)
    return nodes_list[idx]

2. Least Loaded

def select_node_least_loaded(nodes_list):
    loads = []
    for n in nodes_list:
        try:
            load = rpc_call(n, 'ml_worker', 'get_load', [], timeout_ms=1000)
            loads.append((n, load))
        except:
            loads.append((n, float('inf')))

    return min(loads, key=lambda x: x[1])[0]

3. Consistent Hashing (for caching)

import hashlib

def select_node_consistent(nodes_list, key):
    """Select node based on key hash for cache locality."""
    hash_val = int(hashlib.md5(key.encode()).hexdigest(), 16)
    idx = hash_val % len(nodes_list)
    return sorted(nodes_list)[idx]

Fault Tolerance

async def resilient_inference(texts, retries=2):
    """Inference with automatic failover."""
    gpu_nodes = get_gpu_nodes()

    for attempt in range(retries + 1):
        if not gpu_nodes:
            break

        target = select_node(gpu_nodes)

        try:
            result = rpc_call(
                target,
                'ml_worker',
                'encode',
                [texts, 'default'],
                timeout_ms=30000
            )
            return result
        except Exception as e:
            print(f"Node {target} failed: {e}")
            gpu_nodes.remove(target)

    # All nodes failed, try local
    return local_inference(texts)

Monitoring

@app.get("/metrics")
async def metrics():
    """Cluster-wide metrics."""
    from hornbeam_erlang import state_get

    return {
        'requests': {
            'total': state_get('metrics:requests') or 0,
            'distributed': state_get('metrics:distributed') or 0,
            'local_fallback': state_get('metrics:local') or 0
        },
        'cluster': {
            'nodes': len(get_gpu_nodes()),
            'node_selector_count': state_get('node_selector') or 0
        }
    }

Next Steps