Skip to content

gemma3 template forces mm_plugin and breaks text-only Gemma-3-270M; add a text-only template #9671

@huolongteng

Description

@huolongteng

Reminder

  • I have read the above rules and searched the existing issues.

System Info

While training google/gemma-3-270m-it with LLaMA-Factory, I encountered a preprocessing failure caused by the default gemma3 template.
Although the dataset is pure text (ShareGPT format) and contains no multimodal fields, the gemma3 template is currently registered with a mandatory mm_plugin. As a result, the data pipeline always enters the multimodal processing path and requires a processor, which does not exist for gemma-3-270m-it (a text-only model).

In template.py, the gemma3 template is defined as a multimodal template:
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>")

However, gemma-3-270m-it is a text-only model and does not provide a processor, so training fails before it even starts.

For example, adding the following template in template.py resolves the issue:

This is a special template for Gemma-3 text-only models (e.g. 270M).

register_template(
name="gemma3-sp",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
replace_eos=True,
# Note: no mm_plugin here (text-only)
template_class=Llama2Template,
)

Reproduction

Converting format of dataset (num_proc=16): 5264 examples [00:00, 3385.55 examples/s] Running tokenizer on dataset (num_proc=16): 0%| | 0/2632 [00:09<?, ? examples/s] [rank0]: multiprocess.pool.RemoteTraceback: [rank0]: """ [rank0]: Traceback (most recent call last): [rank0]: File "/usr/local/lib/python3.12/dist-packages/multiprocess/pool.py", line 125, in worker [rank0]: result = (True, func(*args, **kwds)) [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/utils/py_utils.py", line 586, in _write_generator_to_queue [rank0]: for i, result in enumerate(func(**kwargs)): [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py", line 3674, in _map_single [rank0]: for i, batch in iter_outputs(shard_iterable): [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py", line 3624, in iter_outputs [rank0]: yield i, apply_function(example, i, offset=offset) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py", line 3547, in apply_function [rank0]: processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/data/processor/supervised.py", line 99, in preprocess_dataset [rank0]: input_ids, labels = self._encode_data_example( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/data/processor/supervised.py", line 43, in _encode_data_example [rank0]: messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/data/mm_plugin.py", line 513, in process_messages [rank0]: self._validate_input(processor, images, videos, audios) [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/data/mm_plugin.py", line 189, in _validate_input [rank0]: raise ValueError("Processor was not found, please check and update your model file.") [rank0]: ValueError: Processor was not found, please check and update your model file. [rank0]: """ [rank0]: The above exception was the direct cause of the following exception: [rank0]: Traceback (most recent call last): [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/launcher.py", line 185, in <module> [rank0]: run_exp() [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/train/tuner.py", line 126, in run_exp [rank0]: _training_function(config={"args": args, "callbacks": callbacks}) [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/train/tuner.py", line 88, in _training_function [rank0]: run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 52, in run_sft [rank0]: dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/data/loader.py", line 318, in get_dataset [rank0]: train_dict["train"] = _get_preprocessed_dataset( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/LLaMA-Factory/src/llamafactory/data/loader.py", line 255, in _get_preprocessed_dataset [rank0]: dataset = dataset.map( [rank0]: ^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py", line 560, in wrapper [rank0]: out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/arrow_dataset.py", line 3309, in map [rank0]: for rank, done, content in iflatmap_unordered( [rank0]: File "/usr/local/lib/python3.12/dist-packages/datasets/utils/py_utils.py", line 626, in iflatmap_unordered [rank0]: [async_result.get(timeout=0.05) for async_result in async_results] [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.12/dist-packages/multiprocess/pool.py", line 774, in get [rank0]: raise self._value [rank0]: ValueError: Processor was not found, please check and update your model file. [rank0]:[W1226 05:47:49.548378528 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) [rank0]:[W1226 05:47:50.214391997 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) E1226 05:47:50.897000 5348 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 0 (pid: 5414) of binary: /usr/bin/python3.12 Traceback (most recent call last): File "/usr/local/bin/torchrun", line 7, in <module> sys.exit(main()) ^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main run(args) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run elastic_launch( File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ /workspace/LLaMA-Factory/src/llamafactory/launcher.py FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2025-12-26_05:47:50 host : c6761f2f9b3a rank : 0 (local_rank: 0) exitcode : 1 (pid: 5414) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ [W1226 05:47:51.766334053 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) Traceback (most recent call last): File "/usr/local/bin/llamafactory-cli", line 7, in <module> sys.exit(main()) ^^^^^^ File "/workspace/LLaMA-Factory/src/llamafactory/cli.py", line 24, in main launcher.launch() File "/workspace/LLaMA-Factory/src/llamafactory/launcher.py", line 115, in launch process = subprocess.run( ^^^^^^^^^^^^^^^ File "/usr/lib/python3.12/subprocess.py", line 571, in run raise CalledProcessError(retcode, process.args, subprocess.CalledProcessError: Command '['torchrun', '--nnodes', '1', '--node_rank', '0', '--nproc_per_node', '1', '--master_addr', '127.0.0.1', '--master_port', '42317', '/workspace/LLaMA-Factory/src/llamafactory/launcher.py', 'examples/train_full/llama3_full_sft.yaml', 'model_name_or_path=google/gemma-3-270m-it', 'dataset=train_gemma3', 'max_samples=4096', 'cutoff_len=4096', 'output_dir=saves/gemma-3-270m/full/sft', 'num_train_epochs=3.0', 'template=gemma3']' returned non-zero exit status 1. [W1226 05:47:51.523497700 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())

Others

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingpendingThis problem is yet to be addressed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions