|
| 1 | + |
| 2 | +"""This runs a benchmark test dataset against a series of prompts. It can be used to test any model type for |
| 3 | + longer running series of prompts, as well as the fact-checking capability. """ |
| 4 | + |
| 5 | + |
| 6 | +import time |
| 7 | +import random |
| 8 | + |
| 9 | +from llmware.prompts import Prompt |
| 10 | +from datasets import load_dataset |
| 11 | + |
| 12 | + |
| 13 | +def load_rag_benchmark_tester_dataset(): |
| 14 | + |
| 15 | + """ Loads benchmark dataset used in the prompt test. """ |
| 16 | + |
| 17 | + dataset_name = "llmware/rag_instruct_benchmark_tester" |
| 18 | + print(f"\n > Loading RAG dataset '{dataset_name}'...") |
| 19 | + dataset = load_dataset(dataset_name) |
| 20 | + |
| 21 | + test_set = [] |
| 22 | + for i, samples in enumerate(dataset["train"]): |
| 23 | + test_set.append(samples) |
| 24 | + |
| 25 | + return test_set |
| 26 | + |
| 27 | + |
| 28 | +# Run the benchmark test |
| 29 | +def test_prompt_rag_benchmark(): |
| 30 | + |
| 31 | + test_dataset = load_rag_benchmark_tester_dataset() |
| 32 | + |
| 33 | + # SELECTED MODELS |
| 34 | + |
| 35 | + selected_test_models = ["llmware/bling-1b-0.1", "llmware/bling-1.4b-0.1", "llmware/bling-falcon-1b-0.1", |
| 36 | + "llmware/bling-tiny-llama-v0", |
| 37 | + "bling-phi-3-gguf", "bling-answer-tool", "dragon-yi-answer-tool", |
| 38 | + "dragon-llama-answer-tool", "dragon-mistral-answer-tool"] |
| 39 | + |
| 40 | + # randomly select one model from the list |
| 41 | + r = random.randint(0,len(selected_test_models)-1) |
| 42 | + |
| 43 | + model_name = selected_test_models[r] |
| 44 | + |
| 45 | + print(f"\n > Loading model '{model_name}'") |
| 46 | + prompter = Prompt().load_model(model_name) |
| 47 | + |
| 48 | + print(f"\n > Running RAG Benchmark Test against '{model_name}' - 200 questions") |
| 49 | + for i, entry in enumerate(test_dataset): |
| 50 | + |
| 51 | + start_time = time.time() |
| 52 | + |
| 53 | + prompt = entry["query"] |
| 54 | + context = entry["context"] |
| 55 | + response = prompter.prompt_main(prompt, context=context, prompt_name="default_with_context", temperature=0.3) |
| 56 | + |
| 57 | + assert response is not None |
| 58 | + |
| 59 | + # Print results |
| 60 | + time_taken = round(time.time() - start_time, 2) |
| 61 | + print("\n") |
| 62 | + print(f"{i + 1}. llm_response - {response['llm_response']}") |
| 63 | + print(f"{i + 1}. gold_answer - {entry['answer']}") |
| 64 | + print(f"{i + 1}. time_taken - {time_taken}") |
| 65 | + |
| 66 | + # Fact checking |
| 67 | + fc = prompter.evidence_check_numbers(response) |
| 68 | + sc = prompter.evidence_comparison_stats(response) |
| 69 | + sr = prompter.evidence_check_sources(response) |
| 70 | + |
| 71 | + for fc_entry in fc: |
| 72 | + for f, facts in enumerate(fc_entry["fact_check"]): |
| 73 | + print(f"{i + 1}. fact_check - {f} {facts}") |
| 74 | + |
| 75 | + for sc_entry in sc: |
| 76 | + print(f"{i + 1}. comparison_stats - {sc_entry['comparison_stats']}") |
| 77 | + |
| 78 | + for sr_entry in sr: |
| 79 | + for s, source in enumerate(sr_entry["source_review"]): |
| 80 | + print(f"{i + 1}. source - {s} {source}") |
| 81 | + |
| 82 | + return 0 |
| 83 | + |
| 84 | + |
0 commit comments