Skip to content

Commit 8fe6ca7

Browse files
committed
chore: refactor code
1 parent 67994a5 commit 8fe6ca7

File tree

1 file changed

+149
-126
lines changed

1 file changed

+149
-126
lines changed

‎hello.py‎

Lines changed: 149 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -3,138 +3,161 @@
33
import matplotlib.pyplot as plt
44
import scipy.special
55
import math
6-
from scipy.stats import multivariate_normal
7-
st.title("Probability Explorer")
8-
9-
# Create two columns for layout
10-
left_col, right_col = st.columns([1, 2])
11-
12-
with left_col:
13-
# Add interactive elements for probability exploration
14-
# Select distribution type
15-
dist_type = st.selectbox(
16-
'Select probability distribution',
17-
['Multivariate Normal', 'Binomial', 'Normal', 'Poisson', 'Uniform']
18-
)
19-
st.write(f'Selected distribution: {dist_type}')
20-
21-
# Parameters based on distribution type
22-
if dist_type == 'Multivariate Normal':
6+
from scipy.stats import multivariate_normal, chi2
7+
8+
class ProbabilityExplorer:
9+
def __init__(self):
10+
st.title("Probability Explorer")
11+
self.left_col, self.right_col = st.columns([1, 2])
12+
self.confidence = 0.95
13+
self.setup_ui()
14+
15+
def setup_ui(self):
16+
with self.left_col:
17+
self.dist_type = st.selectbox(
18+
'Select probability distribution',
19+
['Multivariate Normal', 'Normal', 'Chi-squared', 'Poisson', 'Uniform', 'Binomial']
20+
)
21+
st.write(f'Selected distribution: {self.dist_type}')
22+
self.get_distribution_parameters()
23+
self.auto_update = st.checkbox('Auto-update plot', value=True)
24+
25+
def get_distribution_parameters(self):
26+
if self.dist_type == 'Multivariate Normal':
27+
self.get_multivariate_normal_params()
28+
elif self.dist_type == 'Normal':
29+
self.get_normal_params()
30+
elif self.dist_type == 'Uniform':
31+
self.get_uniform_params()
32+
elif self.dist_type == 'Chi-squared':
33+
self.get_chi_squared_params()
34+
else:
35+
self.get_discrete_params()
36+
37+
def get_multivariate_normal_params(self):
2338
st.write('Mean Vector:')
24-
mean1 = st.slider('μ₁', -5.0, 5.0, 0.0, 0.1)
25-
mean2 = st.slider('μ₂', -5.0, 5.0, 0.0, 0.1)
39+
self.mean1 = st.slider('μ₁', -5.0, 5.0, 0.0, 0.1)
40+
self.mean2 = st.slider('μ₂', -5.0, 5.0, 0.0, 0.1)
2641

2742
st.write('Covariance Matrix:')
28-
var1 = st.slider('σ₁²', 0.1, 5.0, 1.0, 0.1)
29-
var2 = st.slider('σ₂²', 0.1, 5.0, 1.0, 0.1)
30-
corr = st.slider('Correlation ρ', -1.0, 1.0, 0.0, 0.1)
43+
self.var1 = st.slider('σ₁²', 0.1, 5.0, 1.0, 0.1)
44+
self.var2 = st.slider('σ₂²', 0.1, 5.0, 1.0, 0.1)
45+
self.corr = st.slider('Correlation ρ', -1.0, 1.0, 0.0)
3146

