Adding Custom Feature Embeddings
The Phoenix model uses a hash-based embedding system to represent entities like users and posts. This architecture allows the model to learn representations for millions of entities without needing a manual feature engineering pipeline.
If you want to introduce a new signals—such as a user’s preferred language or a post's media category—follow this guide to extend the embedding logic.
1. Update the Hash Configuration
First, define how many hash functions will be used for your new feature. This is controlled in the HashConfig dataclass within phoenix/recsys_model.py.
Open phoenix/recsys_model.py and add your new configuration parameter:
@dataclass
class HashConfig:
"""Configuration for hash-based embeddings."""
num_user_hashes: int = 2
num_item_hashes: int = 2
num_author_hashes: int = 2
# Add your custom feature hash count here
num_language_hashes: int = 1
2. Extend the Data Batch Structures
The model handles data in two stages: the raw hash values (RecsysBatch) and the looked-up embedding vectors (RecsysEmbeddings). You must update both in phoenix/recsys_model.py.
Update RecsysBatch
Add the field for the raw hash values that will be passed from your data pipeline:
class RecsysBatch(NamedTuple):
user_hashes: jax.typing.ArrayLike
history_post_hashes: jax.typing.ArrayLike
# ... existing fields
language_hashes: jax.typing.ArrayLike # New field
Update RecsysEmbeddings
Add the field that will hold the actual embedding vectors after they are retrieved from the embedding tables:
@dataclass
class RecsysEmbeddings:
user_embeddings: jax.typing.ArrayLike
candidate_post_embeddings: jax.typing.ArrayLike
# ... existing fields
language_embeddings: jax.typing.ArrayLike # New field
3. Modify the Reduction Logic
The "reduction" functions are responsible for taking multiple hash embeddings and projecting them into a single vector that the Grok transformer understands. Depending on whether your feature is a User feature or a Post (Item) feature, you will modify block_user_reduce or block_history_reduce.
Example: Adding a User-level Feature
In block_user_reduce, you need to concatenate your new embedding with the existing ones before the linear projection:
def block_user_reduce(
user_hashes: jnp.ndarray,
user_embeddings: jnp.ndarray,
language_embeddings: jnp.ndarray, # Add parameter
num_user_hashes: int,
num_language_hashes: int, # Add parameter
emb_size: int,
# ...
):
B = user_embeddings.shape[0]
D = emb_size
# Concatenate the new embeddings along the feature dimension
# Original shape: [B, 1, (num_user_hashes + num_language_hashes) * D]
combined = jnp.concatenate([
user_embeddings.reshape((B, 1, num_user_hashes * D)),
language_embeddings.reshape((B, 1, num_language_hashes * D))
], axis=-1)
# The projection matrix will automatically scale to the new combined size
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_1 = hk.get_parameter(
"proj_mat_1",
[combined.shape[-1], D],
init=embed_init
)
user_embedding = jnp.dot(combined.astype(proj_mat_1.dtype), proj_mat_1)
# ...
4. Update the Model Call Site
Finally, ensure the Phoenix model (or the PhoenixRetrievalModel in phoenix/recsys_retrieval_model.py) passes the new embedding fields into the reduction functions.
If you are modifying the retrieval pipeline, update the User Tower logic to include your new feature:
# Inside PhoenixRetrievalModel
user_emb, user_mask = block_user_reduce(
user_hashes=batch.user_hashes,
user_embeddings=embeddings.user_embeddings,
language_embeddings=embeddings.language_embeddings, # Pass the new field
num_user_hashes=self.config.hash_config.num_user_hashes,
num_language_hashes=self.config.hash_config.num_language_hashes,
emb_size=self.config.emb_size,
)
5. Verify with a Test Case
After adding a custom embedding, run the existing test suite to ensure the shapes and projections are still valid. You can use phoenix/test_recsys_model.py as a template to create a new test:
def test_custom_feature_projection(self):
# 1. Initialize your new config
config = HashConfig(num_user_hashes=2, num_language_hashes=1)
# 2. Mock input embeddings [Batch, Num_Hashes, Dimension]
user_embs = jnp.ones((4, 2, 64))
lang_embs = jnp.ones((4, 1, 64))
# 3. Call the reduction function
out, mask = block_user_reduce(
user_hashes=jnp.ones((4, 2)),
user_embeddings=user_embs,
language_embeddings=lang_embs,
num_user_hashes=2,
num_language_hashes=1,
emb_size=64
)
# 4. Assert the output is projected back to the model dimension (D=64)
assert out.shape == (4, 1, 64)
Key Considerations
- Initialization Scale: If you add many new features, you may need to adjust the
embed_init_scaleto prevent activation variance from exploding. - Padding: Remember that by convention,
hash 0is reserved for padding. Ensure your data pipeline respects this so that theuser_padding_maskcorrectly identifies valid users.