pandas中one-hot编码的神坑

【pandas中one-hot编码的神坑】机器学习中,经常会用到one-hot编码。pandas中已经提供了这一函数。
但是这里有一个神坑,得到的one-hot编码数据类型是uint8,进行数值计算时会溢出!!!

import pandas as pd import numpy as np a = [1, 2, 3, 1] one_hot = pd.get_dummies(a) print(one_hot.dtypes) print(one_hot) print(-one_hot)

1uint8 2uint8 3uint8 dtype: object 123 0100 1010 2001 3100 123 025500 102550 200255 325500

正确的做法是,将其转换成浮点:

one_hot = one_hot.astype('float') print(-one_hot)

123 0 -1.0 -0.0 -0.0 1 -0.0 -1.0 -0.0 2 -0.0 -0.0 -1.0 3 -1.0 -0.0 -0.0

    推荐阅读