I’ve been using conversational alpaca or sharegpt formats for fine-tuning LLMs with Axolotl , but it always felt unnecessary to limit the model on a conversational format when the use-case doesn’t require so.

I’m currently working on a project to classify pull requests in my company’s code repositories. The model needs to look at the PR title, description, and code changes, then categorize the PR and explain its reasoning. I thought there must be a way to fine-tune these models with any format I see fitting this specific use-case, and sure there is: Template-free Axolotl

This seemed exactly what I was looking for, but the emphais on “masking inputs” made me confused:

If you declare a dataset format such as alpaca or chatml, axolotl knows what is an input (i.e. human) vs. an output (i.e. the assistant) and masks the input labels so that your model can focus on predicting the outputs only. You can construct your prompts without a template by using the input_output format, by setting type: input_output in your configuration file. Unlike type: completion, which is also template-free, type: input_output allows you to mask segments of your text.

This was the first time I came across the concept of input masking. It felt counterintuitive at first - why would we want our model to ignore parts of the training data? Isn’t all of it important? Wouldn’t masking the input prevent the model from learning to respond properly when encountering similar cases during inference time? After going deeper in the Axolotl docs and codes, I’ve learned some stuff I’d like to share.

What is input masking?

Input masking is a technique where we intentionally hide parts of the input data from the model during training. It’s like covering up certain words in a sentence and asking the model to focus on predicting only the uncovered parts. But why should we do this? There are several facts we need to know:

  • Distinction between input and output: In many fine-tuning scenarios, particularly for instruction-following or question-answering tasks, we want the model to learn to generate appropriate responses (outputs) given certain prompts or questions (inputs). By masking the input, we’re emphasizing that the model should focus on generating the correct output rather than simply memorizing and repeating the input.

  • Preventing overfitting to prompt templates: If we train the model on both the input and output, it might start to rely too heavily on specific phrasings or templates in the input. By masking the input, we encourage the model to be more flexible and generalize better to variations in how questions or instructions might be phrased.

  • Focusing on task completion: The goal is often to have the model learn to perform a task or generate relevant information, not to reproduce the exact input format. Masking helps direct the model’s attention to the core task.

  • Efficiency in fine-tuning: By not training on inputs, we can potentially reduce the amount of computation needed during fine-tuning, as the model doesn’t need to adjust its parameters to predict the input tokens.

  • Handling different prompt structures: This allows for more flexibility in how prompts are structured in our training data, as the model isn’t learning to reproduce specific prompt formats.

To mask or not to mask

Based on the above facts, there are scenarios that we might want to use input masking when fine-tuning a language model:

  • Chatbots and conversational AI
  • Question-answering models
  • Instruction-following tasks

In these cases, we care more about the model generating good responses than repeating exact input phrasings. However, masking isn’t always the best choice. We might want to avoid it when:

  • Every detail in the input is crucial for the task
  • The model needs to understand and reference specific parts of the input
  • We’re working with specialized or technical content where context is key

My use case: classifying pull requests

As I mentioned previously:

I’m currently working on a project to classify pull requests (PRs) in my company’s code repositories. The model needs to look at the PR title, description, and code changes, then categorize the PR and explain its reasoning.

For this task, I’ve decided not to use input masking. Here’s why:

  • Importance of input details: In PR classification, every detail in the title, description, and code changes could be crucial for making an accurate classification. By training on the full input, we allow the model to learn important patterns and correlations between specific input features and the resulting classifications.

  • Contextual understanding: PRs often contain technical jargon, project-specific terms, or code snippets that are important for understanding the nature of the changes. Training on these inputs helps the model build a better contextual understanding of my company’s development practices and codebase.

  • Reasoning requirement: Since my task includes providing reasoning for the classification, the model needs to learn how to reference and analyze specific parts of the input. This is more effectively achieved if the model is trained to process and potentially reproduce relevant parts of the input.

  • Limited output structure: Unlike more open-ended generation tasks, my output has a specific structure (JSON with classification and reasoning). The risk of the model simply copying large portions of the input into the output is lower, given this constrained format.

  • Potential for implicit feature extraction: By training on the full input, the model might learn to identify subtle features or patterns in PR descriptions or code changes that are indicative of certain classifications. This implicit feature extraction could lead to more accurate classifications.

  • Handling variations in input: PRs can vary greatly in length, structure, and content. Training on full inputs helps the model handle this variability more effectively.

Let me know if I have missed anything or if you know a good learning resource on this topic.


Comment? Reply via Email, Mastodon or Twitter.