Skip to content

Commit ca96688

Browse files
committed
Fit model footprint in CPU if GPU memory is small
1 parent d549976 commit ca96688

File tree

3 files changed

+108
-75
lines changed

3 files changed

+108
-75
lines changed

‎src/train/server/src/app.py‎

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,23 @@ def completion():
4141
return jsonify({"error": "Missing required params"}), 400
4242

4343
# Get the required attributes from the request body
44-
model_name = request.json["model"]
44+
model_name = request.json["model_name"]
4545
training_data = request.json["training_data"]
4646
hf_token = request.json["hf_token"]
4747
deploy_to_hugging_face = request.json["deploy_to_hugging_face"]
4848
model_path = request.json["model_path"]
4949

50-
dataset_path = "" #TODO: Make CSV from json received in training data
51-
# Save that CSV locally
50+
print(training_data, hf_token, deploy_to_hugging_face, model_path)
5251

53-
llm_train = LLMTrain(model_name, dataset_path)
54-
# Call make completion which calls LiteLLM which calls Vertex AI
55-
endpont = ""
56-
if not endpont:
57-
raise ValueError("ResponseUndefined")
52+
llm_train = LLMTrain(model_name, training_data)
53+
llm_train.run_train(model_name, training_data, deploy_to_hugging_face, model_path)
54+
55+
# endpont = ""
56+
# if not endpont:
57+
# raise ValueError("ResponseUndefined")
5858

