13. Introduction to Seaborn

  • seaborn 是 build 在 matplotlib 上,和 pandas 合作的很好

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

13.1. Scatter plot

13.1.1. Data

13.1.1.1. student_data

student_data = pd.read_csv("data/student-alcohol-consumption.csv", index_col=0)
student_data
school sex age famsize Pstatus Medu Fedu traveltime failures schoolsup ... goout Dalc Walc health absences G1 G2 G3 location study_time
0 GP F 18 GT3 A 4 4 2 0 yes ... 4 1 1 3 6 5 6 6 Urban 2 to 5 hours
1 GP F 17 GT3 T 1 1 1 0 no ... 3 1 1 3 4 5 5 6 Urban 2 to 5 hours
2 GP F 15 LE3 T 1 1 1 3 yes ... 2 2 3 3 10 7 8 10 Urban 2 to 5 hours
3 GP F 15 GT3 T 4 2 1 0 no ... 2 1 1 5 2 15 14 15 Urban 5 to 10 hours
4 GP F 16 GT3 T 3 3 1 0 no ... 2 1 2 5 4 6 10 10 Urban 2 to 5 hours
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
390 MS M 20 LE3 A 2 2 1 2 no ... 4 4 5 4 11 9 9 9 Urban 2 to 5 hours
391 MS M 17 LE3 T 3 1 2 0 no ... 5 3 4 2 3 14 16 16 Urban <2 hours
392 MS M 21 GT3 T 1 1 1 3 no ... 3 3 3 3 3 10 8 7 Rural <2 hours
393 MS M 18 LE3 T 3 2 3 0 no ... 1 3 4 5 0 11 12 10 Rural <2 hours
394 MS M 19 LE3 T 1 1 1 0 no ... 3 3 3 5 5 8 9 9 Urban <2 hours

395 rows × 29 columns

13.1.1.2. mpg

mpg = pd.read_csv("data/mpg.csv")
mpg
mpg cylinders displacement horsepower weight acceleration model_year origin name
0 18.0 8 307.0 130.0 3504 12.0 70 usa chevrolet chevelle malibu
1 15.0 8 350.0 165.0 3693 11.5 70 usa buick skylark 320
2 18.0 8 318.0 150.0 3436 11.0 70 usa plymouth satellite
3 16.0 8 304.0 150.0 3433 12.0 70 usa amc rebel sst
4 17.0 8 302.0 140.0 3449 10.5 70 usa ford torino
... ... ... ... ... ... ... ... ... ...
393 27.0 4 140.0 86.0 2790 15.6 82 usa ford mustang gl
394 44.0 4 97.0 52.0 2130 24.6 82 europe vw pickup
395 32.0 4 135.0 84.0 2295 11.6 82 usa dodge rampage
396 28.0 4 120.0 79.0 2625 18.6 82 usa ford ranger
397 31.0 4 119.0 82.0 2720 19.4 82 usa chevy s-10

398 rows × 9 columns

13.1.2. 基本 scatter plot

sns.scatterplot(x="absences", y="G3", 
                data=student_data);
../_images/intro_seaborn_10_0.png
  • G3 是 第三次段考的意思。可以看到,缺席率越高,看起來成績越低

13.1.3. hue (i.e. color)

sns.scatterplot(x="absences", y="G3", 
                data=student_data, 
                hue="location");
../_images/intro_seaborn_13_0.png
  • 可以看到,第三軸放上 location 後,結論是:不論是城市或鄉下的小孩,都是缺席越多,成績越差

13.1.3.1. 第三個變數的顏色自己指定

hue_colors = {
    "Urban": "black",
    "Rural": "red"
}
sns.scatterplot(x="absences", y="G3", 
                data=student_data, 
                hue="location",
                hue_order = ["Rural", "Urban"],
                palette = hue_colors);
../_images/intro_seaborn_16_0.png

13.1.3.2. 第三個變數的順序自己指定

sns.scatterplot(x="absences", y="G3", 
                data=student_data, 
                hue="location",
                hue_order = ["Rural", "Urban"]); # 先 Rural 再 Urban
