Install MLX LM:
pip install mlx-lm
And run:
python reason.py
The default model is mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit. You
can specify the model with --model.
To see all the options:
python reason.py --help
| # Copyright © 2023-2024 Apple Inc. | |
| import argparse | |
| import json | |
| import mlx.core as mx | |
| from functools import partial | |
| from mlx_lm.models.cache import make_prompt_cache, trim_prompt_cache | |
| from mlx_lm.sample_utils import make_sampler | |
| from mlx_lm.utils import load, stream_generate | |
| DEFAULT_TEMP = 0.0 | |
| DEFAULT_TOP_P = 1.0 | |
| DEFAULT_SEED = 0 | |
| DEFAULT_MAX_TOKENS = 4096 | |
| DEFAULT_MODEL = "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit" | |
| def setup_arg_parser(): | |
| """Set up and return the argument parser.""" | |
| parser = argparse.ArgumentParser(description="Chat with an LLM") | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| help="The path to the local model directory or Hugging Face repo.", | |
| default=DEFAULT_MODEL, | |
| ) | |
| parser.add_argument( | |
| "--adapter-path", | |
| type=str, | |
| help="Optional path for the trained adapter weights and config.", | |
| ) | |
| parser.add_argument( | |
| "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" | |
| ) | |
| parser.add_argument( | |
| "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" | |
| ) | |
| parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") | |
| parser.add_argument( | |
| "--max-tokens", | |
| "-m", | |
| type=int, | |
| default=DEFAULT_MAX_TOKENS, | |
| help="Maximum number of tokens to generate", | |
| ) | |
| return parser | |
| def main(): | |
| parser = setup_arg_parser() | |
| args = parser.parse_args() | |
| mx.random.seed(args.seed) | |
| model, tokenizer = load( | |
| args.model, | |
| adapter_path=args.adapter_path, | |
| tokenizer_config={"trust_remote_code": True}, | |
| ) | |
| wait_token = "Wait" | |
| wait_token_id = tokenizer.convert_tokens_to_ids(wait_token) | |
| end_think_token = "</think>" | |
| end_think_token_id = tokenizer.convert_tokens_to_ids(end_think_token) | |
| think_more_prompt = mx.array([wait_token_id], mx.uint32) | |
| end_think_prompt = mx.array( | |
| tokenizer.encode(end_think_token + "\n", add_special_tokens=False), mx.uint32 | |
| ) | |
| generator = partial( | |
| stream_generate, | |
| model=model, | |
| tokenizer=tokenizer, | |
| sampler=make_sampler(args.temp, args.top_p), | |
| ) | |
| print(f"[INFO] Starting reasoning session with {args.model}. To exit, enter 'q'.") | |
| while True: | |
| prompt_cache = make_prompt_cache(model) | |
| query = input(">> ") | |
| if query == "q": | |
| break | |
| messages = [{"role": "user", "content": query}] | |
| prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | |
| while True: | |
| max_tokens = args.max_tokens | |
| end_think_idx = None | |
| for response in generator( | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| prompt_cache=prompt_cache, | |
| ): | |
| if response.token == wait_token_id: | |
| break | |
| elif response.token == end_think_token_id: | |
| end_think_idx = prompt_cache[0].offset | |
| print(response.text, flush=True, end="") | |
| max_tokens -= response.generation_tokens | |
| # If we got a wait token insert </think> and generate the response | |
| if end_think_idx is None: | |
| print(end_think_token, flush=True) | |
| end_think_idx = prompt_cache[0].offset | |
| prompt = end_think_prompt | |
| # Trim the wait token from the cache | |
| trim_prompt_cache(prompt_cache, 1) | |
| # Generate answer | |
| for response in generator( | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| prompt_cache=prompt_cache, | |
| ): | |
| print(response.text, flush=True, end="") | |
| max_tokens -= response.generation_tokens | |
| think_more = input( | |
| "\n\n\033[31mWould you like me to think more? (y/n):\033[0m " | |
| ) | |
| if think_more == "y": | |
| # Trim the prompt cache to just before the end of think token | |
| print("<think>") | |
| print(wait_token, flush=True, end="") | |
| num_to_trim = prompt_cache[0].offset - end_think_idx + 1 | |
| max_tokens += num_to_trim | |
| trim_prompt_cache(prompt_cache, num_to_trim) | |
| prompt = think_more_prompt | |
| else: | |
| break | |
| print() | |
| if __name__ == "__main__": | |
| main() |