knn两个

news/2024/5/20 4:29:23 标签: hadoop, hdfs, 大数据

先考虑训练集很大,测试集很小的情况
将测试集设置为全局文件,由于测试集很小,所以用hdfs的方法,在每一个map节点都会计算出它与每一个测试集的距离,输出:Key=测试集ID,Value=标签,距离
然后在reduce中对距离进行排序,选取最小的前3个,把他们的标签放入List集合中,利用set集合去重的作用,将标签放入hashset(set集合的一种)中,再对与list集合中的标签进行比较(两个for循环),相等则计数+1,最终得到频率最高的标签即为该测试数据标签。输出key=测试数据ID,value=预测标签: 真实标签:
重点:1.因为排序中需要降序排列,所以需要重写compare方法,compare方法自动是按升序排列的;
2.对于测试数据全局文件,需要在map的setup中进行处理成数组的形式;
3.对于已给结果数据全局文件,需要在reduce的setup中进行处理成数组的形式
4.如果直接利用hashset,集合set中有很多方法没有被定义,所以,将hashset中得到的值存入另一个List中。

代码:package wordcount;
import java.awt.datatransfer.StringSelection;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.StringTokenizer;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.Mapper.Context;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.yarn.webapp.example.MyApp;

import com.google.common.base.Strings;

public class knn{

public static class TokenizerMapper
extends Mapper<Object, Text, IntWritable,Text >{
// 存放测试集路径
private String localFiles;
// 存放测试数据
private List test = new ArrayList();
@Override
public void setup(Context context) throws IOException,InterruptedException{
Configuration conf = context.getConfiguration();
// 获取测试集所在的hdfs路径
localFiles = conf.getStrings(“test”)[0];
FileSystem fs = FileSystem.get(URI.create(localFiles), conf);
FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));
// 从hdfs中读取测试集
InputStreamReader isr = new InputStreamReader(hdfsInStream, “utf-8”);
String line;
BufferedReader br = new BufferedReader(isr);
while ((line = br.readLine()) != null) {
StringTokenizer itr = new StringTokenizer(line);
while (itr.hasMoreTokens()) {
// System.out.println(itr.nextToken().split(",").getClass().getName().toString());
//每一行作为一个数组
String[] tmp = itr.nextToken().split(",");
List data = new ArrayList();
for (String i : tmp){
data.add(Double.parseDouble(i));
}
test.add(data);
}
}
// 存储了所有的测试集
System.out.println(“测试数据”);
System.out.println(test);
}

public void map(Object key, Text value, Context context
                ) throws IOException, InterruptedException {
  StringTokenizer itr = new StringTokenizer(value.toString());
  while (itr.hasMoreTokens()) {
	    // 将训练数据分割
	    String[] tmp = itr.nextToken().split(",");
	    // 记录该训练集的标签
	    String label = tmp[4];
	    // 记录该训练集的属性值
		List data = new ArrayList();
		for (int i = 0;i<=3;i++){
			data.add(Double.parseDouble(tmp[i]));
		}

// System.out.println(label);
// System.out.println(data);
for (int i = 0;i<test.size();i++){
// 获得每个测试数据
List tmp2 = (List) test.get(i);
// 每个测试数据和训练数据的距离(这里使用欧氏距离)
double dis = 0;
for (int j=0;j<4;j++){
dis += Math.pow( (double)tmp2.get(j)-(double)data.get(j),2);
}
dis = Math.sqrt(dis);
// out 为类标签,距离
String out = label + “,” + String.valueOf(dis);
// System.out.println(out.toString());
// i为测试数据的标号
// System.out.println(i);
context.write(new IntWritable(i), new Text(out));
}
}
}
}

public static class IntSumReducer
extends Reducer<IntWritable,Text,IntWritable,Text> {

    private String localFiles;
    private List tgt = new ArrayList();
    private int n;
    //读取测试集的标签
    @Override
	public void setup(Context context) throws IOException,InterruptedException{
    	Configuration conf = context.getConfiguration();
		// 获取测试集标签所在的hdfs路径
		localFiles  = conf.getStrings("label")[0];
		// 读取n值
		n = conf.getInt("n", 3);
		FileSystem fs = FileSystem.get(URI.create(localFiles), conf);  
		FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));  
		// 从hdfs中读取测试集
		InputStreamReader isr = new InputStreamReader(hdfsInStream, "utf-8");
		String line;
		BufferedReader br = new BufferedReader(isr);
		while ((line = br.readLine()) != null) {
			StringTokenizer itr = new StringTokenizer(line);
			while (itr.hasMoreTokens()) {

// System.out.println(itr.nextToken().split(",").getClass().getName().toString());
//每一行作为一个数组
tgt.add(itr.nextToken());
}
}
// 测试集标签
System.out.println(“测试集标签”);
System.out.println(tgt);
}

public void reduce(IntWritable key, Iterable values,
Context context
) throws IOException, InterruptedException {

List<String> sortvalue = new ArrayList<String>();
// 将每个值放入list中方便排序
for (Text val : values) {

// System.out.println("###");
// System.out.println(val.toString());
sortvalue.add(val.toString());
}

// 对距离进行排序
Collections.sort(sortvalue, new Comparator<String>() {
	 
    @Override
    public int compare(String o1, String o2) {
        // 升序
        //return o1.getAge()-o2.getAge();
    	double x = Double.parseDouble(o1.split(",")[1]); 
    	double y = Double.parseDouble(o2.split(",")[1]); 
        return Double.compare(x, y);
        // 降序
        // return Double.compare(y, x);
    }
});

// System.out.println(sortvalue.toString());
// 存放前n个数据的标签
List labels = new ArrayList();
for (int i =0;i<n;i++){
labels.add(sortvalue.get(i).split(",")[0]);
}
// 将标签转换成集合方便计数
Set set = new HashSet<>();
set.addAll(labels);
List labelset = new ArrayList<>(set);
int[] count = new int[labelset.size()];
// 将计数数组全部初始化为0
for (int i=0;i<count.length;i++){
count[i] = 0;
}
// 对每个标签计数得到count,位置对应labelset
for(int i=0;i<labelset.size();i++){
for (int j=0;j<labels.size();j++){
if (labelset.get(i).equals(labels.get(j))){
count[i] += 1;
}
}
}

// 求count最大值所在的索引
int max = 0;
for(int i=1;i<count.length;i++){
	if(count[i] > count[max]){
		max = i;
	}
}

context.write(key, new Text("预测标签:" + labelset.get(max) + "\t" + "真实标签:" + String.valueOf(tgt.get(key.get()))));
}

}