32-
# Calculate covariance from correlation
33-
cov12 = corr * np.sqrt(var1 * var2)
34-
35-
# Display covariance matrix
36-
cov_matrix = np.array([[var1, cov12], [cov12, var2]])
47+
self.cov12 = self.corr * np.sqrt(self.var1 * self.var2)
48+
self.cov_matrix = np.array([[self.var1, self.cov12], [self.cov12, self.var2]])
3749
st.write("Covariance Matrix:")
38-
st.write(cov_matrix)
39-
40-
elif dist_type == 'Normal':
41-
mean = st.slider('Mean', -10.0, 10.0, 0.0, 0.1)
42-
std = st.slider('Standard deviation', 0.1, 5.0, 1.0, 0.1)
43-
st.write(f'Mean: {mean}, Standard deviation: {std}')
44-
elif dist_type == 'Uniform':
45-
a = st.slider('Lower bound (a)', -10.0, 10.0, 0.0, 0.1)
46-
b = st.slider('Upper bound (b)', -10.0, 10.0, 1.0, 0.1)
47-
if b <= a:
50+
st.write(self.cov_matrix)
51+
52+
def get_normal_params(self):
53+
self.mean = st.slider('Mean', -10.0, 10.0, 0.0, 0.1)
54+
self.std = st.slider('Standard deviation', 0.1, 5.0, 1.0, 0.1)
55+
st.write(f'Mean: {self.mean}, Standard deviation: {self.std}')
56+
57+
def get_uniform_params(self):
58+
self.a = st.slider('Lower bound (a)', -10.0, 10.0, 0.0, 0.1)
59+
self.b = st.slider('Upper bound (b)', -10.0, 10.0, 1.0, 0.1)
60+
if self.b <= self.a:
4861
st.error('Upper bound must be greater than lower bound')
49-
b = a + 0.1
50-
else:
51-
# Slider for probability value
52-
probability = st.slider('Select a probability value', 0.0, 1.0, 0.5, 0.1)
53-
st.write(f'Selected probability: {probability}')
54-
55-
# Number of trials input
56-
trials = st.number_input('Number of trials', min_value=1, value=100)
57-
st.write(f'Number of trials: {trials}')
58-
59-
# Add auto-update toggle
60-
auto_update = st.checkbox('Auto-update plot', value=True)
61-
62-
# Set confidence level
63-
confidence = 0.95
64-
65-
def calculate_and_plot():
66-
# Add these lines at the start of the function to access the variables
67-
global mean, std, probability, trials, a, b, mean1, mean2, var1, var2, cov12
68-
69-
with right_col:
70-
st.write('Calculating probability distribution...')
71-
72-
# Display formula based on distribution type
73-
if dist_type == 'Multivariate Normal':
74-
st.latex(r'f(x) = \frac{1}{2\pi|\Sigma|^{1/2}} \exp\left(-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right)')
75-
elif dist_type == 'Normal':
76-
st.latex(r'f(x) = \frac{1}{\sigma\sqrt{2\pi}} e^{-\frac{(x-\mu)^2}{2\sigma^2}}')
77-
elif dist_type == 'Binomial':
78-
st.latex(r'P(X=k) = \binom{n}{k}p^k(1-p)^{n-k}')
79-
elif dist_type == 'Poisson':
80-
st.latex(r'P(X=k) = \frac{\lambda^k e^{-\lambda}}{k!}')
81-
elif dist_type == 'Uniform':
82-
st.latex(r'f(x) = \frac{1}{b-a} \text{ for } a \leq x \leq b')
83-
84-
# Create figure
62+
self.b = self.a + 0.1
63+
64+
def get_chi_squared_params(self):
65+
self.df = st.slider('Degrees of freedom', 1, 30, 1)
66+
st.write(f'Degrees of freedom: {self.df}')
67+
68+
def get_discrete_params(self):
69+
self.probability = st.slider('Select a probability value', 0.0, 1.0, 0.5, 0.1)
70+
st.write(f'Selected probability: {self.probability}')
71+
self.trials = st.number_input('Number of trials', min_value=1, value=100)
72+
st.write(f'Number of trials: {self.trials}')
73+
74+
def display_formula(self):
75+
formulas = {
76+
'Multivariate Normal': r'f(x) = \frac{1}{2\pi|\Sigma|^{1/2}} \exp\left(-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right)',
77+
'Normal': r'f(x) = \frac{1}{\sigma\sqrt{2\pi}} e^{-\frac{(x-\mu)^2}{2\sigma^2}}',
78+
'Binomial': r'P(X=k) = \binom{n}{k}p^k(1-p)^{n-k}',
79+
'Poisson': r'P(X=k) = \frac{\lambda^k e^{-\lambda}}{k!}',
80+
'Uniform': r'f(x) = \frac{1}{b-a} \text{ for } a \leq x \leq b',
81+
'Chi-squared': r'f(x) = \frac{1}{2^{k/2}\Gamma(k/2)}x^{k/2-1}e^{-x/2}'
82+
}
83+
st.latex(formulas[self.dist_type])
84+
85+
def plot_distribution(self):
8586
fig, ax = plt.subplots()
8687

