Training Retrieval Models
The Phoenix retrieval system uses a two-tower architecture to discover relevant out-of-network content. By encoding users and candidate posts into a shared embedding space, the system can perform efficient similarity searches across millions of posts in real-time.
This guide walks you through configuring and training the Phoenix retrieval model using the provided JAX and Haiku implementation.
Overview of the Two-Tower Architecture
The retrieval model consists of two distinct components:
- User Tower: A Grok-based transformer that processes a user's action sequence, history, and features to produce a fixed-length vector representing their current interests.
- Candidate Tower: A projection layer that encodes post and author features into the same vector space as the user.
When both vectors are L2-normalized, the dot product between a user vector and a candidate vector represents their similarity, making it ideal for Approximate Nearest Neighbor (ANN) search.
Step 1: Configure the Retrieval Model
To train the model, you must first define the PhoenixRetrievalModelConfig. This configuration controls the transformer depth, the embedding dimensions, and the sequence lengths for user history.
from grok import TransformerConfig
from recsys_model import HashConfig
from recsys_retrieval_model import PhoenixRetrievalModelConfig
# 1. Define the underlying Grok transformer architecture
transformer_cfg = TransformerConfig(
emb_size=256,
widening_factor=4,
key_size=64,
num_q_heads=8,
num_kv_heads=4,
num_layers=6,
attn_output_multiplier=0.125,
)
# 2. Define the retrieval-specific settings
retrieval_config = PhoenixRetrievalModelConfig(
model=transformer_cfg,
emb_size=256,
history_seq_len=128, # Number of past actions to consider
candidate_seq_len=32, # Context window for candidates
hash_config=HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2
)
)
Step 2: Prepare Training Batches
The model expects a RecsysBatch containing feature hashes and a RecsysEmbeddings object containing the pre-looked-up vectors from your embedding tables.
The system uses hash-based embeddings to handle the massive cardinality of users and posts on X.
from recsys_model import RecsysBatch, RecsysEmbeddings
# A batch typically includes:
# - User hashes (ID, location, language)
# - History hashes (List of post IDs the user engaged with)
# - Candidate hashes (The target post and its author)
# - Actions (Like, Reply, Retweet, etc.)
batch = RecsysBatch(
user_hashes=user_hashes,
history_post_hashes=history_posts,
history_author_hashes=history_authors,
history_actions=actions,
history_product_surface=surfaces,
candidate_post_hashes=target_posts,
candidate_author_hashes=target_authors,
candidate_product_surface=target_surfaces
)
Step 3: Initialize and Run Training
Training is managed via the RetrievalModelRunner. During training, the model attempts to minimize the distance between the User Tower's output and the Candidate Tower's output for positive engagements.
from runners import RetrievalModelRunner
# Initialize the runner with your config
runner = RetrievalModelRunner(config=retrieval_config)
# Initialize parameters using a sample batch
rng = jax.random.PRNGKey(42)
params = runner.init(rng, batch, embeddings)
# Execute a training step
# The loss function uses in-batch negatives to optimize retrieval
logs, params, opt_state = runner.update(params, opt_state, batch, embeddings)
print(f"Training Loss: {logs['loss']}")
Step 4: Exporting for Inference
Once trained, the User Tower and Candidate Tower can be used independently:
- User Tower: Runs online when a user opens their "For You" feed to generate a "Query Vector."
- Candidate Tower: Runs offline (or during ingestion) to generate "Candidate Vectors" for all new posts, which are then indexed in a vector database.
Generating User Representations
# Use the inference runner to get the user's embedding vector
inference_runner = RecsysRetrievalInferenceRunner(config=retrieval_config)
user_vec = inference_runner.get_user_representation(
params,
batch,
embeddings
)
# user_vec is now a normalized vector of size [batch, emb_size]
# ready for dot-product search against the post corpus.
Key Training Considerations
Sequence Lengths
The history_seq_len determines how much context the transformer sees.
- Short sequences (e.g., 16-32): Capture immediate intent and "session-based" interests.
- Long sequences (e.g., 128-256): Capture long-term interests but increase computational cost and latency.
Negative Sampling
The retrieval model is trained using In-Batch Negatives. For a given user in a batch, the positive candidate is the post they actually engaged with. All other candidate posts in the same training batch are treated as negatives. To improve model quality, ensure your batch size is sufficiently large (typically 1024 or higher) to provide a diverse set of negatives.
L2 Normalization
The CandidateTower automatically applies L2 normalization to the output. Do not disable this, as the similarity search logic relies on the vectors existing on a unit hypersphere for consistent dot-product scoring.
# Internal CandidateTower projection logic
candidate_norm_sq = jnp.sum(candidate_embeddings**2, axis=-1, keepdims=True)
candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
candidate_representation = candidate_embeddings / candidate_norm