public static void main(String[] args) throws Exception {
Configuration conf = new Configuration();
//String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
//考虑的是测试集少量的情况,所以将测试集和测试集的标签的文件位置传入conf,在mapreduce中读取
conf.setStrings(“test”, “hdfs://localhost:9000/user/hadoop/knn/input/iris_test_data.csv”);
conf.setStrings(“label”, “hdfs://localhost:9000/user/hadoop/knn/input/iris_test_lable.csv”);
// 从命令行传入参数N
conf.setInt(“n”, Integer.parseInt(args[0]));
String[] otherArgs = new String[]{“hdfs://localhost:9000/user/hadoop/knn/input/iris_train.csv”,“hdfs://localhost:9000/user/hadoop/knn/output/”};
if (otherArgs.length < 2) {
System.err.println("Usage: wordcount […] ");
System.exit(2);
}
Job job = Job.getInstance(conf, “knn”);
job.setJarByClass(knn.class);
job.setMapperClass(TokenizerMapper.class);
job.setReducerClass(IntSumReducer.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(Text.class);
for (int i = 0; i < otherArgs.length - 1; ++i) {
//由于是训练集大的情况,所以将训练集输入
FileInputFormat.addInputPath(job, new Path(otherArgs[i]));
}
FileOutputFormat.setOutputPath(job,
new Path(otherArgs[otherArgs.length - 1]));
System.exit(job.waitForCompletion(true) ? 0 : 1);
}
}

结果:

考虑测试集很大,训练集很小的情况
这时将训练集作为全局文件,测试集分布到各个节点上
只需要在上一问基础上进行修改。将对训练集和结果集合的拆分放到mapper的setup中。在map中对于每一个测试集数据遍历训练集求出距离,按升序排序,选出前n个,再在其中选标签。
如何对比结果?
设置一个变量ID,每做一次迭代+1(这种方法是不行的,因为训练集分到了不同的节点上,次序就不一定了,我们程序跑出来对的上是因为我们的测试集还小),老师说不用对比结果,直接输出就行了,但是接下来的代码有结果对比(我懒得改了)
package wordcount;
import java.awt.datatransfer.StringSelection;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.StringTokenizer;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.Mapper.Context;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.yarn.webapp.example.MyApp;
import com.google.common.base.Strings;
public class knn{

public static class TokenizerMapper
extends Mapper<Object, Text, Text,Text >{
private int ID=0;//设置一个变量表示测试数据的序号
private String localFiles;//存训练集
private String localFiles1;//存结果标签
private List train = new ArrayList();//存拆分后的训练集
int n;
private List tgt = new ArrayList();//存拆分后的标签
@Override//重写方法之前需要加上的
public void setup(Context context) throws IOException,InterruptedException{
//hdfs设置全局变量重点
Configuration conf = context.getConfiguration();
// 获取测试集所在的hdfs路径
localFiles = conf.getStrings(“train”)[0];
FileSystem fs = FileSystem.get(URI.create(localFiles), conf);
FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));
// 从hdfs中读取测试集
InputStreamReader isr = new InputStreamReader(hdfsInStream, “utf-8”);
String line;
BufferedReader br = new BufferedReader(isr);
while ((line = br.readLine()) != null) {
StringTokenizer itr = new StringTokenizer(line);
//while (itr.hasMoreTokens()) {可以把这一行删掉,因为这是一行里只会有一个Token
//每一行作为一个数组
String[] tmp = itr.nextToken().split(",");//将数据集以逗号分割为数组
List data = new ArrayList();
for (int i = 0;i<=3;i++){
data.add(Double.parseDouble(tmp[i]));//将一组参数加入data中
}
data.add(tmp[4]);// 将训练集标签加入data中,训练集标签本来就是string型不用转换,所以在循环外面
train.add(data);
}
}
localFiles1 = conf.getStrings(“label”)[0];
// 读取n值
n = conf.getInt(“n”, 3);//从conf中获取n的值
FileSystem fsa = FileSystem.get(URI.create(localFiles1), conf);
FSDataInputStream hdfsInStreama = fsa.open(new Path(localFiles1));
// 从hdfs中读取标签集
InputStreamReader isra = new InputStreamReader(hdfsInStreama, “utf-8”);
String linea;
BufferedReader bra = new BufferedReader(isra);
while ((linea = bra.readLine()) != null) {
StringTokenizer itra = new StringTokenizer(linea);
while (itra.hasMoreTokens()) {
//每一行作为一个数组
tgt.add(itra.nextToken());
}
}
}

public void map(Object key, Text value, Context context
) throws IOException, InterruptedException {
StringTokenizer itr = new StringTokenizer(value.toString());
while (itr.hasMoreTokens()) {
ID++;
// 将测试数据分割
String[] tmp = itr.nextToken().split(",");
// 记录该测试集的属性值
List data = new ArrayList();
List sortvalue = new ArrayList();
for (int i = 0;i<=3;i++){
data.add(Double.parseDouble(tmp[i]));
}
String[] out=new String[train.size()];
for (int i = 0;i<train.size();i++){
// 获得每个测试数据
List tmp2 = (List) train.get(i);
// 每个测试数据和训练数据的距离(这里使用欧氏距离)
double dis = 0;
for (int j=0;j<4;j++){
dis += Math.pow( (double)tmp2.get(j)-(double)data.get(j),2);
}
dis = Math.sqrt(dis);
// out 为类标签,距离
out[i] = (String)tmp2.get(4)+ “,” + String.valueOf(dis);
sortvalue.add(out[i]);
}
// 对距离进行排序
Collections.sort(sortvalue, new Comparator() {
@Override
public int compare(String o1, String o2) {
double x = Double.parseDouble(o1.split(",")[1]);
double y = Double.parseDouble(o2.split(",")[1]);
return Double.compare(x, y);
}
});
// System.out.println(sortvalue.toString());
// 存放前n个数据的标签
List labels = new ArrayList();
for (int i =0;i<n;i++){
labels.add(sortvalue.get(i).split(",")[0]);
}
// 将标签转换成集合方便计数
Set set = new HashSet<>();
set.addAll(labels);
List labelset = new ArrayList<>(set);
int[] count = new int[labelset.size()];
// 将计数数组全部初始化为0
for (int i=0;i<count.length;i++){
count[i] = 0;
}
// 对每个标签计数得到count,位置对应labelset
for(int i=0;i<labelset.size();i++){
for (int j=0;j<labels.size();j++){
if (labelset.get(i).equals(labels.get(j))){
count[i] += 1;
}
}
}

// 求count最大值所在的索引
int max = 0;
for(int i=1;i<count.length;i++){
	if(count[i] > count[max]){
		max = i;
	}
}
IntWritable IDCOUNT=new IntWritable(ID);
context.write(new Text("预测标签:" + labelset.get(max) + "\t") ,new Text("真实标签:" + String.valueOf(tgt.get(IDCOUNT.get()-1))));
}

}
}

public static void main(String[] args) throws Exception {
Configuration conf = new Configuration(); //考虑的是测试集少量的情况,所以将测试集和测试集的标签的文件位置传入conf,在mapreduce中读取
conf.setStrings(“train”, “hdfs://localhost:9000/user/hadoop/knn/input/iris_train.csv”);
conf.setStrings(“label”, “hdfs://localhost:9000/user/hadoop/knn/input/iris_test_lable.csv”);
// 从命令行传入参数N
conf.setInt(“n”, Integer.parseInt(args[0]));//Java命令行参数是从0开始的
String[]otherArgs=new String[]{“hdfs://localhost:9000/user/hadoop/knn/input/iris_test_data.csv”,“hdfs://localhost:9000/user/hadoop/knn/output1/”};
if (otherArgs.length < 2) {
System.err.println("Usage: wordcount […] ");
System.exit(2);
}
Job job = Job.getInstance(conf, “knn”);
job.setJarByClass(knn.class);
job.setMapperClass(TokenizerMapper.class);
job.setOutputKeyClass(Object.class);
job.setOutputValueClass(Text.class);
job.setNumReduceTasks(0);//没有reduce也必须加上这一句,这涉及到mapper collection初始化问题
for (int i = 0; i < otherArgs.length - 1; ++i) {
//由于是训练集大的情况,所以将训练集输入
FileInputFormat.addInputPath(job, new Path(otherArgs[i]));
}
FileOutputFormat.setOutputPath(job,
new Path(otherArgs[otherArgs.length - 1]));
System.exit(job.waitForCompletion(true) ? 0 : 1);
}
}

结果截图


http://www.niftyadmin.cn/n/1738874.html

相关文章

在ubuntu20上配置vagrant + virtualbox+ centos7 or ubuntu

vagrant 是命令行管理创建虚拟机的工具。 下载&#xff1a;vagrant https://www.vagrantup.com/downloads下载:VirtualBox: https://www.virtualbox.org/wiki/Downloadsapt install -f -y virtualbox-6.1_6.1.32-149290_Ubuntu_eoan_amd64.deb下载&#xff1a;安装 VirtualBo…

BASH编程,SHELL编程等号前后不能空格

变量赋值时&#xff0c;号前后都不能有空格&#xff0c; 否则无法赋值成功。 变量赋值时&#xff0c;“”前后不能有空格引用数组全部时&#xff0c;要加“[]”&#xff0c;否则只能引用地一个连接字符串时&#xff0c;使用双引号&#xff0c;不用加“."和PHP不同。 #!/…

Kubernetes(k8s)安装部署配置

aliyun 的Severless Kubernet&#xff0c;无需配置&#xff0c;直接使用。但是还是要自己先熟悉一下过程吧。 接下来把笔记写一下。 一、用VirtualBox准备三台CentOS7 本机环境是ubuntu,作为Master&#xff0c;再用VirtualBox 准备两台CentOS7,作为worker。 部署过程参考我的上…

解决 Kubernetes 中 Kubelet 组件报 failed to get cgroup Failed to get system container stats 错误

Failed to get system container stats for “/system.slice/docker.service”: failed to get cgroup stats for “/system.slice/docker.service”: failed to get container info for “/system.slice/docker.service”: unknown container “/system.slice/docker.service”…

rsync使用方法 妙用 比scp好

比scp好用的地方是&#xff0c;如果本地存在文件&#xff0c;就不再重复。还可以在目标中删除源中不存在的文件 [rootlocalhost ~]# rsync [OPTION] SRC DEST [rootlocalhost ~]# rsync [OPTION] SRC [USER]HOST:DEST [rootlocalhost ~]# rsync [OPTION] [USER]HOST:SRC DEST …

k8s 对外服务之ingress

k8s配置好了&#xff0c;要对外服务&#xff0c;怎么办&#xff0c; 参考&#xff1a; https://cloud.tencent.com/developer/article/1903063 用官方提供的ingress-nginx: ingress with tls: https://kubernetes.io/docs/concepts/services-networking/ingress/#tls apiVe…

kubernetes/k8s中的Volume网络卷本地卷讲透了

kubernetes/k8s中的Volume网络卷本地卷讲透了 Docker Kubernetes Volume 本地数据卷 Docker Kubernetes Volume 网络数据卷 让mysql使用PVC来持久保存数据&#xff1a; https://cloud.tencent.com/developer/article/1696182

error: unable to upgrade connection: pod does not exist 解决方案

在vagrant 创建的VirtualBox 里部署了k8s , [rootmaster yaml]# kubectl get pods NAME READY STATUS RESTARTS AGE app-7bcbdd4dfd-l2xkv 0/1 CrashLoopBackOff 582 2d5h db-756759796-gfl8d 1/1…