import logging from dataclasses import dataclass from typing import List, Optional import numpy as np import torch from torch.nn import functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from prompts import FIGURE_3_TEMPLATE, FIGURE_5_TEMPLATE logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") @dataclass class EvaluationExample: # Paper uses: Instruction {input}, Response {response}, Criteria {criteria}, Options {options} question: str answer: str context: Optional[str] = "" criteria: str = "" id: Optional[str] = None MODEL = "PatronusAI/Llama-3-Patronus-Lynx-8B-Instruct" # "flowaicom/Flow-Judge-v0.1" fails miserably even with the simplest examples # "PatronusAI/Llama-3-Patronus-Lynx-8B-Instruct" atleast holds its grounds for simple examples. For comples like below it fails too class HuggingFaceLLM: def __init__(self, model_name: str = "unsloth/Qwen3-4B-Base"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") self.model.eval() logging.info(f"Loaded model {model_name} on device {self.model.device} with dtype {self.model.dtype}") def generate(self, prompt: str, max_new_tokens: int = 512, enable_thinking: bool = False) -> str: """ - If tokenizer supports chat templates (Qwen3), wrap as a single user message and prefer apply_chat_template with enable_thinking when supported. - Otherwise, fall back to plain prompt encoding. """ if hasattr(self.tokenizer, "apply_chat_template"): messages = [{"role": "user", "content": prompt}] try: # If some model supports enable_thinking text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) except TypeError: # no enable_thinking kwarg text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) enc = self.tokenizer([text], return_tensors="pt") else: # Fallback (non-chat) enc = self.tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(self.model.device) for k, v in enc.items()} with torch.inference_mode(): logging.debug(f" generate(): do_sample=False, temperature=0.0, max_new_tokens={max_new_tokens}") outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # deterministic temperature=0.0, pad_token_id=self.tokenizer.eos_token_id, ) out_text = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ).strip() # TODO: I haven't checked how it might have worled with thinking ones. Maybe we need to strop it off. Need to test it out. return out_text # Step 3: last-token probability scoring for Equation (1) def score_option_last_token_probability(self, full_prompt_with_answer_option: str, option_text: str) -> float: """ Compute p_ij = p(o_i | q_c(o_i, a_j)) as the conditional probability of the last token of the option o_i under teacher forcing for the full confusion prompt. Returns probability in [0,1] """ # sometime extra space make the model generate a weird output text = full_prompt_with_answer_option.strip() # Tokenize without adding special tokens to keep alignment with exact prompt text enc = self.tokenizer(text, return_tensors="pt", add_special_tokens=False) input_ids = enc["input_ids"].to(self.model.device) attention_mask = enc.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(self.model.device) # Show last few tokens in the input (including the target token) seq_len = int(input_ids.shape[1]) last_n = min(8, seq_len) tail_ids = input_ids[0, -last_n:].tolist() try: tail_tokens = self.tokenizer.convert_ids_to_tokens(tail_ids) except Exception: tail_tokens = [""] logging.debug(f" score(): seq_len={seq_len}, last_ids={tail_ids}, last_tokens={tail_tokens}") # Need at least 2 tokens to compute next-token probability for the last token if input_ids.shape[1] < 2: logging.warning("Input too short for last-token probability; returning NaN.") return float("nan") with torch.inference_mode(): # Basically the last token only (most recent generated). We want only one token ctx_ids = input_ids[:, :-1] ctx_mask = attention_mask[:, :-1] if attention_mask is not None else None outputs = self.model(input_ids=ctx_ids, attention_mask=ctx_mask) logits = outputs.logits # [1, seq_len-1, vocab] last_logits = logits[:, -1, :] # [1, vocab] probs = F.softmax(last_logits, dim=-1) # [1, vocab] # Not sure how researchers would hve used but different vocan can come like ' Pass' or 'Pass ' or 'PASS' etc so this is just to be sure option_norm = option_text.strip().upper() try: k_match = int(min(5, probs.shape[-1])) topk5 = torch.topk(probs[0], k=k_match) top5_ids = topk5.indices.tolist() top5_vals = [float(v) for v in topk5.values.tolist()] top5_decs = [] for tid in top5_ids: s = self.tokenizer.decode([tid]) top5_decs.append(s) # just print top-k for debugging lines = [] for rank, (tid, p, dec) in enumerate(zip(top5_ids, top5_vals, top5_decs), start=1): lines.append(f"{rank}. id={tid}, dec={dec!r}, p={p:.8f}") logging.debug(" score(): Top-5 next-token probs:\n" + "\n".join(lines)) except Exception as e: logging.warning(f" score(): Failed to compute Top-5 for matching: {e}") top5_ids, top5_vals, top5_decs = [], [], [] selected_prob = 0.001 matched_token = None for tid, p, dec in zip(top5_ids, top5_vals, top5_decs): norm_tok = dec.strip().upper() if option_norm in norm_tok or norm_tok in option_norm: # NOTE: Assuming mistakes like Pass, -Pass etc in labels selected_prob = p matched_token = dec break logging.debug(f" score(): matched_prob_for_option={option_text!r} -> {selected_prob:.8f}, matched_token={matched_token!r}") return float(selected_prob) class ConfusionUncertaintyEvaluator: """ - Step 2: Figure 3 and Figure 5 prompts and initial prediction - Step 3: n^2 confusion prompts with p_ij (Eq. 1) - Step 4: u_i and label computation (Eq. 2-3) """ def __init__(self, llm: HuggingFaceLLM, options: List[str], threshold: float = 0.7): self.llm = llm self.options = options self.threshold = threshold def _format_input_block(self, ex: EvaluationExample) -> str: # Paper uses a generic "Instruction". We pass instruction EXACTLY they are if ex.context: return f"{ex.question}\n{ex.context}" return ex.question def _format_options_block(self) -> str: return "\n".join(self.options) # Step 2: exact Figure 3 prompt def generate_assessment(self, ex: EvaluationExample, target_option: str) -> str: prompt = FIGURE_3_TEMPLATE.format( input=self._format_input_block(ex), response=ex.answer, criteria=ex.criteria, options=self._format_options_block(), option=target_option, ) logging.info(f"Generating assessment for option: {target_option}") return self.llm.generate(prompt) # Step 2: exact Figure 5 prompt def create_confusion_prompt(self, ex: EvaluationExample, assessment_text: str, target_option: str) -> str: """ See the Prompt for more info. You'll understand """ prompt = FIGURE_5_TEMPLATE prompt = prompt.replace("{input}", self._format_input_block(ex)) prompt = prompt.replace("{response}", ex.answer) prompt = prompt.replace("{criteria}", ex.criteria) prompt = prompt.replace("{options}", self._format_options_block()) prompt = prompt.replace("{Explanation for option}", assessment_text) prompt = prompt.replace("{Option}", target_option) return prompt # NOTE: Step 2: In paper they say "initial choce from LLM" it means you can override if you have some judgement form LLM already def get_initial_prediction(self, ex: EvaluationExample) -> str: # Direct decision prompt for the initial choice direct_decision_prompt = ( "Consider the evaluation criteria and choose a final answer.\n\n" f"### Instruction:\n{self._format_input_block(ex)}\n\n" f"###Response:\n{ex.answer}\n\n" f"###Evaluation criteria:\n{ex.criteria}\n{self._format_options_block()}\n\n" "Answer:" ) raw = self.llm.generate(direct_decision_prompt, max_new_tokens=8) parsed = raw.strip().split()[0] if raw else "" logging.debug(f" initial_prediction raw={raw!r} parsed={parsed!r}") return parsed # Step 3: n^2 confusion prompts with p_ij per Equation (1) def build_confusion_matrix(self, ex: EvaluationExample) -> np.ndarray: """ Step 3: Construct n^2 confusion prompts and fill matrix C with p_ij as per Equation (1). Rows: options o_i (in self.options order) Columns: assessments a_j (generated from Figure 3 for each option in self.options order) C[i, j] = p(o_i | q_c(o_i, a_j)) computed as last-token probability of option o_i. """ n = len(self.options) if n == 0: raise ValueError("No options provided to build confusion matrix.") # Generate n biased assessments (Figure 3), one per option, in order logging.debug(f" build_confusion_matrix: n={n}, options={self.options}") logging.info("Generating biased assessments (Figure 3) for all options...") assessments: List[str] = [] for opt in self.options: assessments.append(self.generate_assessment(ex, opt)) # Build n^2 confusion prompts (Figure 5) and score p_ij logging.info("Building confusion matrix (n^2 prompts; Equation (1))...") C = np.zeros((n, n), dtype=np.float32) for i, option_i in enumerate(self.options): for j, assessment_j in enumerate(assessments): prompt_ij = self.create_confusion_prompt(ex, assessment_j, option_i) p_ij = self.llm.score_option_last_token_probability(prompt_ij, option_i) if not np.isfinite(p_ij): logging.warning(f"Non-finite p_ij for option '{option_i}', assessment index {j}; setting to 0.0") p_ij = 0.0 C[i, j] = float(p_ij) logging.debug(f"build_confusion_matrix: C[{i},{j}] option='{option_i}' p_ij={p_ij:.8f}") return C # Step 4: Equations (2) and (3) for u_i and label def calculate_uncertainty(self, confusion_matrix: np.ndarray, initial_prediction: str): """ Step 4: - Compute u_i = (1/n) sum_j p_ij per Equation (2). - Apply labeling per Equation (3) AND the narrative rules in 'Setting Uncertainty Labels': * If exactly one row exceeds threshold (alpha) AND it matches the initially chosen option -> low uncertainty. * Otherwise -> high uncertainty (covers multiple rows exceed, none exceed, or mismatch with initial choice). Returns a dict with u (per-option means), label, exceed_mask, initial_index, and selected_index. """ n = len(self.options) if confusion_matrix.shape != (n, n): raise ValueError(f"Confusion matrix shape {confusion_matrix.shape} does not match number of options {n}.") # Equation (2): per-option mean across assessments (average across columns) u = confusion_matrix.mean(axis=1) # Just a helper in case there are spaces etc. To quickly find and map. It's either a Score, True/False, Pass-Fail etc. import re def _norm(s: str) -> str: return re.sub(r"[^a-z0-9_ ]+", "", s.lower()) initial_idx: Optional[int] = None if initial_prediction: norm_init = _norm(initial_prediction) # Exact normalized match for idx, opt in enumerate(self.options): if _norm(opt) == norm_init: initial_idx = idx break # Just some extra stuff. In case some model gives a different thing. (handles 'Option A' vs 'A', etc.) if initial_idx is None and norm_init: for idx, opt in enumerate(self.options): no = _norm(opt) if no and (no in norm_init or norm_init in no): initial_idx = idx break # Equation (3) and narrative rules with threshold alpha exceed_mask = (u >= self.threshold) count_exceed = int(exceed_mask.sum()) label = "high uncertainty" selected_idx: Optional[int] = None if count_exceed == 1: # Unique row exceeding threshold selected_idx = int(np.argmax(exceed_mask)) if initial_idx is not None and selected_idx == initial_idx: label = "low uncertainty" else: # Mismatch with initial choice OR missing initial choice -> high label = "high uncertainty" else: # 0 or multiple rows exceed -> high uncertainty label = "high uncertainty" return { "u": u, "label": label, "exceed_mask": exceed_mask, "initial_index": initial_idx, "selected_index": selected_idx, } if __name__ == "__main__": options = ["PASS", "FAIL"] # This is from LYNX Halubench so I've used it directly here ex = EvaluationExample( question="Given the context, decide if the answer is correct and grounded in the context.", context=( "Context:\n" "Australia's federal capital is Canberra, chosen in 1908 and formally named in 1913. " "From 1901 to 1927, the federal parliament met in Melbourne; this historical arrangement " "does not make Melbourne the capital. Sydney is the most populous city and is often " "marketed as the 'nation's capital of culture'—a tourism slogan, not a constitutional status. " "Recent news headlines mention 'capital works' in Sydney; these refer to infrastructure spending, " "not the national capital. The national parliament is seated in Canberra." ), answer="Answer: Australia's capital is Sydney — it's the nation's capital in practice given its size and status.", criteria=( "Choose exactly one option based on the following definitions:\n" "PASS: The answer is correct AND directly supported by the context.\n" "FAIL: The answer is incorrect OR not grounded in the context." ), ) llm = HuggingFaceLLM(MODEL) evaluator = ConfusionUncertaintyEvaluator(llm, options, threshold=0.7) initial_prediction = evaluator.get_initial_prediction(ex) C = evaluator.build_confusion_matrix(ex) result = evaluator.calculate_uncertainty(C, initial_prediction) np.set_printoptions(precision=4, suppress=True) print("Options:", options) print("Initial prediction:", initial_prediction) print("Confusion matrix C (p_ij):") print(C) print("u (per-option mean probs):", result["u"]) print("exceed_mask (u_i >= alpha):", result["exceed_mask"]) print("selected_index (unique exceed):", result["selected_index"]) print("initial_index:", result["initial_index"]) print("Uncertainty label:", result["label"])