@@ -111,7 +111,7 @@ def format_type(_types, is_ml=False):
111111 for _type in _types :
112112 i_type_list = []
113113 if is_ml :
114- if _type .startswith ("Union[" ):
114+ if is_ml and _type .startswith ("Union[" ):
115115 # TODO: Improve code, should not lower() for all. e.g., MyClass
116116 types_split = [
117117 x .replace (" " , "" ).lower ()
@@ -124,15 +124,31 @@ def format_type(_types, is_ml=False):
124124 # i_type_list.append(_t.split("[")[0].lower())
125125 else :
126126 for _t in _type :
127- if _t .startswith ("Union[" ):
127+ if _t and _t .startswith ("Union[" ):
128128 types_split = [
129129 x .replace (" " , "" ).lower ()
130130 for x in _t .split ("Union[" )[1 ].split ("]" )[0 ].split ("," )
131131 ]
132132 i_type_list .extend (types_split )
133+ elif _t and _t .startswith ("Optional[" ):
134+ types_split = [
135+ x .replace (" " , "" ).lower ()
136+ for x in _t .split ("Optional[" )[1 ].split ("]" )[0 ].split ("," )
137+ ]
138+ types_split .append ("Nonetype" )
139+ i_type_list .extend (types_split )
140+ elif _t and _t .startswith ("Type[" ):
141+ types_split = [
142+ x .replace (" " , "" ).lower ()
143+ for x in _t .split ("Type[" )[1 ].split ("]" )[0 ].split ("," )
144+ ]
145+ i_type_list .extend (types_split )
146+ elif _t and _t in ["None" , "Unknown" ]:
147+ i_type_list .append ("Nonetype" )
133148 else :
134149 # TODO: Maybe no translation should be done here
135- i_type_list .append (_t .lower ())
150+ if _t :
151+ i_type_list .append (_t .lower ())
136152 # i_type_list.append(_t.split("[")[0].lower())
137153 type_formatted .append (list (set (i_type_list )))
138154
@@ -176,10 +192,14 @@ def check_match(
176192 if expected .get ("file" ) != out .get ("file" ):
177193 return False
178194
179- # check if line_number match
195+ # # check if line_number match
180196 if expected .get ("line_number" ) != out .get ("line_number" ):
181197 return False
182198
199+ # if "col_offset" in expected and "col_offset" in out:
200+ if expected ["col_offset" ] != out ["col_offset" ]:
201+ return False
202+
183203 if "col_offset" in expected and "col_offset" in out :
184204 if expected ["col_offset" ] != out ["col_offset" ]:
185205 return False
@@ -658,3 +678,97 @@ def benchmark_count(benchmark_path):
658678 _a , _functions , _params , _variables = get_fact_stats (json_files )
659679 total_result .append ([cat , _a , _functions , _params , _variables ])
660680 return total_result
681+
682+
683+ def normalize_type (type_str , nested_level = 0 ):
684+ """
685+ Normalize the type string by removing module prefixes and simplifying typing constructs.
686+ Example: 'builtins.str' -> 'str',
687+ 'typing.Tuple[builtins.str, builtins.float]' -> 'Tuple[str, float]',
688+ 'musictaxonomy.spotify.models.spotifyuser' -> 'SpotifyUser',
689+ 'List[List[Tuple[str]]]' -> 'List[List[Any]]' if nested level > 2.
690+ """
691+
692+ if type_str is None :
693+ return None
694+
695+ # Remove extra quotes if present
696+ if type_str .startswith ('"' ) and type_str .endswith ('"' ):
697+ type_str = type_str .strip ('"' )
698+
699+ # Mapping of module prefixes to remove
700+ type_mappings = {
701+ "builtins." : "" ,
702+ "typing." : "" ,
703+ }
704+ # Additional type mappings
705+ additional_type_mappings = {
706+ "integer" : "int" ,
707+ "string" : "str" ,
708+ "dictonary" : "dict" ,
709+ "method" : "Callable" ,
710+ "func" : "Callable" ,
711+ "function" : "Callable" ,
712+ "none" : "None" ,
713+ "Nonetype" : "None" ,
714+ "nonetype" : "None" ,
715+ "NoneType" : "None" ,
716+ "Text" : "str" ,
717+ }
718+
719+ if type_str is None :
720+ return None
721+
722+ # Replace module prefixes
723+ for prefix , replacement in type_mappings .items ():
724+ type_str = type_str .replace (prefix , replacement )
725+
726+ # Apply additional type mappings
727+ type_str = additional_type_mappings .get (type_str , type_str )
728+
729+ # Handle generic types (e.g., Tuple[], List[], Dict[])
730+ if "[" in type_str and "]" in type_str :
731+ base_type , generic_content = type_str .split ("[" , 1 )
732+ generic_content = generic_content .rsplit ("]" , 1 )[0 ]
733+ # Process the generic parameters recursively
734+ generic_params = []
735+ bracket_level = 0
736+ param = ""
737+ for char in generic_content :
738+ if char == "[" :
739+ bracket_level += 1
740+ param += char
741+ elif char == "]" :
742+ bracket_level -= 1
743+ param += char
744+ elif char == "," and bracket_level == 0 :
745+ generic_params .append (param .strip ())
746+ param = ""
747+ else :
748+ param += char
749+ if param :
750+ generic_params .append (param .strip ())
751+
752+ # If nested level is greater than 0, replace with Any
753+ if nested_level > 0 :
754+ normalized_params = ["Any" ]
755+ else :
756+ normalized_params = [
757+ normalize_type (param , nested_level + 1 ) for param in generic_params
758+ ]
759+
760+ return f"{ base_type } [{ ', ' .join (normalized_params )} ]"
761+
762+ # Handle fully qualified names by extracting the last segment
763+ if "." in type_str :
764+ return type_str .split ("." )[- 1 ]
765+
766+ # Return the simplified type
767+ return type_str
768+
769+
770+ def normalize_types (types ):
771+ """
772+ Normalize the type strings in the data.
773+ """
774+ return [normalize_type (type_str ) for type_str in types ]
0 commit comments