23 minute read

Why batch ASR won’t work for voice assistants, and how streaming models transcribe speech as you speak in under 200ms.

Introduction

Every time you say “Hey Google” or ask Alexa a question, you’re interacting with a streaming Automatic Speech Recognition (ASR) system. Unlike traditional batch ASR systems that wait for you to finish speaking before transcribing, streaming ASR must:

  • Emit words as you speak (not after)
  • Maintain < 200ms latency for first token
  • Handle millions of concurrent audio streams
  • Work reliably in noisy environments
  • Run on both cloud and edge devices
  • Adapt to different accents and speaking styles

This is fundamentally different from batch models like OpenAI’s Whisper, which achieve amazing accuracy but require the entire utterance before processing. For interactive voice assistants, this delay is unacceptable users expect immediate feedback.

What you’ll learn:

  • Why streaming requires different model architectures
  • RNN-Transducer (RNN-T) and CTC for streaming
  • How to maintain state across audio chunks
  • Latency optimization techniques (quantization, pruning, caching)
  • Scaling to millions of concurrent streams
  • Cold start and speaker adaptation
  • Real production systems (Google, Amazon, Apple)

Problem Definition

Design a production streaming ASR system that transcribes speech in real-time for a voice assistant platform.

Functional Requirements

  1. Streaming Transcription
    • Output tokens incrementally as user speaks
    • No need to wait for end of utterance
    • Partial results updated continuously
  2. Low Latency
    • First token latency: < 200ms (time from start of speech to first word)
    • Per-token latency: < 100ms (time between subsequent words)
    • End-of-utterance latency: < 500ms (finalized transcript)
  3. No Future Context
    • Cannot “look ahead” into future audio (non-causal)
    • Limited right context window (e.g., 320ms)
    • Must work with incomplete information
  4. State Management
    • Maintain conversational context across chunks
    • Remember acoustic and linguistic state
    • Handle variable-length inputs
  5. Multi-Language Support
    • 20+ languages
    • Automatic language detection
    • Code-switching (mixing languages)

Non-Functional Requirements

  1. Accuracy
    • Clean speech: WER < 5% (Word Error Rate)
    • Noisy speech: WER < 15%
    • Accented speech: WER < 10%
    • Far-field: WER < 20%
  2. Throughput
    • 10M concurrent audio streams globally
    • 10k QPS per regional cluster
    • Auto-scaling based on load
  3. Availability
    • 99.99% uptime (< 1 hour downtime/year)
    • Graceful degradation on failures
    • Multi-region failover
  4. Cost Efficiency
    • < $0.01 per minute of audio (cloud)
    • < 100ms inference time on edge devices
    • GPU/CPU optimization

Out of Scope

  • Audio storage and archival
  • Speaker diarization (who is speaking)
  • Speech translation
  • Emotion/sentiment detection
  • Voice biometric authentication

Streaming vs Batch ASR: Key Differences

Batch ASR (e.g., Whisper)

def batch_asr(audio):
    # Wait for complete audio
    complete_audio = wait_for_end_of_speech(audio)
    
    # Process entire sequence at once
    # Can use bidirectional models, look at future context
    features = extract_features(complete_audio)
    transcript = model(features)  # Has access to all frames
    
    return transcript

# Latency: duration + processing time
# For 10-second audio: 10 seconds + 2 seconds = 12 seconds

Pros:

  • Can use future context → better accuracy
  • Simpler architecture (no state management)
  • Can use attention over full sequence

Cons:

  • High latency (must wait for end)
  • Poor user experience for voice assistants
  • Cannot provide real-time feedback

Streaming ASR

def streaming_asr(audio_stream):
    state = initialize_state()
    
    for audio_chunk in audio_stream:  # Process 100ms chunks
        # Can only look at past + limited future
        features = extract_features(audio_chunk)
        tokens, state = model(features, state)  # Causal processing
        
        if tokens:
            yield tokens  # Emit immediately
    
    # Finalize
    final_tokens = finalize(state)
    yield final_tokens

# Latency: ~200ms for first token, ~100ms per subsequent token
# For 10-second audio: 200ms + (tokens * 100ms) ≈ 2-3 seconds total

Pros:

  • Low latency (immediate feedback)
  • Better user experience
  • Can interrupt/correct in real-time