../_images/intro_seaborn_18_0.png

13.1.4. size

sns.scatterplot(
    x="horsepower", 
    y="mpg",
    data=mpg,
    size="cylinders"
);
../_images/intro_seaborn_20_0.png
  • 可以看到,汽缸數越多(cylinders),horsepower越大,而油耗越差(mpg)

sns.scatterplot(
    x="horsepower", 
    y="mpg",
    data=mpg,
    size="cylinders",
    hue = "cylinders"
);
../_images/intro_seaborn_22_0.png
  • 加上顏色,看得更清楚

  • 也因為 cylinders 被他認為是 float 變數,所以當第三軸的顏色時,他是給你 gradient 顏色,比較好觀察

13.1.5. style (點的style)

sns.scatterplot(
    x="acceleration", 
    y="mpg",
    data=mpg,
    style ="origin",
    hue = "origin"
);
../_images/intro_seaborn_25_0.png
  • 可以看到,usa的車子最多,而且比起 japan 和 europe 的特色,是他有一部分都聚在左下角:表示加速快 & 油耗差

sns.scatterplot(
    x = "absences",
    y = "G3",
    data = student_data,
    style = "traveltime",
    hue = "traveltime"
);
../_images/intro_seaborn_27_0.png

13.1.6. alpha

13.1.7. facet_grid 類型

  • ggplot 的 facet_grid/facet_wrap,在 sns 中,是用 relplot() 來實現

  • relplot 是 relational plot 的縮寫,它包含 scatter plot 和 line plot.

  • 我們使用 relplot 的時機是,你想做出 ggplot 那種 facet_wrap 的效果

13.1.7.1. by column 畫圖

sns.relplot(x="absences", 
            y="G3",
            data=student_data, 
            kind = "scatter",
            col = "location");
../_images/intro_seaborn_32_0.png

13.1.7.2. by column 指定順序

sns.relplot(x="absences", 
            y="G3",
            data=student_data, 
            kind = "scatter",
            col = "location",
           col_order = ["Rural", "Urban"]);
../_images/intro_seaborn_34_0.png

13.1.7.3. 指定 column 行數

  • 可以定義 by col 畫圖時,最多幾個後要換行

sns.relplot(x="absences", 
            y="G3",
            data=student_data, 
            kind = "scatter",
            col = "study_time",
            col_wrap = 2);
../_images/intro_seaborn_37_0.png

13.1.7.4. by row 畫圖

  • 同樣的做法,可以改成 by row

sns.relplot(x="absences", 
            y="G3",
            data=student_data, 
            kind = "scatter",
            row = "location");
../_images/intro_seaborn_40_0.png

13.1.7.5. by column & row (R 的 facet_grid)

  • 如果要做到 facet_grid (兩個變數交叉),那就又 col 又 row

sns.relplot(x="absences", 
            y="G3",
            data=student_data, 
            kind = "scatter",
            col = "study_time",
           row = "location");
../_images/intro_seaborn_43_0.png
  • 當然,剛剛 row 和 column 用過的細節設定都還是可以下:

sns.relplot(x="G1", y="G3", # 第一學期 和 第三學期 的成績
            data=student_data,
            kind="scatter", 
            col="schoolsup", # 有沒有獲得學校補助 school support
            col_order=["yes", "no"],
            row = "famsup", # 有沒有獲得家庭補助 family support
            row_order = ["yes", "no"])
<seaborn.axisgrid.FacetGrid at 0x12af1dc70>
../_images/intro_seaborn_45_1.png

13.2. Line plot

13.2.1. 基本 line plot

13.2.2. multiple line plot

13.2.3. line plot with CI

13.3. count plots (bar chart)

13.3.1. 基本 countplot

sns.countplot(x = "school", 
              data = student_data);
../_images/intro_seaborn_52_0.png

13.3.2. 兩維的 countplot

palette_colors = {"Rural": "green", "Urban": "blue"}

sns.countplot(x = "school", 
              data = student_data, 
              hue = "location", 
              palette = palette_colors);
