首先我们需要构造一下直线和曲线的数据,用np.linspace函数生成50个横坐标点.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-3, 3, 50)
然后构造出我们想要画的线的方程
y1 = x ** 2 + 1
y2 = 2 * x + 1
用plot函数画线
plt.plot(x, y2,label='2 * x + 1')
plt.plot(x, y1, color='red', linewidth=1.0, linestyle='--',label='x ** 2 + 1')
plt.show()
plot的第一个参数是所有点的横坐标,第二个参数是所有点的纵坐标.
所以当我们知道两点坐标画直线的时候,只需把横纵坐标分离然后用plot函数就可以画线啦.
points = [(1, 1), (5, 5)]
(xpoints, ypoints) = zip(*points)
plt.plot(xpoints, ypoints)
plt.show()
直方图是一个不错的数据可视化方法.
bar()函数需要两个参数一个是横坐标X,另一个是对应X数值的直方高度.
import matplotlib.pyplot as plt
X = [0, 1, 2, 3, 4, 5]
Y = [2, 5, 6, 8, 3, 2]
plt.bar(X, Y)
plt.show()
bar()函数需要我们预先统计好数据中数据的数量,而hist()函数只需要提供数据,会自动帮我们把每个区间的数据量统计好.
比如说下图,在(0,1]
区间有两个值,所以直方高度为2,在(1,2]
区间没有值所以高度为0,依次类推.
import matplotlib.pyplot as plt
import numpy as np
import cv2
data = [0, 0, 2, 4]
plt.hist(data, bins=5)
plt.show()
最明显的区别是plt.hist()会帮我们统计数据,而plt.bar()需要我们自己统计数据,这在不同的场景会各有优势.
另外,plt.bar()的直方图是画在数值上的,因为我们提供的数值和高度是对应的.而plt.hist()帮我们统计的数据是在区间上的,所以直方也画在了区间里.
在人工智能或者数据分析之前,通常习惯把原始数据画出来,看一下分布情况,让自己对数据有着更直观的理解.
我们可以使用plt.scatter()函数画散点图
import matplotlib.pyplot as plt
import numpy as np
nodes = np.mgrid[0:5, 0:3]
plt.scatter(nodes[0],nodes[1])
plt.show()
import matplotlib.pyplot as plt
#调节图形大小,宽,高
plt.figure(figsize=(6,9))
#定义饼状图的标签,标签是列表
labels = [u'one',u'two',u'three']
#每个标签占多大,会自动去算百分比
sizes = [60,30,10]
colors = ['red','yellowgreen','lightskyblue']
#将某部分爆炸出来, 使用括号,将第一块分割出来,数值的大小是分割出来的与其他两块的间隙
explode = (0.05,0,0)
patches,l_text,p_text = plt.pie(sizes,explode=explode,labels=labels,colors=colors,
labeldistance = 1.1,autopct = '%3.1f%%',shadow = False,
startangle = 90,pctdistance = 0.6)
#labeldistance,文本的位置离远点有多远,1.1指1.1倍半径的位置
#autopct,圆里面的文本格式,%3.1f%%表示小数有三位,整数有一位的浮点数
#shadow,饼是否有阴影
#startangle,起始角度,0,表示从0开始逆时针转,为第一块。一般选择从90度开始比较好看
#pctdistance,百分比的text离圆心的距离
#patches, l_texts, p_texts,为了得到饼图的返回值,p_texts饼图内部文本的,l_texts饼图外label的文本
#改变文本的大小
#方法是把每一个text遍历。调用set_size方法设置它的属性
for t in l_text:
t.set_size=(30)
for t in p_text:
t.set_size=(20)
# 设置x,y轴刻度一致,这样饼图才能是圆的
plt.axis('equal')
plt.legend()
plt.show()
np.histogram()函数可以将Numpy结构的数据绘制成直方图,统计Numpy结构中每个数据出现的次数.
比如我们要分析一张图片在HSV空间上的分布情况,则可以用np.histogram()先计算出数量,在用bar()函数画图
img = plt.imread('./cat.jpg')
hsv = cv2.cvtColor(img,cv2.COLOR_RGB2HSV)
h_hist,h_edges = np.histogram(hsv[:,:,0], bins=32, range=(0, 180))
s_hist,s_edges = np.histogram(hsv[:,:,1], bins=32, range=(0, 256))
v_hist,v_edges = np.histogram(hsv[:,:,2], bins=32, range=(0, 256))
f,(ax1,ax2,ax3) = plt.subplots(1,3)
ax1.bar(h_edges[:-1], h_hist)
ax2.bar(s_edges[:-1], s_hist)
ax3.bar(v_edges[:-1], v_hist)
plt.show()
还可以中plt.hist函数画直方图,hist会自动帮我们计算出每个区间的元素个数.
不过hist只能自动计算1维的数据,所以我们先reshape一下
img = plt.imread('./cat.jpg')
hsv = cv2.cvtColor(img,cv2.COLOR_RGB2HSV)
h = hsv[:,:,0].reshape([-1])
s = hsv[:,:,1].reshape([-1])
v = hsv[:,:,2].reshape([-1])
f,(ax1,ax2,ax3) = plt.subplots(1,3)
ax1.hist(h, 180)
ax2.hist(s, 255)
ax3.hist(v, 255)
plt.show()
当分析数据时,经常有很多数据希望一次打印出来.
matplotlib提供了很多多图绘制函数,可以将多个表格合并到一幅图中打印出来.
最常用的函数就是subplots了,因为它很方便.
matplotlib.pyplot.subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw)
subplots有很多参数,不过一般我们只需要注意这几个:
来看个完整的例子.
# coding=utf-8
import matplotlib.pyplot as plt
img1 = plt.imread('1.jpg')
img2 = plt.imread('2.jpg')
img3 = plt.imread('3.jpg')
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
ax1.imshow(img1)
ax2.imshow(img2)
ax3.imshow(img3)
plt.show()
有的时候我们希望使用循环来绘制很多图片,这时可以这样做
import matplotlib.pyplot as plt
x = df.iloc[:,0].values
y = df.iloc[:,1].values
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
aslist = axes.ravel()
for i in range(4):
pred = weakClassifier(i).predict(df.iloc[:,0:2].values)
aslist[i].scatter(x = x[pred ==1],y=y[pred ==1],color = 'r')
aslist[i].scatter(x = x[pred ==0],y=y[pred ==0],color = 'b')
如果你想画更复杂,更美观的图,可以试着用一下subplot2grid函数.
# coding=utf-8
import matplotlib.pyplot as plt
ax1 = plt.subplot2grid((3, 3), (0, 0), colspan=3)
ax2 = plt.subplot2grid((3, 3), (1, 0), colspan=2)
ax3 = plt.subplot2grid((3, 3), (1, 2), rowspan=2)
ax4 = plt.subplot2grid((3, 3), (2, 0))
ax5 = plt.subplot2grid((3, 3), (2, 1))
plt.show()