Cons:

  • More complex (state management)
  • Slightly lower accuracy (no full future context)
  • Harder to train

Architecture Overview

Audio Input (100ms chunks @ 16kHz)
    ↓
Voice Activity Detection (VAD)
    ├─ Speech detected → Continue
    └─ Silence detected → Skip processing
    ↓
Feature Extraction
    ├─ Mel Filterbank (80 dims)
    ├─ Normalization
    └─ Delta features (optional)
    ↓
Streaming Acoustic Model
    ├─ Encoder (Conformer/RNN)
    ├─ Prediction Network
    └─ Joint Network
    ↓
Decoder (Beam Search)
    ├─ Language Model Fusion
    ├─ Beam Management
    └─ Token Emission
    ↓
Post-Processing
    ├─ Punctuation
    ├─ Capitalization
    └─ Inverse Text Normalization
    ↓
Transcription Output

Component 1: Voice Activity Detection (VAD)

Why VAD is Critical

Problem: Processing silence wastes 50-70% of compute.

Solution: Filter out non-speech audio before expensive ASR processing.

# Without VAD
total_audio = 10 seconds
speech = 3 seconds (30%)
silence = 7 seconds (70% wasted compute)

# With VAD
processed_audio = 3 seconds (save 70% compute)

VAD Approaches

Option 1: Energy-Based (Simple)

def energy_vad(audio_chunk, threshold=0.01):
    """
    Classify based on audio energy
    """
    energy = np.sum(audio_chunk ** 2) / len(audio_chunk)
    return energy > threshold

Pros: Fast (< 1ms), no model needed
Cons: Fails in noisy environments, no semantic understanding

Option 2: ML-Based (Robust)

class SileroVAD:
    """
    Using Silero VAD (open-source, production-ready)
    Model size: 1MB, Latency: ~2ms
    """
    def __init__(self):
        self.model, self.utils = torch.hub.load(
            repo_or_dir='snakers4/silero-vad',
            model='silero_vad'
        )
        self.get_speech_timestamps = self.utils[0]
    
    def is_speech(self, audio, sampling_rate=16000):
        """
        Args:
            audio: torch.Tensor, shape (samples,)
            sampling_rate: int
        
        Returns:
            bool: True if speech detected
        """
        speech_timestamps = self.get_speech_timestamps(
            audio, 
            self.model,
            sampling_rate=sampling_rate,
            threshold=0.5
        )
        
        return len(speech_timestamps) > 0

# Usage
vad = SileroVAD()

for audio_chunk in audio_stream:
    if vad.is_speech(audio_chunk):
        # Process with ASR
        process_asr(audio_chunk)
    else:
        # Skip, save compute
        continue

Pros: Robust to noise, semantic understanding
Cons: Adds 2ms latency, requires model

Production VAD Pipeline

class ProductionVAD:
    def __init__(self):
        self.vad = SileroVAD()
        self.speech_buffer = []
        self.silence_frames = 0
        self.max_silence_frames = 30  # 300ms of silence
    
    def process_chunk(self, audio_chunk):
        """
        Buffer management with hysteresis
        """
        is_speech = self.vad.is_speech(audio_chunk)
        
        if is_speech:
            # Reset silence counter
            self.silence_frames = 0
            
            # Add to buffer
            self.speech_buffer.append(audio_chunk)
            
            return 'speech', audio_chunk
        
        else:
            # Increment silence counter
            self.silence_frames += 1
            
            # Keep buffering for a bit (hysteresis)
            if self.silence_frames < self.max_silence_frames:
                self.speech_buffer.append(audio_chunk)
                return 'speech', audio_chunk
            
            else:
                # End of utterance
                if self.speech_buffer:
                    complete_utterance = np.concatenate(self.speech_buffer)
                    self.speech_buffer = []
                    return 'end_of_utterance', complete_utterance
                
                return 'silence', None

Key design decisions:

  • Hysteresis: Continue processing for 300ms after silence to avoid cutting off speech
  • Buffering: Accumulate audio for end-of-utterance finalization
  • State management: Track silence duration to detect utterance boundaries

Component 2: Feature Extraction

Log Mel Filterbank Features

Why Mel scale? Human perception of pitch is logarithmic, not linear.

