Skip to content

Commit 03f5444

Browse files
DARREN OBERSTDARREN OBERST
authored andcommitted
updating model tests
1 parent 0d4a82f commit 03f5444

9 files changed

+163
-378
lines changed

‎tests/experimental/inference_server.py‎

Lines changed: 0 additions & 206 deletions
This file was deleted.

‎tests/models/test_all_generative_models.py‎

Lines changed: 0 additions & 53 deletions
This file was deleted.
Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
""" Basic connectivity tests to cloud API providers. """
2+
13
import os
24
from llmware.prompts import Prompt
35

@@ -7,39 +9,34 @@
79
google_api_key = os.environ.get("GOOGLE_API_KEY","")
810
cohere_api_key = os.environ.get("COHERE_API_KEY", "")
911

12+
1013
# Simple test to make sure we are reaching OpenAI
1114
def test_openai():
1215
prompter = Prompt(llm_name="gpt-4", llm_api_key=openai_api_key)
1316
response = prompter.completion("what is artificial intelligence?")
1417
llm_response = response["llm_response"]
1518
assert 'artificial' in llm_response.lower()
1619

20+
1721
# Simple test to make sure we are reaching Google
1822
def test_google():
1923
prompter = Prompt(llm_name="text-bison@001", llm_api_key=google_api_key)
2024
response = prompter.completion("what is artificial intelligence?")
2125
llm_response = response["llm_response"]
2226
assert 'artificial' in llm_response.lower()
2327

28+
2429
# Simple test to make sure we are reaching Anthropic
2530
def test_anthropic():
2631
prompter = Prompt(llm_name="claude-instant-v1", llm_api_key=anthropic_api_key)
2732
response = prompter.completion("what is artificial intelligence?")
2833
llm_response = response["llm_response"]
2934
assert 'artificial' in llm_response.lower()
3035

36+
3137
# Simple test to make sure we are reaching AI21
3238
def test_ai21():
3339
prompter = Prompt(llm_name="j2-grande-instruct", llm_api_key=ai21_api_key)
3440
response = prompter.completion("what is artificial intelligence?")
3541
llm_response = response["llm_response"]
3642
assert 'artificial' in llm_response.lower()
37-
38-
# Simple test to make sure we are reaching Cohere. Disabling due to Cohere temporarily rate-limiting summarization for Trial accounts
39-
# def test_cohere():
40-
# user_managed_secrets_setup()
41-
# prompter = Prompt(llm_name="summarize-medium", llm_api_key=os.environ["USER_MANAGED_COHERE_API_KEY"])
42-
# response = prompter.completion("what is artificial intelligence?")
43-
# llm_response = response["llm_response"]
44-
# print(llm_response)
45-
# assert 'artificial' in llm_response.lower()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
2+
""" Test that GGUF models are loading correctly in local environment. By default, will run through a series of
3+
different GGUF models in the ModelCatalog to spot-check that the model is correctly loading and
4+
successfully completing an inference:
5+
6+
# tests several different underlying models:
7+
8+
# bling-answer-tool -> tiny-llama (1b)
9+
# bling-phi-3-gguf -> phi-3 (3.8b)
10+
# dragon-yi-answer-tool -> yi (6b)
11+
# dragon-llama-answer-tool -> llama-2 (7b)
12+
# llama-2-7b-chat-gguf -> llama-2-chat (7b)
13+
# dragon-mistral-answer-tool -> mistral-1 (7b)
14+
15+
"""
16+
17+
18+
from llmware.models import ModelCatalog
19+
20+
21+
def test_gguf_model_load():
22+
23+
# feel free to adapt this model list
24+
25+
model_list = ["bling-answer-tool",
26+
"bling-phi-3-gguf",
27+
"dragon-yi-answer-tool",
28+
"dragon-llama-answer-tool",
29+
"llama-2-7b-chat-gguf",
30+
"dragon-mistral-answer-tool"]
31+
32+
# please note that the unusually short and simple prompt at times actually yields more variability in the model
33+
# response - we are only testing for successful loading and inference
34+
35+
sample_prompt = ("The company stock declined by $12 after poor earnings results."
36+
"\nHow much did the stock price decline?")
37+
38+
for model_name in model_list:
39+
40+
print("\nmodel name: ", model_name)
41+
42+
model = ModelCatalog().load_model(model_name, temperature=0.0, sample=False)
43+
44+
response = model.inference(sample_prompt)
45+
46+
print(f"{model_name} - response: ", response)
47+
48+
assert response is not None
49+

0 commit comments

Comments
 (0)