Skip to content
11 changes: 0 additions & 11 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ enable =
expression-not-assigned,
confusing-with-statement,
unnecessary-lambda,
assign-to-new-keyword,
redeclared-assigned-name,
pointless-statement,
pointless-string-statement,
Expand Down Expand Up @@ -123,7 +122,6 @@ enable =
invalid-length-returned,
protected-access,
attribute-defined-outside-init,
no-init,
abstract-method,
invalid-overridden-method,
arguments-differ,
Expand Down Expand Up @@ -165,9 +163,7 @@ enable =
### format
# Line length, indentation, whitespace:
bad-indentation,
mixed-indentation,
unnecessary-semicolon,
bad-whitespace,
missing-final-newline,
line-too-long,
mixed-line-endings,
Expand All @@ -187,7 +183,6 @@ enable =
import-self,
preferred-module,
reimported,
relative-import,
deprecated-module,
wildcard-import,
misplaced-future,
Expand Down Expand Up @@ -282,12 +277,6 @@ indent-string = ' '
# black doesn't always obey its own limit. See pyproject.toml.
max-line-length = 100

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check =

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt = no
Expand Down
42 changes: 22 additions & 20 deletions examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import csv
from collections import Counter, defaultdict
from tqdm import tqdm
from digest.model_class.digest_model import (
NodeShapeCounts,
NodeTypeCounts,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
)
from digest.model_class.digest_onnx_model import DigestOnnxModel
from utils.onnx_utils import (
get_dynamic_input_dims,
load_onnx,
DigestOnnxModel,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
NodeTypeCounts,
NodeShapeCounts,
)

GLOBAL_MODEL_HEADERS = [
Expand Down Expand Up @@ -71,7 +73,7 @@ def main(onnx_files: str, output_dir: str):
print(f"dim: {dynamic_shape}")

digest_model = DigestOnnxModel(
model_proto, onnx_filepath=onnx_file, model_name=model_name
model_proto, onnx_file_path=onnx_file, model_name=model_name
)

# Update the global model dictionary
Expand All @@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str):

global_model_data[model_name] = {
"opset": digest_model.opset,
"parameters": digest_model.model_parameters,
"flops": digest_model.model_flops,
"parameters": digest_model.parameters,
"flops": digest_model.flops,
}

# Model summary text report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt")
digest_model.save_txt_report(summary_filepath)
digest_model.save_text_report(summary_filepath)

# Model summary yaml report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml")
digest_model.save_yaml_report(summary_filepath)

# Save csv containing node-level information
nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)

# Save csv containing node type counter
node_type_counter = digest_model.get_node_type_counts()
node_type_filepath = os.path.join(
output_dir, f"{model_name}_node_type_counts.csv"
)
if node_type_counter:
save_node_type_counts_csv_report(node_type_counter, node_type_filepath)

digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Update global data structure for node type counter
global_node_type_counter.update(node_type_counter)
global_node_type_counter.update(digest_model.node_type_counts)

# Save csv containing node shape counts per op_type
node_shape_counts = digest_model.get_node_shape_counts()
node_shape_filepath = os.path.join(
output_dir, f"{model_name}_node_shape_counts.csv"
)
save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath)
digest_model.save_node_shape_counts_csv_report(node_shape_filepath)

# Update global data structure for node shape counter
for node_type, shape_counts in node_shape_counts.items():
for node_type, shape_counts in digest_model.get_node_shape_counts().items():
global_node_shape_counter[node_type].update(shape_counts)

if len(onnx_file_list) > 1:
global_filepath = os.path.join(output_dir, "global_node_type_counts.csv")
global_node_type_counter = NodeTypeCounts(
global_node_type_counter.most_common()
)
save_node_type_counts_csv_report(global_node_type_counter, global_filepath)
global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common())
save_node_type_counts_csv_report(global_node_type_counts, global_filepath)

global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv")
save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.0.0",
version="1.2.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand All @@ -25,6 +25,8 @@
"platformdirs>=4.2.2",
"pyyaml>=6.0.1",
"psutil>=6.0.0",
"torch",
"transformers",
],
classifiers=[],
entry_points={"console_scripts": ["digest = digest.main:main"]},
Expand Down
14 changes: 12 additions & 2 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,23 @@ class WarnDialog(QDialog):

def __init__(self, warning_message: str, parent=None):
super().__init__(parent)
self.setWindowTitle("Warning Message")

self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))

self.setWindowTitle("Warning Message")
self.setWindowFlags(Qt.WindowType.Dialog)
self.setMinimumWidth(300)

self.setWindowModality(Qt.WindowModality.WindowModal)

layout = QVBoxLayout()

# Application Version
layout.addWidget(QLabel("<b>Something went wrong</b>"))
layout.addWidget(QLabel("<b>Warning</b>"))
layout.addWidget(QLabel(warning_message))

ok_button = QPushButton("OK")
ok_button.clicked.connect(self.accept) # Close dialog when clicked
layout.addWidget(ok_button)

self.setLayout(layout)
2 changes: 1 addition & 1 deletion src/digest/gui_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# For EXE releases we can block certain features e.g. to customers

modules:
huggingface: false
huggingface: true
6 changes: 3 additions & 3 deletions src/digest/histogramchartwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
super(StackedHistogramWidget, self).__init__(*args, **kwargs)

self.plot_widget = pg.PlotWidget()
self.plot_widget.setMaximumHeight(150)
self.plot_widget.setMaximumHeight(200)
plot_item = self.plot_widget.getPlotItem()
if plot_item:
plot_item.setContentsMargins(0, 0, 0, 0)
Expand All @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs):
self.bar_spacing = 25

def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False):

title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)"
self.plot_widget.setLabel(
"left",
Expand All @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal
x_positions = list(range(len(op_count)))
total_count = sum(op_count)
width = 0.6
self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing)
self.plot_widget.setFixedWidth(500)

for count, x_pos, tick in zip(op_count, x_positions, op_names):
x0 = x_pos - width / 2
y0 = 0
Expand Down
Loading
Loading