def extract_mel_features(audio, sr=16000, n_mels=80):
    """
    Extract 80-dimensional log mel filterbank features
    
    Args:
        audio: np.array, shape (samples,)
        sr: sampling rate (Hz)
        n_mels: number of mel bands
    
    Returns:
        features: np.array, shape (time, n_mels)
    """
    # Frame audio: 25ms window, 10ms stride
    frame_length = int(0.025 * sr)  # 400 samples
    hop_length = int(0.010 * sr)     # 160 samples
    
    # Short-Time Fourier Transform
    stft = librosa.stft(
        audio,
        n_fft=512,
        hop_length=hop_length,
        win_length=frame_length,
        window='hann'
    )
    
    # Magnitude spectrum
    magnitude = np.abs(stft)
    
    # Mel filterbank
    mel_basis = librosa.filters.mel(
        sr=sr,
        n_fft=512,
        n_mels=n_mels,
        fmin=0,
        fmax=sr/2
    )
    
    # Apply mel filters
    mel_spec = np.dot(mel_basis, magnitude)
    
    # Log compression (humans perceive loudness logarithmically)
    log_mel = np.log(mel_spec + 1e-6)
    
    # Transpose to (time, frequency)
    return log_mel.T

Output: 100 frames per second (one every 10ms), each with 80 dimensions

Normalization

def normalize_features(features, mean=None, std=None):
    """
    Normalize to zero mean, unit variance
    
    Can use global statistics or per-utterance
    """
    if mean is None:
        mean = np.mean(features, axis=0, keepdims=True)
    if std is None:
        std = np.std(features, axis=0, keepdims=True)
    
    normalized = (features - mean) / (std + 1e-6)
    return normalized

Global vs Per-Utterance:

  • Global normalization: Use statistics from training data (faster, more stable)
  • Per-utterance normalization: Adapt to current speaker/environment (better for diverse conditions)

SpecAugment (Training Only)

def spec_augment(features, time_mask_max=30, freq_mask_max=10):
    """
    Data augmentation for training
    Randomly mask time and frequency bands
    """
    # Time masking
    t_mask_len = np.random.randint(0, time_mask_max)
    t_mask_start = np.random.randint(0, features.shape[0] - t_mask_len)
    features[t_mask_start:t_mask_start+t_mask_len, :] = 0
    
    # Frequency masking
    f_mask_len = np.random.randint(0, freq_mask_max)
    f_mask_start = np.random.randint(0, features.shape[1] - f_mask_len)
    features[:, f_mask_start:f_mask_start+f_mask_len] = 0
    
    return features

Impact: Improves robustness by 10-20% relative WER reduction


Component 3: Streaming Acoustic Models

RNN-Transducer (RNN-T)

Why RNN-T for streaming?

  1. Naturally causal: Doesn’t need future frames
  2. Emits tokens dynamically: Can output 0, 1, or multiple tokens per frame
  3. No external alignment: Learns alignment jointly with transcription

Architecture:

     Encoder (processes audio)
           ↓
     h_enc[t] (acoustic embedding)
           ↓
     Prediction Network (processes previous tokens)
           ↓
     h_pred[u] (linguistic embedding)
           ↓
     Joint Network (combines both)
           ↓
     Softmax over vocabulary + blank

Implementation:

import torch
import torch.nn as nn

