# Mlx Lm > You can use use the`mlx-lm`package to fine-tune an LLM with low rank --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/LORA.md # Fine-Tuning with LoRA or QLoRA You can use use the `mlx-lm` package to fine-tune an LLM with low rank adaptation (LoRA) for a target task.[^lora] The example also supports quantized LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - Mistral - Llama - Phi2 - Mixtral - Qwen2 - Gemma - OLMo - MiniCPM - InternLM2 ## Contents - [Run](#Run) - [Fine-tune](#Fine-tune) - [Evaluate](#Evaluate) - [Generate](#Generate) - [Fuse](#Fuse) - [Data](#Data) - [Memory Issues](#Memory-Issues) ## Run First, make sure you have the training dependenices installed: ```shell pip install "mlx-lm[train]" ``` The main command is `mlx_lm.lora`. To see a full list of command-line options run: ```shell mlx_lm.lora --help ``` Note, in the following the `--model` argument can be any compatible Hugging Face repo or a local path to a converted model. You can also specify a YAML config with `-c`/`--config`. For more on the format see the [example YAML](examples/lora_config.yaml). For example: ```shell mlx_lm.lora --config /path/to/config.yaml ``` If command-line flags are also used, they will override the corresponding values in the config. ### Fine-tune To fine-tune a model use: ```shell mlx_lm.lora \ --model \ --train \ --data \ --iters 600 ``` To fine-tune the full model weights, add the `--fine-tune-type full` flag. Currently supported fine-tuning types are `lora` (default), `dora`, and `full`. The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` when using `--train` and a path to a `test.jsonl` when using `--test`. For more details on the data format see the section on [Data](#Data). For example, to fine-tune a Mistral 7B you can use `--model mistralai/Mistral-7B-v0.1`. If `--model` points to a quantized model, then the training will use QLoRA, otherwise it will use regular LoRA. By default, the adapter config and learned weights are saved in `adapters/`. You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. #### Logging You can log training metrics to Weights & Biases using `--report-to wandb`, or to SwanLab using `--report-to swanlab`. Make sure to install the required packages beforehand: `pip install wandb` or `pip install swanlab`. You can enable both tracking tools simultaneously by separating them with a comma, for example: `--report-to wandb,swanlab`. To specify a project name for the logging tracker, use `--project-name `. #### Prompt Masking The default training computes a loss for every token in the sample. You can ignore the prompt and compute loss for just the completion by passing `--mask-prompt`. Note this is only supported for `chat` and `completion` datasets. For `chat` datasets the final message in the message list is considered the completion. See the [dataset section](#Data) for more details. ### Evaluate To compute test set perplexity use: ```shell mlx_lm.lora \ --model \ --adapter-path \ --data \ --test ``` ### Generate For generation use `mlx_lm.generate`: ```shell mlx_lm.generate \ --model \ --adapter-path \ --prompt "" ``` ## Fuse You can generate a model fused with the low-rank adapters using the `mlx_lm.fuse` command. This command also allows you to optionally: - Upload the fused model to the Hugging Face Hub. - Export the fused model to GGUF. Note GGUF support is limited to Mistral, Mixtral, and Llama style models in fp16 precision. To see supported options run: ```shell mlx_lm.fuse --help ``` To generate the fused model run: ```shell mlx_lm.fuse --model ``` This will by default load the adapters from `adapters/`, and save the fused model in the path `fused_model/`. All of these are configurable. To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments to `mlx_lm.fuse`. The latter is the repo name of the original model, which is useful for the sake of attribution and model versioning. For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: ```shell mlx_lm.fuse \ --model mistralai/Mistral-7B-v0.1 \ --upload-repo mlx-community/my-lora-mistral-7b \ --hf-path mistralai/Mistral-7B-v0.1 ``` To export a fused model to GGUF, run: ```shell mlx_lm.fuse \ --model mistralai/Mistral-7B-v0.1 \ --export-gguf ``` This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You can specify the file name with `--gguf-path`. ## Data The LoRA command expects you to provide a dataset with `--data`. The MLX Examples GitHub repo has an [example of the WikiSQL data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the correct format. Datasets can be specified in `*.jsonl` files locally or loaded from Hugging Face. ### Local Datasets For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data loader expects a `test.jsonl` in the data directory. Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text` data formats. Here are examples of these formats: `chat`: ```jsonl {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]} ``` `tools`: ```jsonl {"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]} ```
View the expanded single data tool format ```jsonl { "messages": [ { "role": "user", "content": "What is the weather in San Francisco?" }, { "role": "assistant", "tool_calls": [ { "id": "call_id", "type": "function", "function": { "name": "get_current_weather", "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" } } ] } ], "tools": [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and country, eg. San Francisco, USA" }, "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } }, "required": ["location", "format"] } } } ] } ``` The format for the `arguments` field in a function varies for different models. Common formats include JSON strings and dictionaries. The example provided follows the format used by [OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples) and [Mistral AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct). A dictionary format is used in Hugging Face's [chat templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example). Refer to the documentation for the model you are fine-tuning for more details.
`completions`: ```jsonl {"prompt": "What is the capital of France?", "completion": "Paris."} ``` For the `completions` data format, a different key can be used for the prompt and completion by specifying the following in the YAML config: ```yaml prompt_feature: "input" completion_feature: "output" ``` Here, `"input"` is the expected key instead of the default `"prompt"`, and `"output"` is the expected key instead of `"completion"`. `text`: ```jsonl {"text": "This is an example for the model."} ``` Note, the format is automatically determined by the dataset. Note also, keys in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than > one example per line and do not split an example across multiple lines. ### Hugging Face Datasets To use Hugging Face datasets, first install the `datasets` package: ``` pip install datasets ``` If the Hugging Face dataset is already in a supported format, you can specify it on the command line. For example, pass `--data mlx-community/wikisql` to train on the pre-formatted WikiwSQL data. Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: ```yaml hf_dataset: path: "billsum" prompt_feature: "text" completion_feature: "summary" ``` - Use `prompt_feature` and `completion_feature` to specify keys for a `completions` dataset. Use `text_feature` to specify the key for a `text` dataset. Use `chat_feature` to specify the key for a chat dataset. - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. You can specify a list of Hugging Face datasets with a list of records each with the same structure as above. For example: ```yaml hf_dataset: - path: "Open-Orca/OpenOrca" train_split: "train[:90%]" valid_split: "train[-10%:]" prompt_feature: "question" completion_feature: "response" - path: "trl-lib/ultrafeedback_binarized" train_split: "train[:90%]" valid_split: "train[-10%:]" chat_feature: "chosen" ``` - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). In general, for the `chat`, `tools` and `completions` formats, Hugging Face [chat templates](https://huggingface.co/docs/transformers/main/en/chat_templating) are used. This applies the model's chat template by default. If the model does not have a chat template, then Hugging Face will use a default. For example, the final text in the `chat` example above with Hugging Face's default template becomes: ```text <|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user Hello.<|im_end|> <|im_start|>assistant How can I assistant you today.<|im_end|> ``` If you are unsure of the format to use, the `chat` or `completions` are good to start with. For custom requirements on the format of the dataset, use the `text` format to assemble the content yourself. ## Memory Issues Fine-tuning a large model with LoRA requires a machine with a decent amount of memory. Here are some tips to reduce memory use should you need to do so: 1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model with `convert.py` and the `-q` flag. See the [Setup](#setup) section for more details. 2. Try using a smaller batch size with `--batch-size`. The default is `4` so setting this to `2` or `1` will reduce memory consumption. This may slow things down a little, but will also reduce the memory use. You can increase the effective batch size without increasing the memory use by accumulating gradients using `--grad-accumulation-steps ` which will accumulate the gradient of `` batches before updating the parameters. 3. Reduce the number of layers to fine-tune with `--num-layers`. The default is `16`, so you can try `8` or `4`. This reduces the amount of memory needed for back propagation. It may also reduce the quality of the fine-tuned model if you are fine-tuning with a lot of data. 4. Longer examples require more memory. If it makes sense for your data, one thing you can do is break your examples into smaller sequences when making the `{train, valid, test}.jsonl` files. 5. Gradient checkpointing lets you trade-off memory use (less) for computation (more) by recomputing instead of storing intermediate values needed by the backward pass. You can use gradient checkpointing by passing the `--grad-checkpoint` flag. Gradient checkpointing will be more helpful for larger batch sizes or sequence lengths with smaller or quantized models. For example, for a machine with 32 GB the following should run reasonably fast: ``` mlx_lm.lora \ --model mistralai/Mistral-7B-v0.1 \ --train \ --batch-size 1 \ --num-layers 4 \ --data mlx-community/wikisql ``` The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second, using the MLX Example [`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) data set. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/benchmark.py ## API Reference ```python def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="LLM benchmarking script") parser.add_argument( "--model", type=str, help=( "The path to the local model directory or Hugging Face repo. " f"If no model is specified, then {DEFAULT_MODEL} is used." ), default=None, ) parser.add_argument( "--prompt-tokens", "-p", default=512, help="Length of prompt", type=int, ) parser.add_argument( "--generation-tokens", "-g", default=1024, help="Length of completion", type=int, ) parser.add_argument( "--batch-size", "-b", default=1, help="Batch size", type=int, ) parser.add_argument( "--num-trials", "-n", default=5, help="Number of timing trials", type=int, ) parser.add_argument( "--pipeline", action="store_true", help="Use pipelining instead of tensor parallelism", ) return parser def main(): parser = setup_arg_parser() args = parser.parse_args() mx.random.seed(0) group = mx.distributed.init() rank = group.rank() pipeline_group = group if args.pipeline else None tensor_group = group if not args.pipeline else None def rprint(*args, **kwargs): if rank == 0: print(*args, **kwargs) model_path = args.model or DEFAULT_MODEL if group.size() > 1: model, tokenizer, config = sharded_load( model_path, pipeline_group, tensor_group, return_config=True ) else: model, tokenizer, config = load( model_path, return_config=True, tokenizer_config={"trust_remote_code": True} ) # Empty to avoid early stopping tokenizer._eos_token_ids = {} prompt_tokens = args.prompt_tokens generation_tokens = args.generation_tokens batch_size = args.batch_size vocab_size = config.get("vocab_size") or config["text_config"]["vocab_size"] prompts = mx.random.randint(0, vocab_size, (batch_size, prompt_tokens)).tolist() prompt = prompts[0] def single_bench(): for response in stream_generate( model, tokenizer, prompt, max_tokens=generation_tokens ): pass return response def batch_bench(): return batch_generate( model, tokenizer, prompts, max_tokens=generation_tokens ).stats if batch_size == 1: _bench = single_bench else: _bench = batch_bench ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/cache_prompt.py ## API Reference ```python def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser( description="Cache the state of a prompt to be reused with mlx_lm.generate" ) parser.add_argument( "--model", type=str, default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) parser.add_argument( "--adapter-path", type=str, help="Optional path for the trained adapter weights and config.", ) parser.add_argument( "--trust-remote-code", action="store_true", help="Enable trusting remote code for tokenizer", ) parser.add_argument( "--eos-token", type=str, default=None, help="End of sequence token for tokenizer", ) parser.add_argument( "--max-kv-size", type=int, default=None, help="Set the maximum key-value cache size", ) parser.add_argument( "--prompt-cache-file", help="The file to save the prompt cache in", required=True, ) parser.add_argument( "--prompt", required=True, help="Message to be processed by the model ('-' reads from stdin)", ) parser.add_argument( "--kv-bits", type=int, help="Number of bits for KV cache quantization. " "Defaults to no quantization.", default=None, ) parser.add_argument( "--kv-group-size", type=int, help="Group size for KV cache quantization.", default=64, ) parser.add_argument( "--quantized-kv-start", help="When --kv-bits is set, start quantizing the KV cache " "from this step onwards.", type=int, default=DEFAULT_QUANTIZED_KV_START, ) return parser def main(): parser = setup_arg_parser() args = parser.parse_args() # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt if tokenizer.has_chat_template: messages = [{"role": "user", "content": args.prompt}] prompt = tokenizer.apply_chat_template( messages, add_generation_prompt=False, continue_final_message=True, ) else: prompt = tokenizer.encode(args.prompt) cache = make_prompt_cache(model, args.max_kv_size) y = mx.array(prompt) # Process the prompt start = time.time() max_msg_len = 0 ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/chat.py ## API Reference ```python def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="Chat with an LLM") parser.add_argument( "--model", type=str, help="The path to the local model directory or Hugging Face repo.", default=DEFAULT_MODEL, ) parser.add_argument( "--trust-remote-code", action="store_true", help="Enable trusting remote code for tokenizer", ) parser.add_argument( "--adapter-path", type=str, help="Optional path for the trained adapter weights and config.", ) parser.add_argument( "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" ) parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) parser.add_argument( "--xtc-probability", type=float, default=DEFAULT_XTC_PROBABILITY, help="Probability of XTC sampling to happen each next token", ) parser.add_argument( "--xtc-threshold", type=float, default=0.0, help="Thresold the probs of each next token candidate to be sampled by XTC", ) parser.add_argument( "--seed", type=int, default=DEFAULT_SEED, help="PRNG seed", ) parser.add_argument( "--max-kv-size", type=int, help="Set the maximum key-value cache size", default=None, ) parser.add_argument( "--max-tokens", "-m", type=int, default=DEFAULT_MAX_TOKENS, help="Maximum number of tokens to generate", ) parser.add_argument( "--system-prompt", default=None, help="System prompt to be used for the chat template", ) parser.add_argument( "--pipeline", action="store_true", help="Use pipelining instead of tensor parallelism", ) return parser def main(): parser = setup_arg_parser() args = parser.parse_args() group = mx.distributed.init() rank = group.rank() pipeline_group = group if args.pipeline else None tensor_group = group if not args.pipeline else None def rprint(*args, **kwargs): if rank == 0: print(*args, **kwargs) if args.seed is not None: mx.random.seed(args.seed) if group.size() > 1: if args.adapter_path: parser.error("Adapters not supported in distributed mode") model, tokenizer = sharded_load(args.model, pipeline_group, tensor_group) else: model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config={ "trust_remote_code": True if args.trust_remote_code else None }, ) def print_help(): rprint("The command list:") ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/convert.py ## API Reference ```python def mixed_quant_predicate_builder( def mixed_quant_predicate( """Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M. Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265 By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3 def convert( def set_dtype(k, v): def configure_parser() -> argparse.ArgumentParser: """ Configures and returns the argument parser for the script. Returns: argparse.ArgumentParser: Configured argument parser. def main(): ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/evaluate.py ## API Reference ```python """ Adapted from a PyTorch implementation by David Grangier def _rstrip_until(s, untils): """Limit a string to the first occurrence of any substring in untils.""" l = len(s) f = [s.find(u) for u in untils] f = [l if x < 0 else x for x in f] return s[: min(f)] def _lstrip(s, pattern): def _pad_inputs(inputs): def chat_template_fn(**extra_kwargs): def apply_chat_template(self, chat_history, add_generation_prompt=True) -> str: class MLXLM(LM): def __init__( def _process_prompt(self, prompt, step_size: int = 2048): def _score_fn(self, inputs, cache: Optional[Any] = None, step_size: int = 2048): def _tokenize(self, texts): def loglikelihood(self, requests) -> list[tuple[float, bool]]: """Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. :param requests: list[Instance] A list of Instance objects, with property `args` which returns a tuple (context, continuation). `context: str` Context string. Implementations of LM must be able to handle an empty context string. `continuation: str` The continuation over which log likelihood will be calculated. If there is a word boundary, the space should be in the continuation. For example, context="hello" continuation=" world" is correct. :return: list[tuple[float, bool]] A list of pairs (logprob, isgreedy) `logprob: float` The log probability of `continuation`. `isgreedy`: Whether `continuation` would be generated by greedy sampling from `context`. def loglikelihood_rolling(self, requests) -> list[float]: """Compute full log-likelihood of a string, with no truncation, for perplexity computation - We will use the full max context length of the model. - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to the max context length. - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations which may simply concatenate multiple documents together. - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into multiple chunks, the last input will still a full-sized context. Example: Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] Prefix: EOT Max context length: 4 Resulting input/prediction pairs: INPUT: EOT 0 1 2 PRED: 0 1 2 3 INPUT: 3 4 5 6 PRED: 4 5 6 7 INPUT: 5 6 7 8 PRED: 8 9 Observe that: 1. Each token is predicted exactly once 2. For the last pair, we provide the full context, but only score the last two tokens :param requests: list[Instance] A list of Instance objects with property `args` which returns a tuple (context,). string: str String for which we are computing overall loglikelihood :return: list[tuple[float]] A list of tuples (logprob,) logprob: float The log probability of `context` conditioned on the EOT token. def generate_until(self, requests) -> list[str]: """Generate greedily until a stopping sequence :param requests: list[Instance] A list of Instance objects with property `args` which returns a tuple (context, until). context: str Context string until: [str] The string sequences to generate until. These string sequences may each span across multiple tokens, or may be part of one token. :return: list[str] A list of strings continuation continuation: str The generated continuation. def main(): help="""A JSON formatted string of arguments for the tokenizer's ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/fuse.py ## API Reference ```python def parse_arguments() -> argparse.Namespace: def main() -> None: ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/generate.py ## API Reference ```python def str2bool(string): def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="LLM inference script") parser.add_argument( "--model", type=str, help=( "The path to the local model directory or Hugging Face repo. " f"If no model is specified, then {DEFAULT_MODEL} is used." ), default=None, ) parser.add_argument( "--trust-remote-code", action="store_true", help="Enable trusting remote code for tokenizer", ) parser.add_argument( "--adapter-path", type=str, help="Optional path for the trained adapter weights and config.", ) parser.add_argument( "--extra-eos-token", type=str, default=(), nargs="+", help="Add tokens in the list of eos tokens that stop generation.", ) parser.add_argument( "--system-prompt", default=None, help="System prompt to be used for the chat template", ) parser.add_argument( "--prompt", "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) parser.add_argument( "--prefill-response", default=None, help="Prefill response to be used for the chat template", ) parser.add_argument( "--max-tokens", "-m", type=int, default=DEFAULT_MAX_TOKENS, help="Maximum number of tokens to generate", ) parser.add_argument( "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" ) parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) parser.add_argument( "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" ) parser.add_argument( "--top-k", type=int, default=DEFAULT_TOP_K, help="Sampling top-k" ) parser.add_argument( "--xtc-probability", type=float, default=DEFAULT_XTC_PROBABILITY, help="Probability of XTC sampling to happen each next token", ) parser.add_argument( "--xtc-threshold", type=float, default=0.0, help="Thresold the probs of each next token candidate to be sampled by XTC", ) parser.add_argument( "--min-tokens-to-keep", type=int, default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) parser.add_argument( "--seed", type=int, default=DEFAULT_SEED, help="PRNG seed", ) parser.add_argument( "--ignore-chat-template", action="store_true", help="Use the raw prompt without the tokenizer's chat template.", ) parser.add_argument( "--use-default-chat-template", action="store_true", help="Use the default chat template", ) parser.add_argument( ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/gguf.py ## API Reference ```python class TokenType(IntEnum): class GGMLFileType(IntEnum): class HfVocab: def __init__( def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: def get_token_type( def get_token_score(self, token_id: int) -> float: def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: def has_newline_token(self): def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: def __repr__(self) -> str: def load(path: Path) -> "HfVocab": def translate_weight_names(name): def permute_weights(weights, n_head, n_head_kv=None): def prepare_metadata(config, vocab): def convert_to_gguf( ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/lora.py ## API Reference ```python """^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) def build_parser(): def train_model( def evaluate_model(args, model: nn.Module, test_set): def run(args, training_callback: TrainingCallback = None): def main(): ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/manage.py ## API Reference ```python def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: """ Inspired by: - stackoverflow.com/a/8356620/593036 - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data def ask_for_confirmation(message: str) -> bool: """Ask user for confirmation with Y/N prompt. def main(): ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/perplexity.py ## API Reference ```python """ Evaluate perplexity (PPL) of MLX models. def load_data( def eval_ppl(model, data, batch_size=8): """ Evaluate perplexity on a dataset with standard error calculation. Args: model: The model to evaluate data: Tokenized data tensor batch_size: Batch size for evaluation Returns: tuple: (perplexity, standard_error) def main(): ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/sample_utils.py ## API Reference ```python def make_sampler( """ Make a sampler function for use with ``generate_step``. Args: temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. min_p (float, optional): The minimum value (scaled by the top token's probability) that a token probability must have to be considered. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. top_k (int, optional): The top k tokens ranked by probability to constrain the sampling to. xtc_probability (float, optional): The probability of applying XTC sampling. xtc_threshold (float, optional): The threshold the probs need to reach for being sampled. xtc_special_tokens (list(int), optional): List of special tokens IDs to be excluded from XTC sampling. Returns: Callable[mx.array, mx.array]: A sampler which takes log-probabilities and returns tokens. def sampler(logprobs): def make_logits_processors( """ Make logits processors for use with ``generate_step``. Args: repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. logit_bias (dictionary, optional): Additive logit bias. Returns: List[Callable[[mx.array, mx.array], mx.array]]: A list of logits processors. Each processor in the list is a callable which takes an array of tokens and an array of logits and returns the updated logits. def logit_bias_processor(_, logits): def apply_top_k( """ Sample from only the top K tokens ranked by probability. Args: logprobs: A vector of log probabilities. top_k (int): Top k tokens to sample from. def apply_min_p( """ Apply min-p sampling to the logprobs. Min-p keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter is more aggressive given a very high-probability token. Args: logprobs: A vector of log probabilities. min_p (float): Minimum token probability. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. def apply_top_p(logprobs: mx.array, top_p: float) -> mx.array: """ Apply top-p (nucleus) sampling to logits. Args: logprobs: A vector of log probabilities. top_p: The cumulative probability threshold for top-p filtering. Returns: token selected based on the top-p criterion. def apply_xtc( """ Apply XTC sampling to the logits. Args: logits: The logits from the model's output. xtc_probability (float): Probability of XTC sampling to happen for each token xtc_threshold (float): The threshold the probs need to reach for being sampled. special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling. def categorical_sampling(logits, temp): def make_repetition_penalty(penalty: float, context_size: int = 20): """ Make repetition penalty processor. Paper: https://arxiv.org/abs/1909.05858 Args: penalty (float): The repetition penalty factor to be applied. context_size (int): The number of previous tokens to use. Default: ``20``. Returns: Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/server.py ## API Reference ```python def get_system_fingerprint(): class StopCondition(NamedTuple): def stopping_criteria( """ Determines whether the token generation should stop based on predefined conditions. Args: tokens (List[int]): The current sequence of generated tokens. eos_token_ids (set): The token IDs that represents the end-of-sequence. If the last token in ``tokens`` is in the set, the generation should stop. stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs. If the end of the `tokens` list matches any of these sequences, the generation should stop. stop_words (List[str]): The stop words that correspond to the ``stop_id_sequences``. Returns: StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`) and how many tokens should be trimmed from the end if it has (`trim_length`) as well as the text that should be trimmed. def sequence_overlap(s1: Sequence, s2: Sequence) -> bool: """ Checks if a suffix of s1 has overlap with a prefix of s2 Args: s1 (Sequence): The first sequence s2 (Sequence): The second sequence Returns: bool: If the two sequences have overlap def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): def process_message_content(messages): """ Convert message content to a format suitable for `apply_chat_template`. The function operates on messages in place. It converts the 'content' field to a string instead of a list of text fragments. Args: message_list (list): A list of dictionaries, where each dictionary may have a 'content' key containing a list of dictionaries with 'type' and 'text' keys. Raises: ValueError: If the 'content' type is not supported or if 'text' is missing. class LRUPromptCache: class CacheEntry: class SearchResult: def __init__(self, max_size: int = 10): def _search(self, model, tokens): """Search the cache for a prompt cache. Return exact or close match.""" if model not in self._cache: return self.SearchResult(model, None, None, None, 0) current = self._cache[model] last_cache_index = -1 index = 0 while index < len(tokens) and tokens[index] in current: current = current[tokens[index]] if "cache" in current: last_cache_index = index index += 1 # Exact match no need to search for longer or shorter caches if last_cache_index == len(tokens) - 1: return self.SearchResult(model, tokens, None, None, 0) # Find the shorter cache shorter = None if last_cache_index > 0: shorter = tokens[: last_cache_index + 1] # Check for caches that are longer longer = None common_prefix = index if index > 0 and last_cache_index <= 0: best = None stack = [(current, [])] while stack: current, extra = stack.pop() if "cache" in current: if best is None or len(extra) < len(best): best = extra else: for tok in current: stack.append((current[tok], extra + [tok])) longer = tokens[:index] + best return self.SearchResult(model, None, shorter, longer, common_prefix) def _get(self, model, tokens): current = self._cache[model] for tok in tokens: current = current[tok] return current["cache"] ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tokenizer_utils.py ## API Reference ```python class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. Example usage is as follows: detokenizer = ... # Reset the tokenizer state detokenizer.reset() for token in generate(...): detokenizer.add_token(token.item()) # Contains the whole text so far. Some tokens may not be included # since it contains whole words usually. detokenizer.text # Contains the printable segment (usually a word) since the last # time it was accessed detokenizer.last_segment # Contains all the tokens added so far detokenizer.tokens # Make sure that we detokenize any remaining tokens detokenizer.finalize() # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) def reset(self): def add_token(self, token): def finalize(self): def last_segment(self): """Return the last segment of readable text since last time this property was accessed.""" text = self.text segment = text[self.offset :] self.offset = len(text) return segment class NaiveStreamingDetokenizer(StreamingDetokenizer): """ def __init__(self, tokenizer): self._tokenizer = tokenizer self._tokenizer.decode([0]) self.reset() def reset(self): self.offset = 0 self.tokens = [] self._text = "" self._current_tokens = [] self._current_text = "" def add_token(self, token): self._current_tokens.append(token) self.tokens.append(token) def finalize(self): self._text += self._tokenizer.decode(self._current_tokens) self._current_tokens = [] self._current_text = "" @property def text(self): if self._current_tokens: self._current_text = self._tokenizer.decode(self._current_tokens) if self._current_text.endswith("\ufffd"): self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": self._text += self._current_text self._current_tokens.clear() self._current_text = "" return self._text + self._current_text class SPMStreamingDetokenizer(StreamingDetokenizer): """ def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): if value.startswith("<0x"): # Replace bytes with their value self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) else: self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 self._unflushed = b"" self.text = "" self.tokens = [] ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/upload.py ## API Reference ```python def main(): ``` --- # Source: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/utils.py ## API Reference ```python def _unpack_awq_weights(qweight: mx.array) -> mx.array: def _transform_awq_weights( def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. Args: config (dict): The model configuration. Returns: A tuple containing the Model class and the ModelArgs class. def get_total_parameters(model): def nparams(m): def compute_bits_per_weight(model): def _download( """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. Args: path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. Returns: Path: The local file path. def hf_repo_to_path(hf_repo): def load_config(model_path: Path) -> dict: def load_model( """ Load and initialize the model from a given path. Args: model_path (Path): The path to load the model from. lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` strict (bool): Whether or not to raise an exception if weights don't match. Default: ``True`` model_config (dict, optional): Optional configuration parameters for the model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. Defaults to the ``_get_classes`` function. Returns: Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config. Raises: FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. def _quantize(quantization): def class_predicate(p, m): def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: def load_tokenizer(model_path, tokenizer_config_extra=None, eos_token_ids=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. def load( """ Load the model and tokenizer from a given path or a huggingface repository. Args: path_or_hf_repo (Path): The path or the huggingface repository to load the model from. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. model_config(dict, optional): Configuration parameters specifically for the model. Defaults to an empty dictionary. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: ``None``. lazy (bool): If ``False`` eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` return_config (bool: If ``True`` return the model config as the last item.. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. Returns: Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]: A tuple containing the loaded model, tokenizer and, if requested, the model config. Raises: FileNotFoundError: If config file or safetensors are not found. ValueError: If model class or args class are not found. def sharded_load( def pipeline_load(repo, return_config=False): def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: """ Splits the weights into smaller shards. Args: weights (dict): Model weights. max_file_size_gb (int): Maximum size of each shard in gigabytes. Returns: list: List of weight shards. def create_model_card(path: Union[str, Path], hf_path: Union[str, Path, None]): """ Uploads the model to Hugging Face hub. Args: path (Union[str, Path]): Local path to the model. hf_path (Union[str, Path, None]): Path to the original Hugging Face model. def upload_to_hub(path: str, upload_repo: str): ```