sklearn.model_selectionscikit-learn 的核心模块之一,用于模型的训练测试分割、交叉验证、超参数搜索以及验证评估。这是构建机器学习工作流的基础模块,提供了多种工具帮助优化模型性能、评估泛化能力。

以下是对 sklearn.model_selection 的详细介绍,包括其核心功能、常用方法、使用场景以及注意事项。


1. 核心功能概览

sklearn.model_selection 模块的主要功能包括:

  1. 数据集划分

    • 提供工具将数据集划分为训练集和测试集。
    • 常用函数:train_test_split
  2. 交叉验证

    • 支持多种交叉验证方法,评估模型的泛化能力。
    • 常用函数:cross_val_score, cross_validate
  3. 超参数搜索

    • 实现网格搜索、随机搜索等自动化超参数优化方法。
    • 常用函数:GridSearchCV, RandomizedSearchCV
  4. 数据划分策略

    • 提供多种划分策略,包括随机划分、分层划分、时间序列划分等。
    • 常用类:KFold, StratifiedKFold, TimeSeriesSplit
  5. 学习曲线和验证曲线

    • 可视化模型在不同参数或数据规模下的性能。
    • 常用函数:learning_curve, validation_curve

2. 数据集划分:train_test_split

在机器学习任务中,将数据集划分为训练集和测试集是第一步操作。

函数:train_test_split

  • 将数据随机分为训练集和测试集。
  • 支持独立划分特征和标签,或者多个数组同时划分。
基本用法
from sklearn.model_selection import train_test_split

X = [[1, 2], [3, 4], [5, 6], [7, 8]]
y = [0, 1, 0, 1]

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

print("X_train:", X_train)
print("y_train:", y_train)
print("X_test:", X_test)
print("y_test:", y_test)
重要参数
  1. test_size:测试集比例(默认值为 0.25)。
  2. train_size:训练集比例,test_sizetrain_size 不能同时为空。
  3. random_state:随机种子,保证结果可复现。
  4. shuffle:是否在划分前打乱数据(默认为 True)。

3. 交叉验证

交叉验证是一种通过多次训练测试分割,评估模型泛化能力的技术。

3.1 cross_val_score

  • 功能:对指定模型进行交叉验证,返回每次验证的评分。
  • 基本用法:
    from sklearn.model_selection import cross_val_score
    from sklearn.linear_model import LogisticRegression
    from sklearn.datasets import load_iris
    
    # 加载数据
    iris = load_iris()
    X, y = iris.data, iris.target
    
    # 创建模型
    model = LogisticRegression(max_iter=200)
    
    # 使用 5 折交叉验证
    scores = cross_val_score(model, X, y, cv=5)
    
    print("Cross-validation scores:", scores)
    print("Mean score:", scores.mean())
    

3.2 cross_validate

  • 功能:不仅返回交叉验证得分,还可以返回训练时间、测试时间等指标。
  • 基本用法:
    from sklearn.model_selection import cross_validate
    
    results = cross_validate(model, X, y, cv=5, return_train_score=True)
    print(results)
    
参数区别
  • cross_val_score 只能返回测试集的评分。
  • cross_validate 可以返回训练集评分、训练时间等更多信息。

4. 超参数搜索

在机器学习中,超参数对模型性能的影响很大。sklearn.model_selection 提供了两种自动化搜索超参数的方法:网格搜索和随机搜索。

4.1 网格搜索:GridSearchCV

  • 遍历所有参数组合,找到最优参数。

  • 基本用法

    from sklearn.model_selection import GridSearchCV
    from sklearn.svm import SVC
    
    # 创建模型和参数网格
    model = SVC()
    param_grid = {'kernel': ['linear', 'rbf'], 'C': [1, 10]}
    
    # 网格搜索
    grid_search = GridSearchCV(model, param_grid, cv=5)
    grid_search.fit(X, y)
    
    print("Best parameters:", grid_search.best_params_)
    print("Best score:", grid_search.best_score_)
    
  • 注意事项

    • 网格搜索耗时较长,适合搜索参数空间较小的情况。

4.2 随机搜索:RandomizedSearchCV

  • 随机采样参数组合,比网格搜索效率更高。
  • 基本用法
    from sklearn.model_selection import RandomizedSearchCV
    from sklearn.ensemble import RandomForestClassifier
    import numpy as np
    
    # 创建模型和参数分布
    model = RandomForestClassifier()
    param_dist = {'n_estimators': [10, 50, 100], 'max_depth': [3, 5, None]}
    
    # 随机搜索
    random_search = RandomizedSearchCV(model, param_distributions=param_dist, n_iter=10, cv=5, random_state=42)
    random_search.fit(X, y)
    
    print("Best parameters:", random_search.best_params_)
    print("Best score:", random_search.best_score_)
    

5. 数据划分策略

在交叉验证或数据集划分时,可以根据任务需求选择不同的划分策略。

5.1 K 折交叉验证(KFold

  • 将数据随机分为 K 份,每次使用其中一份作为测试集,其余作为训练集。
  • 基本用法
    from sklearn.model_selection import KFold
    
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    for train_idx, test_idx in kf.split(X):
        print("Train indices:", train_idx, "Test indices:", test_idx)
    

5.2 分层 K 折交叉验证(StratifiedKFold

  • 在每一折中,保证类别的分布比例与原始数据一致(适用于类别不平衡问题)。
  • 用法与 KFold 类似,只需替换为 StratifiedKFold

5.3 时间序列拆分(TimeSeriesSplit

  • 针对时间序列数据,保证训练集始终早于测试集。
  • 基本用法
    from sklearn.model_selection import TimeSeriesSplit
    
    tscv = TimeSeriesSplit(n_splits=5)
    for train_idx, test_idx in tscv.split(X):
        print("Train indices:", train_idx, "Test indices:", test_idx)
    

6. 学习曲线和验证曲线

6.1 学习曲线:learning_curve

  • 分析模型在不同训练数据规模下的表现。
  • 基本用法
    from sklearn.model_selection import learning_curve
    import matplotlib.pyplot as plt
    
    train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=5)
    plt.plot(train_sizes, train_scores.mean(axis=1), label='Train')
    plt.plot(train_sizes, test_scores.mean(axis=1), label='Test')
    plt.legend()
    plt.show()
    

6.2 验证曲线:validation_curve

  • 分析模型在不同超参数值下的表现。
  • 基本用法
    from sklearn.model_selection import validation_curve
    
    param_range = [1, 10, 100]
    train_scores, test_scores = validation_curve(model, X, y, param_name='C', param_range=param_range, cv=5)
    

7. 注意事项

  1. 随机种子

    • 确保随机划分的结果可复现,建议设置 random_state
  2. 类别不平衡问题

    • 使用分层划分策略(如 StratifiedKFold)保证类别比例一致。
  3. 时间序列数据

    • 避免使用普通交叉验证,建议使用 TimeSeriesSplit
  4. 参数选择

    • 网格搜索适合小范围精确搜索,随机搜索适合大范围粗略搜索。

总结

sklearn.model_selection 是机器学习工作流中必不可少的模块,涵盖数据划分、模型评估、超参数搜索以及学习曲线分析等多个功能。通过灵活运用这些工具,可以构建高效的模型优化流程,提高模型的泛化能力和性能。在具体场景中,根据数据特点选择合适的划分策略和评估方法,是保证模型成功的关键。

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