class StreamingRNNT(nn.Module):
    def __init__(self, vocab_size=1000, enc_dim=512, pred_dim=256, joint_dim=512):
        super().__init__()
        
        # Encoder: audio features → acoustic representation
        self.encoder = ConformerEncoder(
            input_dim=80,
            output_dim=enc_dim,
            num_layers=18,
            num_heads=8
        )
        
        # Prediction network: previous tokens → linguistic representation
        self.prediction_net = nn.LSTM(
            input_size=vocab_size,
            hidden_size=pred_dim,
            num_layers=2,
            batch_first=True
        )
        
        # Joint network: combine acoustic + linguistic
        self.joint_net = nn.Sequential(
            nn.Linear(enc_dim + pred_dim, joint_dim),
            nn.Tanh(),
            nn.Linear(joint_dim, vocab_size + 1)  # +1 for blank token
        )
        
        self.blank_idx = vocab_size
    
    def forward(self, audio_features, prev_tokens, encoder_state=None, predictor_state=None):
        """
        Args:
            audio_features: (batch, time, 80)
            prev_tokens: (batch, seq_len)
            encoder_state: hidden state from previous chunk
            predictor_state: (h, c) from previous tokens
        
        Returns:
            logits: (batch, time, seq_len, vocab_size+1)
            new_encoder_state: updated encoder state
            new_predictor_state: updated predictor state
        """
        # Encode audio
        h_enc, new_encoder_state = self.encoder(audio_features, encoder_state)
        # h_enc: (batch, time, enc_dim)
        
        # Encode previous tokens
        # Convert tokens to one-hot
        prev_tokens_onehot = F.one_hot(prev_tokens, num_classes=self.prediction_net.input_size)
        h_pred, new_predictor_state = self.prediction_net(
            prev_tokens_onehot.float(),
            predictor_state
        )
        # h_pred: (batch, seq_len, pred_dim)
        
        # Joint network: combine all pairs of (time, token_history)
        # Expand dimensions for broadcasting
        h_enc_exp = h_enc.unsqueeze(2)  # (batch, time, 1, enc_dim)
        h_pred_exp = h_pred.unsqueeze(1)  # (batch, 1, seq_len, pred_dim)
        
        # Concatenate
        h_joint = torch.cat([
            h_enc_exp.expand(-1, -1, h_pred.size(1), -1),
            h_pred_exp.expand(-1, h_enc.size(1), -1, -1)
        ], dim=-1)
        # h_joint: (batch, time, seq_len, enc_dim+pred_dim)
        
        # Project to vocabulary
        logits = self.joint_net(h_joint)
        # logits: (batch, time, seq_len, vocab_size+1)
        
        return logits, new_encoder_state, new_predictor_state

Conformer Encoder

Why Conformer? Combines convolution (local patterns) + self-attention (long-range dependencies)

class ConformerEncoder(nn.Module):
    def __init__(self, input_dim=80, output_dim=512, num_layers=18, num_heads=8):
        super().__init__()
        
        # Subsampling: 4x downsampling to reduce sequence length
        self.subsampling = Conv2dSubsampling(input_dim, output_dim, factor=4)
        
        # Conformer blocks
        self.conformer_blocks = nn.ModuleList([
            ConformerBlock(output_dim, num_heads) 
            for _ in range(num_layers)
        ])
    
    def forward(self, x, state=None):
        # x: (batch, time, input_dim)
        
        # Subsampling
        x = self.subsampling(x)
        # x: (batch, time//4, output_dim)
        
        # Conformer blocks
        for block in self.conformer_blocks:
            x, state = block(x, state)
        
        return x, state

class ConformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        
        # Feed-forward module 1
        self.ff1 = FeedForwardModule(dim)
        
        # Multi-head self-attention
        self.attention = MultiHeadSelfAttention(dim, num_heads)
        
        # Convolution module
        self.conv = ConvolutionModule(dim, kernel_size=31)
        
        # Feed-forward module 2
        self.ff2 = FeedForwardModule(dim)
        
        # Layer norms
        self.norm_ff1 = nn.LayerNorm(dim)
        self.norm_att = nn.LayerNorm(dim)
        self.norm_conv = nn.LayerNorm(dim)
        self.norm_ff2 = nn.LayerNorm(dim)
        self.norm_out = nn.LayerNorm(dim)
    
    def forward(self, x, state=None):
        # Feed-forward 1 (half-step residual)
        residual = x
        x = self.norm_ff1(x)
        x = residual + 0.5 * self.ff1(x)
        
        # Self-attention
        residual = x
        x = self.norm_att(x)
        x, state = self.attention(x, state)
        x = residual + x
        
        # Convolution
        residual = x
        x = self.norm_conv(x)
        x = self.conv(x)
        x = residual + x
        
        # Feed-forward 2 (half-step residual)
        residual = x
        x = self.norm_ff2(x)
        x = residual + 0.5 * self.ff2(x)
        
        # Final norm
        x = self.norm_out(x)
        
        return x, state

Key features:

  • Macaron-style: Feed-forward at both beginning and end
  • Depthwise convolution: Captures local patterns efficiently
  • Relative positional encoding: Better for variable-length sequences

Streaming Constraints

