Skip to content

Commit 646dbb6

Browse files
authored
Merge pull request #9 from aws-samples/claude3branch
Claude3branch merge request to main branch
2 parents fc5bf2a + 4d8e9e7 commit 646dbb6

8 files changed

+328
-134
lines changed

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

‎BedrockTextToSql_for_Athena.ipynb

Lines changed: 121 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,25 @@
8181
"**Prerequisite**\n",
8282
"\n",
8383
"The following are prerequisites that needs to be accomplised before executing this notebook.\n",
84-
"- A Sagemaker instance with a role having access to bedrock, glue,athena, s3,lakeformation\n",
84+
"This Notebook can be executed via a Sagemaker instance or via a VS Code editor\n",
85+
"- Create a role having access to bedrock, glue,athena, s3,lakeformation. \n",
86+
"- Assign the role to the Sagemaker instance or to the instance where VS Code editor is running\n",
8587
"- Glue Database and tables. Provided spark notebook to create.\n",
8688
"- An Amazon OpenSearch cluster for storing embeddings.Here Opensearch credenitals are in notebooks. However Opensearch cluster's access credentials (username and password) can be stored in AWS Secrets Mananger by following steps described [here](https://docs.aws.amazon.com/secretsmanager/latest/userguide/managing-secrets.html).\n",
8789
"\n",
88-
"**The overall workflow for this notebook is as follows:**\n",
89-
"1. Download data from source https://developer.imdb.com/non-commercial-datasets/#titleakastsvgz and upload to S3.\n",
90-
"1. Create database and load datasets in Glue. Make sure see of the you are able to query through athena. \n",
91-
"1. Install the required Python packages (specifically boto version mentioned)\n",
92-
"1. Create embedding and vector store.Do a similarity search with embeddings stored in the OpenSearch index for an input query.\n",
90+
"**The workflow for this notebook is as follows:**\n",
91+
"1. Create an S3 bucket with the name \"knowledgebase-<ACCOUNT_ID>\" \n",
92+
" - create a folder \"input\" in that bucket\n",
93+
"2. Download data from source \n",
94+
" - https://developer.imdb.com/non-commercial-datasets/#titleakastsvgz and upload to S3 bucket from step 1 and into the \"input\" folder\n",
95+
" - https://developer.imdb.com/non-commercial-datasets/#titlebasicstsvgz and upload to S3 bucket from step 1 and into the \"input\" folder\n",
96+
"3. Glue Steps\n",
97+
" - Create a glue database \"imdb_stg\" \n",
98+
" - Create a glue crawler \"text-2-sql-crawler\" with the datasource set to the S3 bucket created in step 1. Run the crawler.\n",
99+
" - 2 tables should be created in Glue data catalo.g Make sure you are able to query through athena. \n",
100+
"4. From the Bedrock console, Create a new knowledgebase \n",
101+
"1. Install the required Python packages \n",
102+
"1. Create embedding and vector store. Do a similarity search with embeddings stored in the OpenSearch index for an input query.\n",
93103
"1. Execute this notebook to generate sql.."
94104
]
95105
},
@@ -108,7 +118,7 @@
108118
},
109119
{
110120
"cell_type": "code",
111-
"execution_count": 2,
121+
"execution_count": null,
112122
"id": "efc3af34-9c4b-4e95-9147-cf498a74a0c2",
113123
"metadata": {
114124
"pycharm": {
@@ -118,8 +128,16 @@
118128
},
119129
"outputs": [],
120130
"source": [
121-
"# !pip3 install boto3==1.34.8\n",
122-
"# !pip3 install jq"
131+
"!pip3 install boto3\n",
132+
"!pip3 install jq\n",
133+
"\n",
134+
"!pip3 install langchain\n",
135+
"!pip3 install langchain-community langchain-core\n",
136+
"!pip3 install pandas\n",
137+
"!pip3 install opensearch-py\n",
138+
"!pip3 install langchain-aws\n",
139+
"!pip3 install requests-aws4auth\n",
140+
"!pip3 install botocore"
123141
]
124142
},
125143
{
@@ -148,13 +166,14 @@
148166
"source": [
149167
"import boto3\n",
150168
"from botocore.config import Config\n",
151-
"from langchain.llms.bedrock import Bedrock\n",
152-
"from langchain.embeddings import BedrockEmbeddings"
169+
"from langchain_community.embeddings import BedrockEmbeddings\n",
170+
"from langchain_aws import BedrockLLM\n",
171+
"import traceback"
153172
]
154173
},
155174
{
156175
"cell_type": "code",
157-
"execution_count": 10,
176+
"execution_count": 4,
158177
"id": "ef96ac35-3597-4bd8-80ca-035c6c98050b",
159178
"metadata": {
160179
"pycharm": {
@@ -168,15 +187,14 @@
168187
"import json\n",
169188
"import os,sys\n",
170189
"import re\n",
171-
"sys.path.append(\"/home/ec2-user/SageMaker/llm_bedrock_v0/\")\n",
172190
"import time\n",
173191
"import pandas as pd\n",
174192
"import io"
175193
]
176194
},
177195
{
178196
"cell_type": "code",
179-
"execution_count": null,
197+
"execution_count": 5,
180198
"id": "482a9055-6cc2-419c-839f-ca1326b04957",
181199
"metadata": {
182200
"pycharm": {
@@ -223,23 +241,15 @@
223241
},
224242
{
225243
"cell_type": "code",
226-
"execution_count": 7,
244+
"execution_count": null,
227245
"id": "436fb4f1-cd34-4146-bc96-da5e3608d720",
228246
"metadata": {
229247
"pycharm": {
230248
"name": "#%%\n"
231249
},
232250
"tags": []
233251
},
234-
"outputs": [
235-
{
236-
"name": "stdout",
237-
"output_type": "stream",
238-
"text": [
239-
"{'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-tg1-large', 'modelId': 'amazon.titan-tg1-large', 'modelName': 'Titan Text Large', 'providerName': 'Amazon', 'inputModalities': ['TEXT'], 'outputModalities': ['TEXT'], 'responseStreamingSupported': True, 'customizationsSupported': [], 'inferenceTypesSupported': ['ON_DEMAND'], 'modelLifecycle': {'status': 'ACTIVE'}}\n"
240-
]
241-
}
242-
],
252+
"outputs": [],
243253
"source": [
244254
"session = boto3.session.Session()\n",
245255
"bedrock_client = session.client('bedrock')\n",
@@ -260,31 +270,30 @@
260270
},
261271
{
262272
"cell_type": "code",
263-
"execution_count": 11,
273+
"execution_count": null,
264274
"id": "72e44bd9-8787-447a-b5e0-0961547bafef",
265275
"metadata": {
266276
"pycharm": {
267277
"name": "#%%\n"
268278
},
269279
"tags": []
270280
},
271-
"outputs": [
272-
{
273-
"name": "stderr",
274-
"output_type": "stream",
275-
"text": [
276-
"athena client created \n",
277-
"s3 client created !!\n"
278-
]
279-
}
280-
],
281+
"outputs": [],
281282
"source": [
282283
"rqstath=AthenaQueryExecute()"
283284
]
284285
},
286+
{
287+
"cell_type": "markdown",
288+
"id": "c46db72d",
289+
"metadata": {},
290+
"source": [
291+
"### Step 4.1 Update the variables"
292+
]
293+
},
285294
{
286295
"cell_type": "code",
287-
"execution_count": null,
296+
"execution_count": 8,
288297
"id": "e6b0cedd-0fcb-4a3b-a7be-50bd25197230",
289298
"metadata": {
290299
"pycharm": {
@@ -293,7 +302,28 @@
293302
},
294303
"outputs": [],
295304
"source": [
296-
"ebropen=EmbeddingBedrockOpenSearch()"
305+
"\n",
306+
"index_name = 'bedrock-knowledge-base-default-index' \n",
307+
"domain = 'https://OPENSEARCH.aoss.amazonaws.com' ##-- update here with your OpenSearch domain\n",
308+
"region = 'us-east-1' ##-- update here with your AWS region\n",
309+
"vector_name = 'bedrock-knowledge-base-default-vector'\n",
310+
"fieldname = 'id'\n",
311+
" "
312+
]
313+
},
314+
{
315+
"cell_type": "code",
316+
"execution_count": null,
317+
"id": "9af3a226",
318+
"metadata": {},
319+
"outputs": [],
320+
"source": [
321+
"ebropen2=EmbeddingBedrockOpenSearch(domain, vector_name, fieldname)\n",
322+
"if ebropen2 is None:\n",
323+
" print(\"ebropen2 is null\")\n",
324+
"else:\n",
325+
" attrs = vars(ebropen2)\n",
326+
" print(', '.join(\"%s: %s\" % item for item in attrs.items()))"
297327
]
298328
},
299329
{
@@ -313,7 +343,7 @@
313343
},
314344
{
315345
"cell_type": "code",
316-
"execution_count": 14,
346+
"execution_count": 11,
317347
"id": "fdcd87a9-f028-43ab-94a4-8d9b5b5bd163",
318348
"metadata": {
319349
"pycharm": {
@@ -324,15 +354,25 @@
324354
"outputs": [],
325355
"source": [
326356
"class RequestQueryBedrock:\n",
327-
" def __init__(self):\n",
328-
" # self.model_id = \"anthropic.claude-v2\"\n",
329-
" self.bedrock_client = Clientmodules.createBedrockRuntimeClient()\n",
357+
" def __init__(self, ebropen2):\n",
358+
" \n",
359+
" ##self.bedrock_client = Clientmodules.createBedrockRuntimeClient()\n",
360+
" self.ebropen2 = ebropen2\n",
361+
" \n",
362+
"\n",
363+
" self.bedrock_client = ebropen2.bedrock_client\n",
364+
" if self.bedrock_client is None:\n",
365+
" self.bedrock_client = Clientmodules.createBedrockRuntimeClient()\n",
366+
" else : \n",
367+
" print(\"the bedrock_client is not null\")\n",
330368
" self.language_model = LanguageModel(self.bedrock_client)\n",
331369
" self.llm = self.language_model.llm\n",
332-
" def getOpenSearchEmbedding(self,index_name,user_query):\n",
333-
" vcindxdoc=ebropen.getDocumentfromIndex(index_name=index_name)\n",
334-
" documnet=ebropen.getSimilaritySearch(user_query,vcindxdoc)\n",
335-
" return ebropen.format_metadata(documnet)\n",
370+
" \n",
371+
" def getOpenSearchEmbedding(self, index_name,user_query):\n",
372+
" vcindxdoc=self.ebropen2.getDocumentfromIndex(index_name=index_name)\n",
373+
" documnet=self.ebropen2.getSimilaritySearch(user_query,vcindxdoc)\n",
374+
" #return self.ebropen2.format_metadata(documnet)\n",
375+
" return self.ebropen2.get_data(documnet)\n",
336376
" \n",
337377
" def generate_sql(self,prompt, max_attempt=4) ->str:\n",
338378
" \"\"\"\n",
@@ -355,8 +395,9 @@
355395
" logger.info(f'we are in Try block to generate the sql and count is :{attempt+1}')\n",
356396
" generated_sql = self.llm.predict(prompt)\n",
357397
" query_str = generated_sql.split(\"```\")[1]\n",
358-
" query_str = \" \".join(query_str.split(\"\\n\")).strip()\n",
398+
" query_str = \" \".join(query_str.split(\"\\n\")).strip() \n",
359399
" sql_query = query_str[3:] if query_str.startswith(\"sql\") else query_str\n",
400+
" print(sql_query)\n",
360401
" # return sql_query\n",
361402
" syntaxcheckmsg=rqstath.syntax_checker(sql_query)\n",
362403
" if syntaxcheckmsg=='Passed':\n",
@@ -374,6 +415,7 @@
374415
" prompts.append(prompt)\n",
375416
" attempt += 1\n",
376417
" except Exception as e:\n",
418+
" print(e)\n",
377419
" logger.error('FAILED')\n",
378420
" msg = str(e)\n",
379421
" error_messages.append(msg)\n",
@@ -382,32 +424,17 @@
382424
]
383425
},
384426
{
385-
"cell_type": "code",
386-
"execution_count": 15,
387-
"id": "13da2f51-6032-4564-8943-3c36e55b025f",
388-
"metadata": {
389-
"pycharm": {
390-
"name": "#%%\n"
391-
},
392-
"tags": []
393-
},
394-
"outputs": [
395-
{
396-
"name": "stderr",
397-
"output_type": "stream",
398-
"text": [
399-
"bedrock runtime client created \n"
400-
]
401-
}
402-
],
427+
"cell_type": "markdown",
428+
"id": "d5cc21a1",
429+
"metadata": {},
403430
"source": [
404-
"rqst=RequestQueryBedrock()"
431+
"Create an instance of RequestQueryBedrock class"
405432
]
406433
},
407434
{
408435
"cell_type": "code",
409-
"execution_count": 16,
410-
"id": "552419cd-3827-41ee-9aab-ca93c756b9c9",
436+
"execution_count": null,
437+
"id": "13da2f51-6032-4564-8943-3c36e55b025f",
411438
"metadata": {
412439
"pycharm": {
413440
"name": "#%%\n"
@@ -416,12 +443,12 @@
416443
},
417444
"outputs": [],
418445
"source": [
419-
"index_name = 'llm_vector_db_metadata_indx2'"
446+
"rqst=RequestQueryBedrock(ebropen2)"
420447
]
421448
},
422449
{
423450
"cell_type": "code",
424-
"execution_count": 17,
451+
"execution_count": 13,
425452
"id": "db6b4e5f-b52e-4209-aab7-403dabc61239",
426453
"metadata": {
427454
"pycharm": {
@@ -434,10 +461,12 @@
434461
"def userinput(user_query):\n",
435462
" logger.info(f'Searching metadata from vector store')\n",
436463
" # vector_search_match=rqst.getEmbeddding(user_query)\n",
437-
" vector_search_match=rqst.getOpenSearchEmbedding(index_name,user_query)\n",
438-
" # print(vector_search_match)\n",
439-
" details=\"It is important that the SQL query complies with Athena syntax. During join if column name are same please use alias ex llm.customer_id in select statement. It is also important to respect the type of columns: if a column is string, the value should be enclosed in quotes. If you are writing CTEs then include all the required columns. While concatenating a non string column, make sure cast the column to string. For date columns comparing to string , please cast the string input.\"\n",
464+
" vector_search_match=rqst.getOpenSearchEmbedding( index_name,user_query)\n",
465+
" \n",
466+
" \n",
467+
" details=\"It is important that the SQL query complies with Athena syntax. During join if column name are same please use alias ex llm.customer_id in select statement. It is also important to respect the type of columns: if a column is string, the value should be enclosed in quotes. If you are writing CTEs then include all the required columns. While concatenating a non string column, make sure cast the column to string. For date columns comparing to string , please cast the string input. Alwayws use the database name along with the table name\"\n",
440468
" final_question = \"\\n\\nHuman:\"+details + vector_search_match + user_query+ \"n\\nAssistant:\"\n",
469+
" print(\"FINAL QUESTION :::\" + final_question)\n",
441470
" answer = rqst.generate_sql(final_question)\n",
442471
" return answer"
443472
]
@@ -456,7 +485,7 @@
456485
},
457486
{
458487
"cell_type": "code",
459-
"execution_count": 18,
488+
"execution_count": 24,
460489
"id": "21d5b62a-4446-43d2-a4bf-6d51061160a9",
461490
"metadata": {
462491
"pycharm": {
@@ -466,7 +495,11 @@
466495
},
467496
"outputs": [],
468497
"source": [
469-
"user_query='show me all the titles in US region'"
498+
"#user_query='how many titles exist '\n",
499+
"#user_query = 'show me top 10 title by user rating'\n",
500+
"#user_query = 'show me top 10 titles in US region'\n",
501+
"#user_query = 'which year was a movie/title made'\n",
502+
"user_query = 'how many titles are from the US region'"
470503
]
471504
},
472505
{
@@ -528,14 +561,23 @@
528561
},
529562
{
530563
"cell_type": "code",
531-
"execution_count": null,
564+
"execution_count": 28,
532565
"id": "3a1af6ad-036b-4eeb-b640-8642a75da17b",
533566
"metadata": {
534567
"pycharm": {
535568
"name": "#%%\n"
536569
}
537570
},
538-
"outputs": [],
571+
"outputs": [
572+
{
573+
"name": "stdout",
574+
"output_type": "stream",
575+
"text": [
576+
" us_title_count\n",
577+
"0 1534894\n"
578+
]
579+
}
580+
],
539581
"source": [
540582
"print(QueryOutput)"
541583
]
@@ -1178,9 +1220,9 @@
11781220
],
11791221
"instance_type": "ml.t3.medium",
11801222
"kernelspec": {
1181-
"display_name": "Python 3 (Data Science 3.0)",
1223+
"display_name": "Python 3",
11821224
"language": "python",
1183-
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1"
1225+
"name": "python3"
11841226
},
11851227
"language_info": {
11861228
"codemirror_mode": {
@@ -1192,7 +1234,7 @@
11921234
"name": "python",
11931235
"nbconvert_exporter": "python",
11941236
"pygments_lexer": "ipython3",
1195-
"version": "3.10.6"
1237+
"version": "3.12.1"
11961238
}
11971239
},
11981240
"nbformat": 4,

0 commit comments

Comments
 (0)