在数据科学和数值计算的世界里,numpy 就像是一把瑞士军刀,而 np.where() 无疑是其中最锋利的工具之一。今天,我们将深入探索这个功能强大的函数,学会如何用它优雅地处理条件逻辑和数据选择。
什么是 np.where()?
简单来说,np.where() 是 numpy 中用于条件选择和元素定位的核心函数。它有两种主要用法:
- 三元条件替换:根据条件选择不同值
- 索引定位:查找满足条件的元素位置
让我们通过实例来探索这两种用法!
用法一:三元条件替换(条件 ? 值1 : 值2)
这是 np.where() 最常用的形式,语法为:np.where(condition, x, y)
- condition: 布尔数组(true/false)
- x: 当条件为 true 时使用的值
- y: 当条件为 false 时使用的值
基础示例
import numpy as np # 创建示例数组 temperatures = np.array([22, 28, 15, 32, 18, 25]) # 标记高温和低温 result = np.where(temperatures > 25, "高温", "舒适") print(result) # 输出:['舒适' '高温' '舒适' '高温' '舒适' '舒适']
实际应用:成绩分类
scores = np.array([75, 92, 58, 81, 45, 67, 88])
# 根据分数分类
grade = np.where(scores >= 90, "a",
np.where(scores >= 80, "b",
np.where(scores >= 70, "c",
np.where(scores >= 60, "d", "f"))))
print(grade)
# 输出:['c' 'a' 'f' 'b' 'f' 'd' 'b']
多条件组合
data = np.array([12, 25, 7, 18, 30, 5, 22]) # 组合条件:大于10且小于20 result = np.where((data > 10) & (data < 20), data, 0) print(result) # 输出:[12 0 0 18 0 0 0] # 使用 | 表示 or 条件 result = np.where((data < 10) | (data > 20), data, -1) print(result) # 输出:[ -1 25 7 -1 30 5 22]
用法二:定位元素索引
当我们只提供条件参数时,np.where() 会返回满足条件元素的索引。
语法:np.where(condition)
一维数组示例
arr = np.array([0, 5, 0, 8, 0, 3, 0]) # 找到非零元素的索引 non_zero_indices = np.where(arr != 0) print(non_zero_indices) # 输出:(array([1, 3, 5]),) # 提取非零值 print(arr[non_zero_indices]) # 输出:[5 8 3]
二维数组示例
matrix = np.array([[1, 0, 4],
[0, 5, 0],
[7, 0, 9]])
# 找到值大于3的元素位置
rows, cols = np.where(matrix > 3)
print("行索引:", rows) # 输出:[0 1 2 2]
print("列索引:", cols) # 输出:[2 1 0 2]
print("对应值:", matrix[rows, cols]) # 输出:[4 5 7 9]
实际应用:图像处理
# 创建一个简单的图像矩阵 (5x5)
image = np.array([[120, 130, 40, 200, 210],
[30, 145, 255, 180, 10],
[220, 25, 30, 190, 200],
[100, 110, 120, 130, 140],
[50, 60, 70, 80, 90]])
# 找到高光区域(值>200)
highlight_rows, highlight_cols = np.where(image > 200)
print("高光像素位置:")
for r, c in zip(highlight_rows, highlight_cols):
print(f"({r}, {c}) - 值: {image[r, c]}")
# 输出:
# (0, 3) - 值: 200
# (0, 4) - 值: 210
# (1, 2) - 值: 255
# (2, 0) - 值: 220
# (2, 3) - 值: 190 -> 注意:190不大于200,实际应为 (2, 4): 200
# 更正:矩阵中 (2,4) 是200,所以应包含
进阶技巧与注意事项
1. 广播机制
np.where() 支持 numpy 的广播机制,使不同形状的数组能够一起工作:
# 二维条件与一维值组合 condition_2d = np.array([[true, false], [false, true]]) result = np.where(condition_2d, [10, 20], 0) print(result) # 输出: # [[10 0] # [ 0 20]]
2. 直接修改满足条件的值
data = np.array([5, 12, 8, 15, 3, 10]) # 将小于10的值替换为0 data[np.where(data < 10)] = 0 print(data) # 输出:[ 0 12 0 15 0 10]
3. 多维度索引
对于三维或更高维数组,np.where() 同样适用:
# 创建3x3x3数组
cube = np.random.randint(0, 10, (3, 3, 3))
# 找到所有大于8的元素
indices = np.where(cube > 8)
# 输出三维索引
print("维度0:", indices[0])
print("维度1:", indices[1])
print("维度2:", indices[2])
# 访问这些元素
print("满足条件的值:", cube[indices])
性能优势
与 python 循环相比,np.where() 有显著的性能优势:
import time
large_array = np.random.rand(10**6)
# 使用循环
start = time.time()
result_loop = [x*2 if x > 0.5 else x/2 for x in large_array]
print("循环耗时:", time.time() - start)
# 使用 np.where
start = time.time()
result_np = np.where(large_array > 0.5, large_array*2, large_array/2)
print("np.where耗时:", time.time() - start)
测试结果(可能因机器而异):
循环耗时: 0.45秒
np.where耗时: 0.02秒
到此这篇关于详解numpy中np.where() 的两种神奇用法的文章就介绍到这了,更多相关numpy np.where() 用法内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论