Problem: Self-attention in Conformer uses entire sequence → not truly streaming

Solution: Limited lookahead window

class StreamingAttention(nn.Module):
    def __init__(self, dim, num_heads, left_context=1000, right_context=32):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads)
        self.left_context = left_context   # Look at past 10 seconds
        self.right_context = right_context  # Look ahead 320ms
    
    def forward(self, x, cache=None):
        # x: (batch, time, dim)
        
        if cache is not None:
            # Concatenate with cached past frames
            x = torch.cat([cache, x], dim=1)
        
        # Apply attention with limited context
        batch_size, seq_len, dim = x.shape
        
        # Create attention mask for causal attention with limited right context
        # PyTorch expects attn_mask shape (target_len, source_len)
        mask = self.create_streaming_mask(seq_len, self.right_context).to(x.device)
        
        # Attention
        # nn.MultiheadAttention expects (time, batch, dim)
        x_tbf = x.transpose(0, 1)
        x_att_tbf, _ = self.attention(x_tbf, x_tbf, x_tbf, attn_mask=mask)
        x_att = x_att_tbf.transpose(0, 1)
        
        # Cache for next chunk
        new_cache = x[:, -self.left_context:, :]
        
        # Return only new frames (not cached ones)
        if cache is not None:
            x_att = x_att[:, cache.size(1):, :]
        
        return x_att, new_cache
    
    def create_streaming_mask(self, seq_len, right_context):
        """
        Create mask where each position can attend to:
        - All past positions
        - Up to right_context future positions
        """
        # Start with upper-triangular ones (disallow future)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        # Allow limited lookahead: zero-out first right_context super-diagonals
        if right_context > 0:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1+right_context)
        # Convert to bool mask where True = disallow
        return mask.bool()

Greedy Decoding (Fast, Suboptimal)

def greedy_decode(model, audio_features):
    """
    Always pick highest-probability token
    Fast but misses better hypotheses
    """
    tokens = []
    state = None
    
    for frame in audio_features:
        logits, state = model(frame, tokens, state)
        best_token = torch.argmax(logits)
        
        if best_token != BLANK:
            tokens.append(best_token)
    
    return tokens

Pros: O(T) time, minimal memory
Cons: Can’t recover from mistakes, 10-20% worse WER

Beam Search (Better Accuracy)

class BeamSearchDecoder:
    def __init__(self, beam_size=10, blank_idx=0):
        self.beam_size = beam_size
        self.blank_idx = blank_idx
    
    def decode(self, model, audio_features):
        """
        Maintain top-k hypotheses at each time step
        """
        # Initial beam: empty hypothesis
        beams = [Hypothesis(tokens=[], score=0.0, state=None)]
        
        for frame in audio_features:
            candidates = []
            
            for beam in beams:
                # Get logits for this beam
                logits, new_state = model(frame, beam.tokens, beam.state)
                log_probs = F.log_softmax(logits, dim=-1)
                
                # Extend with each possible token
                for token_idx, log_prob in enumerate(log_probs):
                    if token_idx == self.blank_idx:
                        # Blank: don't emit token, just update score
                        candidates.append(Hypothesis(
                            tokens=beam.tokens,
                            score=beam.score + log_prob,
                            state=beam.state
                        ))
                    else:
                        # Non-blank: emit token
                        candidates.append(Hypothesis(
                            tokens=beam.tokens + [token_idx],
                            score=beam.score + log_prob,
                            state=new_state
                        ))
            
            # Prune to top beam_size hypotheses
            candidates.sort(key=lambda h: h.score, reverse=True)
            beams = candidates[:self.beam_size]
        
        # Return best hypothesis
        return beams[0].tokens

class Hypothesis:
    def __init__(self, tokens, score, state):
        self.tokens = tokens
        self.score = score
        self.state = state

Complexity: O(T × B × V) where T=time, B=beam size, V=vocabulary size
Typical parameters: B=10, V=1000 → manageable

Language Model Fusion

Problem: Acoustic model doesn’t know linguistic patterns (grammar, common phrases)

Solution: Integrate language model (LM) scores

