字典序

Posted by kifish on April 25, 2018
idx = 1
cnt = 0
acc_all = []
pre_all = []
recall_all = []
f1_all = []
unable_idx = []
rootdir = os.path.abspath('./rawdata/ambiguous_sets')
for filename in os.listdir(rootdir):
    if filename[-9:] != '.DS_Store':
        filepath = os.path.join(rootdir,filename)
        x_persons_list,x_content_list,x_both,y = get_data(filepath)
        content_words, content_features_weights = train_and_pred.tf_idf(x_content_list) #文本
        coauthor_words, coauthor_features_weights = train_and_pred.tf_idf(x_persons_list)
        #pca = PCA(n_components=0.70)
        #content_features_weights = pca.fit_transform(content_features_weights)
        #pca = PCA(n_components=0.70)
        #coauthor_features_weights = pca.fit_transform(coauthor_features_weights)
        content_features_weights = content_features_weights.tolist() #content_features_weights的一个元素代表了一篇文本里的所有词的各个权重
        coauthor_features_weights = coauthor_features_weights.tolist()
        logging.debug('len of content_features_weights : {}'.format(len(content_features_weights)))
        logging.debug('len of coauthor_features_weights : {}'.format(len(coauthor_features_weights)))
        x = [] #data,每一行代表了一篇文本各个feature的值
        num_samples = len(content_features_weights)
        for i in range(num_samples):
            x.append(content_features_weights[i] + coauthor_features_weights[i])
        logging.info('the number of samples in X : {}'.format(len(x)))
        logging.info('the number of labels in y : {}'.format(len(y)))
        acc,pre,recall,f1 = train_and_pred.wrap2(x, y,classifier) #train.lr(x,y)返回的数据类型是np.array
        if acc == 0:
            cnt += 1
            unable_idx.append(idx)
        else:
            acc_all.append(acc)
            pre_all.append(pre)
            recall_all.append(recall)
            f1_all.append(f1)
        print('set ' + str(idx) + 'prediction result :')
        print('accuracy:{}%'.format(acc*100))
        print('precision:{}%'.format(pre*100))
        print('recall:{}%'.format(recall*100))
        print('f1-score:{}'.format(f1))
        idx += 1

有bug,因为文件名是按字典序拍的,假设idx为209,但是对应的文件不一定是’./rawdata/ambiguous_sets/ambiguous_set209’ 并且listdir似乎是随机的

变量y冲突了
def wrap2(x,y,classifier):
    X = x
    y = y
    #X = np.array(x)
    #y = np.array(y)
    #print(X)
    #print(y)
    class2sample_num = {}
    for item in y:
        if item in class2sample_num:
            class2sample_num[item] += 1
        else:
            class2sample_num[item] = 1
    min = 1000000
    for k,v in class2sample_num.items():
        if v < min:
            min = v
    flag = False
    if min <= 5:
        flag = True
    while True:
        if flag:
            print("minnnnnnnnnn")
            print(y)
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1,shuffle = True)  #随机抽取
        else:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,shuffle = True)




        #有可能会只有一类
        #要检查一下
        class_set = set()
        for y in y_train:
            class_set.add(y)
        if len(class_set) >= 2:
            break

    model = classifiers[classifier](X_train,y_train)
    pred = model.predict(X_test) ##返回预测概率最大的类别
    m_accuracy,m_precision,m_recall,f1_score = calculate_result(y_test,pred)
    return m_accuracy,m_precision,m_recall,f1_score