5959
# Return response
60-
return jsonify({"response": response,
60+
return jsonify({"response": "",
6161
"success": True}), 200
6262
except Exception as e:
6363
app.logger.error(str(e))
Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,55 @@
1+
accelerate @ git+https://github.com/huggingface/accelerate.git@c9fbb71e37e7f64f5df54b39270bdabe82f1b893
12
aiohttp==3.8.6
23
aiosignal==1.3.1
3-
appdirs==1.4.4
44
async-timeout==4.0.3
55
attrs==23.1.0
6-
blinker==1.6.2
7-
cachetools==5.3.1
6+
bitsandbytes==0.39.0
7+
blinker==1.7.0
88
certifi==2023.7.22
9-
charset-normalizer==3.3.0
10-
click==8.1.3
11-
filelock==3.12.4
12-
Flask==2.3.2
13-
Flask-Cors==3.0.10
9+
charset-normalizer==3.3.2
10+
click==8.1.7
11+
datasets==2.12.0
12+
dill==0.3.6
13+
einops==0.6.1
14+
filelock==3.13.1
15+
Flask==3.0.0
1416
frozenlist==1.4.0
15-
fsspec==2023.9.2
16-
google-api-core==2.12.0
17-
google-auth==2.23.3
18-
google-cloud-aiplatform==1.35.0
19-
google-cloud-bigquery==3.12.0
20-
google-cloud-core==2.3.3
21-
google-cloud-resource-manager==1.10.4
22-
google-cloud-storage==2.12.0
23-
google-crc32c==1.5.0
24-
google-resumable-media==2.6.0
25-
googleapis-common-protos==1.61.0
26-
grpc-google-iam-v1==0.12.6
27-
grpcio==1.59.0
28-
grpcio-status==1.59.0
29-
huggingface-hub==0.17.3
17+
fsspec==2023.10.0
18+
gunicorn==21.2.0
19+
huggingface-hub==0.19.0
3020
idna==3.4
3121
importlib-metadata==6.8.0
3222
itsdangerous==2.1.2
3323
Jinja2==3.1.2
34-
jsonify==0.5
35-
litellm==0.8.4
36-
MarkupSafe==2.1.2
24+
loralib==0.1.1
25+
MarkupSafe==2.1.3
26+
mpmath==1.3.0
3727
multidict==6.0.4
28+
multiprocess==0.70.14
29+
networkx==3.1
3830
numpy==1.24.4
39-
openai==0.28.1
4031
packaging==23.2
41-
proto-plus==1.22.3
42-
protobuf==4.24.4
43-
pyasn1==0.5.0
44-
pyasn1-modules==0.3.0
32+
pandas==2.0.3
33+
peft @ git+https://github.com/huggingface/peft.git@42a184f7423fc0bbc102a085851a8fb6e40132ad
34+
psutil==5.9.6
35+
pyarrow==14.0.1
4536
python-dateutil==2.8.2
46-
python-dotenv==1.0.0
37+
pytz==2023.3.post1
4738
PyYAML==6.0.1
4839
regex==2023.10.3
4940
requests==2.31.0
50-
rsa==4.9
51-
shapely==2.0.2
41+
responses==0.18.0
42+
safetensors==0.4.0
5243
six==1.16.0
53-
tiktoken==0.5.1
54-
tokenizers==0.14.1
44+
sympy==1.12
45+
tokenizers==0.13.3
46+
torch==2.0.1
5547
tqdm==4.66.1
48+
transformers @ git+https://github.com/huggingface/transformers.git@e03a9cc0cd7623a8d5208d7a4206f628b2bd5513
5649
typing_extensions==4.8.0
57-
urllib3==2.0.6
58-
Werkzeug==2.3.4
50+
tzdata==2023.3
51+
urllib3==2.0.7
52+
Werkzeug==3.0.1
53+
xxhash==3.4.1
5954
yarl==1.9.2
60-
zipp==3.15.0
55+
zipp==3.17.0

‎src/train/server/src/train.py‎

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import torch
77
import torch.nn as nn
88
import transformers
9-
from datasets import load_dataset
9+
from datasets import (
10+
load_dataset,
11+
Dataset
12+
)
1013
from peft import (
1114
LoraConfig,
1215
PeftConfig,
@@ -27,30 +30,40 @@
2730
Train LLMs
2831
"""
2932

33+
3034
class LLMTrain:
3135
# Initialize the class with model and data path
32-
def __init__(self, MODEL_NAME, dataset_path) -> None:
36+
def __init__(self, MODEL_NAME, training_data) -> None:
3337
self.MODEL_NAME = MODEL_NAME
34-
self.dataset_path = dataset_path
38+
self.training_data = training_data
3539

3640
# Method to create transformer model and tokenizer
3741
def create_model_and_tokenizer(self):
38-
# Define Quantization configuration to optimize model
42+
# Define Quantization configuration to optimize model
3943
bnb_config = BitsAndBytesConfig(
4044
load_in_4bit=True,
4145
bnb_4bit_use_double_quant=True,
4246
bnb_4bit_quant_type="nf4",
43-
bnb_4bit_compute_dtype=torch.bfloat16
47+
bnb_4bit_compute_dtype=torch.bfloat16,
48+
load_in_8bit_fp32_cpu_offload=True # Set offloading to CPU.
4449
)
50+
# Create a device map
51+
device_map = {
52+
0: ["transformer.h.0.", "transformer.h.1."],
53+
1: ["transformer.h.2.", "transformer.h.3."],
54+
-1: ["transformer.h.4.", "transformer.h.5.", "transformer.h.6.", "transformer.h.7."]
55+
}
4556
# Create Transformer model based on given model name
4657
model = AutoModelForCausalLM.from_pretrained(
4758
self.MODEL_NAME,
48-
device_map="auto",
59+
device_map=device_map, # Pass a custom device map,
4960
trust_remote_code=True,
5061
quantization_config=bnb_config
5162
)
5263
# Create a tokenizer for the designated model
5364
tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
65+
tokenizer.pad_token = tokenizer.eos_token
66+
self.tokenizer = tokenizer
5467
return model, tokenizer
5568

5669
# Method to prepare and configure the model for training
@@ -66,32 +79,51 @@ def prepare_and_configure_model(self, model):
6679
bias="none",
6780
task_type="CAUSAL_LM"
6881
)
69-
model = get_peft_model(model, config) # Apply the defined configuration to the model
82+
# Apply the defined configuration to the model
83+
model = get_peft_model(model, config)
7084
self.print_trainable_parameters(model)
7185
return model
7286

7387
# Method to generate result based on user provided prompt
7488
def generate_future_with_prompt(self, model, tokenizer, prompt):
7589
generation_config = model.generation_config
7690
device = "cuda:0"
77-
# Encoding the prompt using tokenizer
91+
# Encoding the prompt using tokenizer
7892
encoding = tokenizer(prompt, return_tensors="pt").to(device)
7993
with torch.inference_mode():
8094
outputs = model.generate(
81-
input_ids = encoding.input_ids,
82-
attention_mask = encoding.attention_mask,
83-
generation_config = generation_config
95+
input_ids=encoding.input_ids,
96+
attention_mask=encoding.attention_mask,
97+
generation_config=generation_config
8498
)
8599
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
86100

87-
# Method to load and tokenize the dataset
88-
def load_and_tokenize_data(self, tokenizer):
89-
data = load_dataset("csv", data_files=self.dataset_path)
90-
data = data["train"].shuffle().map(self.generate_and_tokenize_prompt)
91-
return data
101+
"""
102+
Method to load and tokenize the dataset
103+
It expects an array of object each object of the format:
104+
{
105+
'input': '{{user_input}}',
106+
'output': '{{model_output}}'
107+
}
108+
"""
109+
110+
def load_training_data(self, data):
111+
# Convert array of objects to dictionary format
112+
data_dict = {
113+
'input': [obj['input'] for obj in data],
114+
'output': [obj['output'] for obj in data]
115+
}
116+
d = Dataset.from_dict(data_dict)
117+
d = d.shuffle().map(
118+
self.generate_and_tokenize_prompt,
119+
batched=True,
120+
remove_columns=["input", "output"],
121+
load_from_cache_file=False
122+
)
123+
return d
92124

93125
# Method to fine tune the model
94-
def fine_tune_model(self, model, data, tokenizer):
126+
def fine_tune_model(self, model, data, tokenizer):
95127
training_args = transformers.TrainingArguments(
96128
per_device_train_batch_size=1,
97129
gradient_accumulation_steps=4,
@@ -109,24 +141,29 @@ def fine_tune_model(self, model, data, tokenizer):
109141
model=model,
110142
train_dataset=data,
111143
args=training_args,
112-
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
144+
data_collator=transformers.DataCollatorForLanguageModeling(
145+
tokenizer, mlm=False)
113146
)
114147
return trainer
115148

116149
# Run a complete training cycle
117-
def run_train(self, MODEL_NAME, dataset_path, deploy_to_hf, model_path):
150+
def run_train(self, MODEL_NAME, training_data, deploy_to_hf, model_path):
118151
self.MODEL_NAME = MODEL_NAME
119-
self.dataset_path = dataset_path
152+
print("create_model_and_tokenizer")
120153
model, tokenizer = self.create_model_and_tokenizer()
154+
print("prepare_and_configure_model")
121155
model = self.prepare_and_configure_model(model)
122156

123157
prompt = """
124158
<human>: midjourney prompt for a girl sit on the mountain
125159
<assistant>:
126160
""".strip()
161+
print("generating future with prompt")
127162
self.generate_future_with_prompt(model, tokenizer, prompt)
163+
print("loading training data")
164+
data = self.load_training_data(training_data)
165+
print("\n\ndata:\n\n", data)
128166

129-
data = self.load_and_tokenize_data(tokenizer)
130167
trainer = self.fine_tune_model(model, data, tokenizer)
131168
trainer.train()
132169

@@ -143,15 +180,16 @@ def deploy_to_hugging_face(self, model, model_path):
143180
# Generate dialog prompt with human and assistant tags
144181
def generate_prompt(self, data_point):
145182
return f"""
146-
<human>: {data_point["User"]}
147-
<assistant>: {data_point["Prompt"]}
183+
<human>: {data_point["input"]}
184+
<assistant>: {data_point["output"]}
148185
""".strip()
149186

150187
# Tokenize the generated dialog prompt
151188
def generate_and_tokenize_prompt(self, data_point):
152189
full_prompt = self.generate_prompt(data_point)
153190
# padding and truncation are set to True for handling sequences of different length.
154-
tokenized_full_prompt = self.tokenizer(full_prompt, padding=True, truncation=True)
191+
tokenized_full_prompt = self.tokenizer(
192+
full_prompt, padding=True, truncation=True)
155193
return tokenized_full_prompt
156194

157195
# Print the number of parameters that are trainable in the model
@@ -163,9 +201,9 @@ def print_trainable_parameters(self, model):
163201
all_param = 0
164202

165203
for _, param in model.named_parameters():
166-
all_param += param.numel() # Total parameters
204+
all_param += param.numel() # Total parameters
167205
if param.requires_grad:
168-
trainable_params += param.numel() # Trainable parameters
206+
trainable_params += param.numel() # Trainable parameters
169207
print(
170208
f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}"
171-
)
209+
)

0 commit comments

Comments
 (0)