def beam_search_with_lm(acoustic_model, lm, audio_features, lm_weight=0.3):
    """
    Combine acoustic model + language model scores
    """
    beams = [Hypothesis(tokens=[], score=0.0, state=None)]
    
    for frame in audio_features:
        candidates = []
        
        for beam in beams:
            logits, new_state = acoustic_model(frame, beam.tokens, beam.state)
            acoustic_log_probs = F.log_softmax(logits, dim=-1)
            
            for token_idx, acoustic_log_prob in enumerate(acoustic_log_probs):
                if token_idx == BLANK:
                    # Blank token
                    combined_score = beam.score + acoustic_log_prob
                    candidates.append(Hypothesis(
                        tokens=beam.tokens,
                        score=combined_score,
                        state=beam.state
                    ))
                else:
                    # Get LM score for this token
                    lm_log_prob = lm.score(beam.tokens + [token_idx])
                    
                    # Combine scores
                    combined_score = (
                        beam.score +
                        acoustic_log_prob +
                        lm_weight * lm_log_prob
                    )
                    
                    candidates.append(Hypothesis(
                        tokens=beam.tokens + [token_idx],
                        score=combined_score,
                        state=new_state
                    ))
            
        candidates.sort(key=lambda h: h.score, reverse=True)
        beams = candidates[:beam_size]
    
    return beams[0].tokens

LM types:

  • N-gram LM (KenLM): Fast (< 1ms), large memory (GBs)
  • Neural LM (LSTM/Transformer): Slower (5-20ms), better quality

Production choice: N-gram for first-pass, neural LM for rescoring top hypotheses


Latency Optimization

Target Breakdown

Total latency budget: 200ms

VAD:                    2ms
Feature extraction:     5ms
Encoder forward:       80ms  ← Bottleneck
Decoder (beam search): 10ms
Post-processing:        3ms
Network overhead:      20ms
Total:               120ms ✓ (60ms margin)

Technique 1: Model Quantization

INT8 Quantization: Convert float32 weights to int8

import torch.quantization as quantization

# Post-training quantization (easiest)
model_fp32 = load_model()
model_fp32.eval()

# Fuse operations (Conv+BN+ReLU → single op)
model_fused = quantization.fuse_modules(
    model_fp32,
    [['conv', 'bn', 'relu']]
)

# Quantize
model_int8 = quantization.quantize_dynamic(
    model_fused,
    {nn.Linear, nn.LSTM, nn.Conv2d},
    dtype=torch.qint8
)

# Save
torch.save(model_int8.state_dict(), 'model_int8.pth')

# Results:
# - Model size: 200MB → 50MB (4x smaller)
# - Inference speed: 80ms → 30ms (2.7x faster)
# - Accuracy: WER 5.2% → 5.4% (0.2% degradation)

Why quantization works:

  • Smaller memory footprint: Fits in L1/L2 cache
  • Faster math: INT8 operations 4x faster than FP32 on CPU
  • Minimal accuracy loss: Neural networks are surprisingly robust

Technique 2: Knowledge Distillation

Train small model to mimic large model

def distillation_loss(student_logits, teacher_logits, temperature=3.0):
    """
    Soft targets from teacher help student learn better
    """
    # Soften probabilities with temperature
    student_soft = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
    
    # KL divergence
    loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
    loss = loss * (temperature ** 2)
    
    return loss

# Training
teacher = large_model  # 18 layers, 80ms inference
student = small_model  # 8 layers, 30ms inference

for audio, transcript in training_data:
    # Get teacher predictions (no backprop)
    with torch.no_grad():
        teacher_logits = teacher(audio)
    
    # Student predictions
    student_logits = student(audio)
    
    # Distillation loss
    loss = distillation_loss(student_logits, teacher_logits)
    
    # Optimize
    loss.backward()
    optimizer.step()

# Results:
# - Student (8 layers): 30ms, WER 5.8%
# - Teacher (18 layers): 80ms, WER 5.0%
# - Without distillation: 30ms, WER 7.2%
# → Distillation closes the gap!

Technique 3: Pruning

Remove unimportant weights

import torch.nn.utils.prune as prune

