Back to Mistral Rs

Custom logits processor that modifies token probabilities during generation

docs/src/content/docs/examples/rust/advanced/logits-processor.md

0.8.62.1 KB
Original Source
<!-- generated by docs/scripts/render_examples.py; edit the source example instead -->

Custom logits processor that modifies token probabilities during generation.

Run with: cargo run --release --example logits_processor -p mistralrs

rust
//! Custom logits processor that modifies token probabilities during generation.
//!
//! Run with: `cargo run --release --example logits_processor -p mistralrs`

use std::sync::Arc;

use anyhow::Result;
use mistralrs::{
    CustomLogitsProcessor, IsqBits, ModelBuilder, PagedAttentionMetaBuilder, RequestBuilder,
    Tensor, TextMessageRole,
};
use rand::Rng;

struct ThresholdLogitsProcessor {
    threshold: f64,
}

impl CustomLogitsProcessor for ThresholdLogitsProcessor {
    fn apply(&self, logits: &Tensor, _context: &[u32]) -> mistralrs::Result<Tensor> {
        // Mask is 1 for true, 0 for false.
        let mask = logits.ge(self.threshold)?;
        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    let model = ModelBuilder::new("Qwen/Qwen3-4B")
        .with_auto_isq(IsqBits::Four)
        .with_logging()
        .with_paged_attn(PagedAttentionMetaBuilder::default().build()?)
        .build()
        .await?;

    let mut rng = rand::rng();
    let random_value: f64 = rng.random_range(0.0..=1.0);
    let threshold: f64 = rng.random_range(0.0..=0.5);

    let request = RequestBuilder::new()
        .add_logits_processor(Arc::new(move |logits: &Tensor, _context: &[u32]| {
            logits * random_value
        }))
        .add_logits_processor(Arc::new(ThresholdLogitsProcessor { threshold }))
        .add_message(
            TextMessageRole::User,
            "Please write a mathematical equation where a few numbers are added.",
        );

    let response = model.send_chat_request(request).await?;

    println!("{}", response.choices[0].message.content.as_ref().unwrap());

    Ok(())
}

Source: mistralrs/examples/advanced/logits_processor/main.rs