../_images/intro_seaborn_54_0.png
countries = pd.read_csv("data/countries-of-the-world.csv")
countries
Country Region Population Area (sq. mi.) Pop. Density (per sq. mi.) Coastline (coast/area ratio) Net migration Infant mortality (per 1000 births) GDP ($ per capita) Literacy (%) Phones (per 1000) Arable (%) Crops (%) Other (%) Climate Birthrate Deathrate Agriculture Industry Service
0 Afghanistan ASIA (EX. NEAR EAST) 31056997 647500 48,0 0,00 23,06 163,07 700.0 36,0 3,2 12,13 0,22 87,65 1 46,6 20,34 0,38 0,24 0,38
1 Albania EASTERN EUROPE 3581655 28748 124,6 1,26 -4,93 21,52 4500.0 86,5 71,2 21,09 4,42 74,49 3 15,11 5,22 0,232 0,188 0,579
2 Algeria NORTHERN AFRICA 32930091 2381740 13,8 0,04 -0,39 31 6000.0 70,0 78,1 3,22 0,25 96,53 1 17,14 4,61 0,101 0,6 0,298
3 American Samoa OCEANIA 57794 199 290,4 58,29 -20,71 9,27 8000.0 97,0 259,5 10 15 75 2 22,46 3,27 NaN NaN NaN
4 Andorra WESTERN EUROPE 71201 468 152,1 0,00 6,6 4,05 19000.0 100,0 497,2 2,22 0 97,78 3 8,71 6,25 NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
222 West Bank NEAR EAST 2460492 5860 419,9 0,00 2,98 19,62 800.0 NaN 145,2 16,9 18,97 64,13 3 31,67 3,92 0,09 0,28 0,63
223 Western Sahara NORTHERN AFRICA 273008 266000 1,0 0,42 NaN NaN NaN NaN NaN 0,02 0 99,98 1 NaN NaN NaN NaN 0,4
224 Yemen NEAR EAST 21456188 527970 40,6 0,36 0 61,5 800.0 50,2 37,2 2,78 0,24 96,98 1 42,89 8,3 0,135 0,472 0,393
225 Zambia SUB-SAHARAN AFRICA 11502010 752614 15,3 0,00 0 88,29 800.0 80,6 8,2 7,08 0,03 92,9 2 41 19,93 0,22 0,29 0,489
226 Zimbabwe SUB-SAHARAN AFRICA 12236805 390580 31,3 0,00 0 67,69 1900.0 90,7 26,8 8,32 0,34 91,34 2 28,01 21,84 0,179 0,243 0,579

227 rows × 20 columns

countries.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 227 entries, 0 to 226
Data columns (total 20 columns):
 #   Column                              Non-Null Count  Dtype  
---  ------                              --------------  -----  
 0   Country                             227 non-null    object 
 1   Region                              227 non-null    object 
 2   Population                          227 non-null    int64  
 3   Area (sq. mi.)                      227 non-null    int64  
 4   Pop. Density (per sq. mi.)          227 non-null    object 
 5   Coastline (coast/area ratio)        227 non-null    object 
 6   Net migration                       224 non-null    object 
 7   Infant mortality (per 1000 births)  224 non-null    object 
 8   GDP ($ per capita)                  226 non-null    float64
 9   Literacy (%)                        209 non-null    object 
 10  Phones (per 1000)                   223 non-null    object 
 11  Arable (%)                          225 non-null    object 
 12  Crops (%)                           225 non-null    object 
 13  Other (%)                           225 non-null    object 
 14  Climate                             205 non-null    object 
 15  Birthrate                           224 non-null    object 
 16  Deathrate                           223 non-null    object 
 17  Agriculture                         212 non-null    object 
 18  Industry                            211 non-null    object 
 19  Service                             212 non-null    object 
dtypes: float64(1), int64(2), object(17)
memory usage: 35.6+ KB
sns.scatterplot(x = "GDP ($ per capita)", y = "Literacy (%)", data = countries);
../_images/intro_seaborn_57_0.png

13.4. count plot