def prune_model(model, amount=0.4):
    """
    Remove 40% of weights with minimal accuracy loss
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            # L1 unstructured pruning
            prune.l1_unstructured(module, name='weight', amount=amount)
            
            # Remove pruning reparameterization
            prune.remove(module, 'weight')
    
    return model

# Results:
# - 40% pruning: WER 5.0% → 5.3%, Speed +20%
# - 60% pruning: WER 5.0% → 6.2%, Speed +40%

Technique 4: Caching

Cache intermediate results across chunks

class StreamingASRWithCache:
    def __init__(self, model):
        self.model = model
        self.encoder_cache = None
        self.decoder_state = None
    
    def process_chunk(self, audio_chunk):
        # Extract features (no caching needed, fast)
        features = extract_features(audio_chunk)
        
        # Encoder: reuse cached hidden states
        encoder_out, self.encoder_cache = self.model.encoder(
            features,
            cache=self.encoder_cache
        )
        
        # Decoder: maintain beam state
        tokens, self.decoder_state = self.model.decoder(
            encoder_out,
            state=self.decoder_state
        )
        
        return tokens
    
    def reset(self):
        """Call at end of utterance"""
        self.encoder_cache = None
        self.decoder_state = None

Savings:

  • Without cache: Process all frames every chunk → 100ms
  • With cache: Process only new frames → 30ms (3.3x speedup)

Scaling to Millions of Users

Throughput Analysis

Per-stream compute:

  • Encoder: 30ms (after optimization)
  • Decoder: 10ms
  • Total: 40ms per 100ms audio chunk

CPU/GPU capacity:

  • CPU (16 cores): ~50 concurrent streams
  • GPU (T4): ~200 concurrent streams

For 10M concurrent streams:

  • GPUs needed: 10M / 200 = 50,000 GPUs
  • Cost @ $0.50/hr: $25k/hour = $18M/month

Way too expensive! Need further optimization.

Strategy 1: Batching

Batch multiple streams together

def batch_inference(audio_chunks, batch_size=32):
    """
    Process 32 streams simultaneously on GPU
    """
    # Pad to same length
    max_len = max(len(chunk) for chunk in audio_chunks)
    padded = [
        np.pad(chunk, (0, max_len - len(chunk)))
        for chunk in audio_chunks
    ]
    
    # Stack into batch
    batch = torch.tensor(padded)  # (32, max_len, 80)
    
    # Single forward pass
    outputs = model(batch)  # ~40ms for 32 streams
    
    return outputs

# Results:
# - Without batching: 40ms per stream
# - With batching (32): 40ms / 32 = 1.25ms per stream (32x speedup)
# - GPU needed: 10M / (200 × 32) = 1,562 GPUs
# - Cost: $0.78M/month (23x cheaper!)

Strategy 2: Regional Deployment

Deploy closer to users to reduce latency

North America: 3M users → 500 GPUs → 3 data centers
Europe: 2M users → 330 GPUs → 2 data centers
Asia: 4M users → 660 GPUs → 4 data centers
...

Total: ~1,500 GPUs globally

Benefits:

  • Lower network latency (30ms → 10ms)
  • Better fault isolation
  • Regulatory compliance (data residency)

Strategy 3: Hybrid Cloud-Edge

Run simple queries on-device, complex queries on cloud

def route_request(audio, user_context):
    # Estimate query complexity
    if is_simple_command(audio):  # "play music", "set timer"
        return on_device_asr(audio)  # 30ms, free, offline
    
    elif is_dictation(audio):  # Long-form transcription
        return cloud_asr(audio)  # 80ms, $0.01/min, high accuracy
    
    else:  # Conversational query
        return cloud_asr(audio)  # Best quality for complex queries

Distribution:

  • 70% simple commands → on-device
  • 30% complex queries → cloud
  • Effective cloud load: 3M concurrent (70% savings!)

Production Example: Putting It All Together

import asyncio
import websockets
import torch

class ProductionStreamingASR:
    def __init__(self):
        # Load optimized model
        self.model = self.load_optimized_model()
        
        # VAD
        self.vad = SileroVAD()
        
        # Session management
        self.sessions = {}  # session_id → StreamingSession
        
        # Metrics
        self.metrics = Metrics()
    
    def load_optimized_model(self):
        """Load quantized, pruned model"""
        model = StreamingRNNT(vocab_size=1000)
        
        # Load pre-trained weights
        checkpoint = torch.load('rnnt_optimized.pth')
        model.load_state_dict(checkpoint)
        
        # Quantize
        model_quantized = torch.quantization.quantize_dynamic(
            model,
            {torch.nn.Linear, torch.nn.LSTM},
            dtype=torch.qint8
        )
        
        model_quantized.eval()
        return model_quantized
    
    async def handle_stream(self, websocket, path):
        """Handle websocket connection from client"""
        session_id = generate_session_id()
        session = StreamingSession(session_id, self.model, self.vad)
        self.sessions[session_id] = session
        
        try:
            async for message in websocket:
                # Receive audio chunk (binary, 100ms @ 16kHz)
                audio_bytes = message
                audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
                audio_float = audio_array.astype(np.float32) / 32768.0
                
                # Process
                start_time = time.time()
                result = session.process_chunk(audio_float)
                latency = (time.time() - start_time) * 1000  # ms
                
                # Send partial transcript
                if result:
                    await websocket.send(json.dumps({
                        'type': 'partial',
                        'transcript': result['text'],
                        'tokens': result['tokens'],
                        'is_final': result['is_final']
                    }))
                
                # Track metrics
                self.metrics.record_latency(latency)
        
        except websockets.ConnectionClosed:
            # Finalize session
            final_transcript = session.finalize()
            print(f"Session {session_id} ended: {final_transcript}")
        
        finally:
            # Cleanup
            del self.sessions[session_id]
    
    def run(self, host='0.0.0.0', port=8765):
        """Start WebSocket server"""
        start_server = websockets.serve(self.handle_stream, host, port)
        asyncio.get_event_loop().run_until_complete(start_server)
        print(f"Streaming ASR server running on ws://{host}:{port}")
        asyncio.get_event_loop().run_forever()

class StreamingSession:
    def __init__(self, session_id, model, vad):
        self.session_id = session_id
        self.model = model
        self.vad = vad
        
        # State
        self.encoder_cache = None
        self.decoder_state = None
        self.partial_transcript = ""
        self.audio_buffer = []
    
    def process_chunk(self, audio):
        # VAD check
        if not self.vad.is_speech(audio):
            return None
        
        # Extract features
        features = extract_mel_features(audio)
        
        # Encode
        encoder_out, self.encoder_cache = self.model.encoder(
            features,
            cache=self.encoder_cache
        )
        
        # Decode (beam search)
        tokens, self.decoder_state = self.model.decoder(
            encoder_out,
            state=self.decoder_state,
            beam_size=5
        )
        
        # Convert tokens to text
        new_text = self.model.tokenizer.decode(tokens)
        self.partial_transcript += new_text
        
        return {
            'text': new_text,
            'tokens': tokens,
            'is_final': False
        }
    
    def finalize(self):
        """End of utterance processing"""
        # Post-processing
        final_transcript = post_process(self.partial_transcript)
        
        # Reset state
        self.encoder_cache = None
        self.decoder_state = None
        self.partial_transcript = ""
        
        return final_transcript

# Run server
if __name__ == '__main__':
    server = ProductionStreamingASR()
    server.run()

Key Takeaways

RNN-T architecture enables true streaming without future context
Conformer encoder combines convolution + attention for best accuracy
State management critical for maintaining context across chunks
Quantization + pruning achieve 4x compression, 3x speedup, < 1% WER loss
Batching provides 32x throughput improvement on GPUs
Hybrid cloud-edge reduces cloud load by 70%
VAD saves 50-70% compute by filtering silence


Further Reading

Papers:

Open-Source:

Courses:


Conclusion

Streaming ASR is a fascinating blend of signal processing, deep learning, and systems engineering. The key challenges low latency, high throughput, and maintaining accuracy without future context require careful architectural choices and aggressive optimization.

As voice interfaces become ubiquitous, streaming ASR systems will continue to evolve. Future directions include:

  • Multi-modal models (audio + video for better accuracy)
  • Personalization (adapt to individual speaking styles)
  • Emotion recognition (detect sentiment, stress, sarcasm)
  • On-device models (< 10MB, < 50ms, works offline)

The fundamentals covered here RNN-T, streaming architectures, optimization techniques will remain relevant as the field advances.

Now go build a voice assistant that feels truly conversational! 🎤🚀


Originally published at: arunbaby.com/speech-tech/0001-streaming-asr

If you found this helpful, consider sharing it with others who might benefit.