87-
# Generate data based on selected distribution
88-
if dist_type == 'Multivariate Normal':
89-
# Create grid of points
90-
x, y = np.mgrid[-5:5:.01, -5:5:.01]
91-
pos = np.dstack((x, y))
92-
93-
# Define distribution parameters
94-
mean = [mean1, mean2]
95-
cov = [[var1, cov12], [cov12, var2]]
96-
97-
# Create multivariate normal distribution
98-
rv = multivariate_normal(mean, cov)
99-
100-
# Calculate pdf
101-
z = rv.pdf(pos)
102-
103-
# Create contour plot
104-
plt.contourf(x, y, z, levels=20, cmap='viridis')
105-
plt.colorbar(label='Probability Density')
106-
107-
ax.set_xlabel('X₁')
108-
ax.set_ylabel('X₂')
109-
110-
elif dist_type == 'Normal':
111-
x = np.linspace(mean - 4*std, mean + 4*std, 100)
112-
y = np.exp(-((x - mean)**2)/(2*std**2))/(std*np.sqrt(2*np.pi))
113-
ax.plot(x, y)
114-
elif dist_type == 'Binomial':
115-
x = np.arange(0, trials + 1)
116-
y = [scipy.special.comb(trials, k) * (probability**k) * ((1-probability)**(trials-k)) for k in x]
117-
ax.plot(x, y)
118-
elif dist_type == 'Poisson':
119-
x = np.arange(0, trials + 1)
120-
y = [(probability**k * np.exp(-probability))/math.factorial(k) for k in x]
121-
ax.plot(x, y)
122-
elif dist_type == 'Uniform':
123-
x = np.linspace(a - 0.5, b + 0.5, 100)
124-
y = np.where((x >= a) & (x <= b), 1/(b-a), 0)
125-
ax.plot(x, y)
126-
127-
ax.set_title(f'{dist_type} Distribution')
88+
if self.dist_type == 'Multivariate Normal':
89+
self.plot_multivariate_normal(ax)
90+
elif self.dist_type == 'Normal':
91+
self.plot_normal(ax)
92+
elif self.dist_type == 'Binomial':
93+
self.plot_binomial(ax)
94+
elif self.dist_type == 'Poisson':
95+
self.plot_poisson(ax)
96+
elif self.dist_type == 'Uniform':
97+
self.plot_uniform(ax)
98+
elif self.dist_type == 'Chi-squared':
99+
self.plot_chi_squared(ax)
100+
101+
ax.set_title(f'{self.dist_type} Distribution')
128102
ax.grid(True)
129-
130-
# Display plot in Streamlit
131-
st.pyplot(fig)
132-
st.success(icon="🔥", body="Distribution calculated!")
133-
134-
# Calculate either on button press or automatically based on toggle
135-
if auto_update:
136-
calculate_and_plot()
137-
else:
138-
with left_col:
139-
if st.button('Calculate Distribution'):
140-
calculate_and_plot()
103+
return fig
104+
105+
def plot_multivariate_normal(self, ax):
106+
x, y = np.mgrid[-5:5:.01, -5:5:.01]
107+
pos = np.dstack((x, y))
108+
mean = [self.mean1, self.mean2]
109+
cov = [[self.var1, self.cov12], [self.cov12, self.var2]]
110+
rv = multivariate_normal(mean, cov)
111+
z = rv.pdf(pos)
112+
plt.contourf(x, y, z, levels=20, cmap='viridis')
113+
plt.colorbar(label='Probability Density')
114+
ax.set_xlabel('X₁')
115+
ax.set_ylabel('X₂')
116+
117+
def plot_normal(self, ax):
118+
x = np.linspace(self.mean - 4*self.std, self.mean + 4*self.std, 100)
119+
y = np.exp(-((x - self.mean)**2)/(2*self.std**2))/(self.std*np.sqrt(2*np.pi))
120+
ax.plot(x, y)
121+
122+
def plot_binomial(self, ax):
123+
x = np.arange(0, self.trials + 1)
124+
y = [scipy.special.comb(self.trials, k) * (self.probability**k) *
125+
((1-self.probability)**(self.trials-k)) for k in x]
126+
ax.plot(x, y)
127+
128+
def plot_poisson(self, ax):
129+
x = np.arange(0, self.trials + 1)
130+
y = [(self.probability**k * np.exp(-self.probability))/math.factorial(k) for k in x]
131+
ax.plot(x, y)
132+
133+
def plot_uniform(self, ax):
134+
x = np.linspace(self.a - 0.5, self.b + 0.5, 100)
135+
y = np.where((x >= self.a) & (x <= self.b), 1/(self.b-self.a), 0)
136+
ax.plot(x, y)
137+
138+
def plot_chi_squared(self, ax):
139+
x = np.linspace(0, max(30, self.df*3), 200)
140+
y = chi2.pdf(x, self.df)
141+
ax.plot(x, y)
142+
ax.set_xlabel('x')
143+
ax.set_ylabel('Probability Density')
144+
145+
def calculate_and_plot(self):
146+
with self.right_col:
147+
st.write('Calculating probability distribution...')
148+
self.display_formula()
149+
fig = self.plot_distribution()
150+
st.pyplot(fig)
151+
st.success(icon="🔥", body="Distribution calculated!")
152+
153+
def run(self):
154+
if self.auto_update:
155+
self.calculate_and_plot()
156+
else:
157+
with self.left_col:
158+
if st.button('Calculate Distribution'):
159+
self.calculate_and_plot()
160+
161+
# Initialize and run the app
162+
app = ProbabilityExplorer()
163+
app.run()

0 commit comments

Comments
 (0)