Importance of Data Visualization
Data visualization transforms complex data into visual representations that make patterns, trends, and insights easier to understand. It's essential for exploratory data analysis and communicating findings.
Types of Visualizations
| Purpose | Visualization Types |
|---|---|
| Distribution | Histogram, Box plot, KDE |
| Relationship | Scatter plot, Line plot |
| Composition | Pie chart, Stacked bar |
| Comparison | Bar chart, Grouped bar |
| Trend | Line chart, Area chart |
Basic Plots with Matplotlib
import matplotlib.pyplot as plt
import numpy as np
# Basic line plot
plt.figure(figsize=(10, 6))
x = np.linspace(0, 10, 100)
y = np.sin(x)
plt.plot(x, y, label='sin(x)', color='blue')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Sine Wave')
plt.legend()
plt.grid(True)
plt.show()
# Multiple subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes[0, 0].plot(x, np.sin(x))
axes[0, 1].plot(x, np.cos(x))
axes[1, 0].plot(x, np.tan(x))
axes[1, 1].plot(x, x**2)
plt.tight_layout()
plt.show()
Statistical Visualizations
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# Generate sample data
np.random.seed(42)
data = np.random.randn(1000)
# Histogram
plt.figure(figsize=(10, 6))
plt.hist(data, bins=30, edgecolor='black', alpha=0.7)
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram')
plt.show()
# Box plot
plt.figure(figsize=(8, 6))
plt.boxplot(data, vert=True)
plt.ylabel('Value')
plt.title('Box Plot')
plt.show()
# Violin plot
plt.figure(figsize=(8, 6))
parts = plt.violinplot(data, positions=[0], showmeans=True, showmedians=True)
plt.ylabel('Value')
plt.title('Violin Plot')
plt.show()
Seaborn Visualizations
import seaborn as sns
import pandas as pd
import numpy as np
# Create sample dataframe
df = pd.DataFrame({
'x': np.random.randn(100),
'y': np.random.randn(100),
'category': np.random.choice(['A', 'B', 'C'], 100),
'value': np.random.randn(100) + 5
})
# Scatter plot
plt.figure(figsize=(10, 6))
sns.scatterplot(data=df, x='x', y='y', hue='category', style='category')
plt.title('Scatter Plot')
plt.show()
# Pair plot
sns.pairplot(df, hue='category')
plt.show()
# Heatmap
correlation_matrix = df[['x', 'y', 'value']].corr()
plt.figure(figsize=(8, 6))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0)
plt.title('Correlation Heatmap')
plt.show()
# Distribution plots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.histplot(df['x'], kde=True, ax=axes[0])
axes[0].set_title('Distribution Plot 1')
sns.histplot(df['y'], kde=True, ax=axes[1])
axes[1].set_title('Distribution Plot 2')
plt.tight_layout()
plt.show()
# Bar plot
plt.figure(figsize=(10, 6))
sns.barplot(data=df, x='category', y='value', estimator=np.mean)
plt.title('Bar Plot with Mean')
plt.show()
# Count plot
plt.figure(figsize=(8, 6))
sns.countplot(data=df, x='category')
plt.title('Count Plot')
plt.show()
Time Series Visualization
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Create time series data
dates = pd.date_range('2023-01-01', periods=100, freq='D')
df = pd.DataFrame({
'date': dates,
'value': np.cumsum(np.random.randn(100)) + 100,
'rolling_mean': pd.Series(np.cumsum(np.random.randn(100)) + 100).rolling(7).mean()
})
plt.figure(figsize=(14, 6))
plt.plot(df['date'], df['value'], label='Original', alpha=0.7)
plt.plot(df['date'], df['rolling_mean'], label='7-day Rolling Mean', color='red')
plt.xlabel('Date')
plt.ylabel('Value')
plt.title('Time Series Plot')
plt.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
# Multiple time series
df2 = pd.DataFrame({
'date': dates,
'series1': np.cumsum(np.random.randn(100)),
'series2': np.cumsum(np.random.randn(100)) + 20
})
plt.figure(figsize=(14, 6))
plt.plot(df2['date'], df2['series1'], label='Series 1')
plt.plot(df2['date'], df2['series2'], label='Series 2')
plt.fill_between(df2['date'], df2['series1'], df2['series2'], alpha=0.3)
plt.title('Multiple Time Series')
plt.legend()
plt.show()
Advanced Visualizations
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
# 3D Scatter plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
x = np.random.randn(100)
y = np.random.randn(100)
z = np.random.randn(100)
ax.scatter(x, y, z, c='blue', marker='o')
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_title('3D Scatter Plot')
plt.show()
# 3D Surface plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
Z = np.sin(np.sqrt(X**2 + Y**2))
surf = ax.plot_surface(X, Y, Z, cmap='coolwarm')
fig.colorbar(surf, ax=ax)
ax.set_title('3D Surface Plot')
plt.show()
# Radar chart
categories = ['Math', 'Science', 'English', 'History', 'Art']
values = [85, 90, 78, 88, 92]
values += values[:1]
angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False).tolist()
angles += angles[:1]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
ax.plot(angles, values, 'o-', linewidth=2)
ax.fill(angles, values, alpha=0.25)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories)
ax.set_title('Radar Chart')
plt.show()
Visualizing Model Results
from sklearn.metrics import confusion_matrix, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
# Confusion Matrix
y_true = [0, 1, 0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 0, 0, 0, 1, 1, 1]
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
# ROC Curve
fpr, tpr, thresholds = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()
# Feature Importance
feature_names = ['feature1', 'feature2', 'feature3', 'feature4', 'feature5']
importance = np.array([0.35, 0.25, 0.20, 0.12, 0.08])
indices = np.argsort(importance)[::-1]
plt.figure(figsize=(10, 6))
plt.bar(range(len(importance)), importance[indices])
plt.xticks(range(len(importance)), [feature_names[i] for i in indices])
plt.xlabel('Feature')
plt.ylabel('Importance')
plt.title('Feature Importance')
plt.show()
Best Practices for Visualization
- Choose appropriate chart types based on data and message
- Keep it simple - avoid unnecessary complexity
- Use appropriate colors - consider colorblind-friendly palettes
- Label clearly - include titles, labels, legends
- Consider your audience - technical vs. non-technical
- Iterate and refine - create multiple versions
- Tell a story - guide the viewer through insights