Skip to content

Commit c663576

Browse files
committed
Add MultiModalModel class for handling multi-modal data processing and inference
1 parent 01a6e4d commit c663576

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

‎llmware/models.py‎

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11240,6 +11240,115 @@ def _load_model(self, model_name, sample=False, temperature=0.0, get_logits=Fals
1124011240
return True
1124111241

1124211242

11243+
class MultiModalModel:
11244+
"""A class to handle multi-modal models, supporting text, image, and other data types."""
11245+
11246+
def __init__(self, model_name, model_type, preprocessors=None, postprocessors=None):
11247+
self.model_name = model_name
11248+
self.model_type = model_type
11249+
self.preprocessors = preprocessors or {}
11250+
self.postprocessors = postprocessors or {}
11251+
11252+
def add_preprocessor(self, data_type, preprocessor):
11253+
"""Add a preprocessor for a specific data type."""
11254+
self.preprocessors[data_type] = preprocessor
11255+
11256+
def add_postprocessor(self, data_type, postprocessor):
11257+
"""Add a postprocessor for a specific data type."""
11258+
self.postprocessors[data_type] = postprocessor
11259+
11260+
def preprocess(self, data_type, data):
11261+
"""Preprocess data based on its type."""
11262+
if data_type in self.preprocessors:
11263+
return self.preprocessors[data_type](data)
11264+
return data
11265+
11266+
def postprocess(self, data_type, data):
11267+
"""Postprocess data based on its type."""
11268+
if data_type in self.postprocessors:
11269+
return self.postprocessors[data_type](data)
11270+
return data
11271+
11272+
def inference(self, inputs):
11273+
"""Perform inference on multi-modal inputs."""
11274+
processed_inputs = {
11275+
data_type: self.preprocess(data_type, data)
11276+
for data_type, data in inputs.items()
11277+
}
11278+
# Placeholder for model inference logic
11279+
raw_outputs = self._run_model(processed_inputs)
11280+
return {
11281+
data_type: self.postprocess(data_type, output)
11282+
for data_type, output in raw_outputs.items()
11283+
}
11284+
11285+
def _run_model(self, inputs):
11286+
"""Run the model on preprocessed inputs based on the model type."""
11287+
if not hasattr(self, 'model') or self.model is None:
11288+
raise ValueError("Model is not loaded. Please load a model before running inference.")
11289+
11290+
if self.model_type == "pytorch":
11291+
# PyTorch inference
11292+
import torch
11293+
input_tensors = {
11294+
data_type: torch.tensor(data) if isinstance(data, list) else torch.from_numpy(data)
11295+
for data_type, data in inputs.items()
11296+
}
11297+
with torch.no_grad():
11298+
outputs = {
11299+
data_type: self.model(input_tensor.unsqueeze(0))
11300+
for data_type, input_tensor in input_tensors.items()
11301+
}
11302+
return {data_type: output.squeeze(0).numpy() for data_type, output in outputs.items()}
11303+
11304+
elif self.model_type == "onnx":
11305+
# ONNX inference
11306+
import onnxruntime as ort
11307+
session = ort.InferenceSession(self.model)
11308+
outputs = {
11309+
data_type: session.run(None, {session.get_inputs()[0].name: data})[0]
11310+
for data_type, data in inputs.items()
11311+
}
11312+
return outputs
11313+
11314+
elif self.model_type == "openvino":
11315+
# OpenVino inference
11316+
from openvino.runtime import Core
11317+
core = Core()
11318+
compiled_model = core.compile_model(self.model, "CPU")
11319+
outputs = {
11320+
data_type: compiled_model([data])[0]
11321+
for data_type, data in inputs.items()
11322+
}
11323+
return outputs
11324+
11325+
elif self.model_type == "gguf":
11326+
# GGUF inference (example placeholder)
11327+
# Assuming GGUF uses a specific library for inference
11328+
from llmware.gguf_configs import GGUFInference
11329+
gguf_inference = GGUFInference(self.model)
11330+
outputs = {
11331+
data_type: gguf_inference.run(data)
11332+
for data_type, data in inputs.items()
11333+
}
11334+
return outputs
11335+
11336+
elif self.model_type == "tensorflow":
11337+
# TensorFlow inference
11338+
import tensorflow as tf
11339+
input_tensors = {
11340+
data_type: tf.convert_to_tensor(data) if isinstance(data, list) else tf.constant(data)
11341+
for data_type, data in inputs.items()
11342+
}
11343+
outputs = {
11344+
data_type: self.model(input_tensor[None, ...])
11345+
for data_type, input_tensor in input_tensors.items()
11346+
}
11347+
return {data_type: output.numpy() for data_type, output in outputs.items()}
11348+
11349+
else:
11350+
raise ValueError(f"Unsupported model type: {self.model_type}")
11351+
1124311352
class PyTorchLoader:
1124411353

1124511354
""" PyTorchLoader is a wrapper class that consolidates all of the PyTorch model loading functions

0 commit comments

Comments
 (0)