import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Date;
import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author coffee
*
*/
public class NIoTest {
private static Logger logger = LoggerFactory.getLogger(NIoTest.class);
private Selector acceptSelector;
private Selector rwSelector;
private BlockingQueue<SocketChannel> waitRegeditChannel = new LinkedBlockingQueue<SocketChannel>();
public static void main(String[] args) {
NIoTest ns = new NIoTest();
ns.start();
}
public void start() {
InetSocketAddress localAddress = new InetSocketAddress("127.0.0.1", 8888);
ServerSocketChannel serverSocketChannel;
try {
acceptSelector = Selector.open();
rwSelector = Selector.open();
serverSocketChannel = ServerSocketChannel.open();
// 非堵塞
serverSocketChannel.configureBlocking(false);
ServerSocket socket = serverSocketChannel.socket();
// 端口不复用
socket.setReuseAddress(false);
socket.setSoTimeout(60000);
socket.setReceiveBufferSize(1024);
socket.bind(localAddress);
serverSocketChannel.register(acceptSelector, SelectionKey.OP_ACCEPT);
Executor e = Executors.newFixedThreadPool(2);// 这里可以不用线程池
e.execute(new Accept());
e.execute(new RWThread());
} catch (IOException e) {
e.printStackTrace();
}
}
public class Accept implements Runnable {
@Override
public void run() {
while (true) {
try {
int count = acceptSelector.select(500);
// logger.debug("accept");
if (count > 0) {
Iterator<SelectionKey> keys = acceptSelector.selectedKeys().iterator();
while (keys.hasNext()) {
SelectionKey key = keys.next();
// 一定要删除
keys.remove();
ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key.channel();
// 接受了才能获取连接的通道
SocketChannel socketChannel = serverSocketChannel.accept();
socketChannel.configureBlocking(false);
// 取消以下注释代码,会导致通道在选择器中注册的时候与选择器在选择的时候争抢互斥锁,很难被注册进去。
// logger.debug("开始注册连接");
// socketChannel.register(rwSelector,
// SelectionKey.OP_READ);
// logger.debug("结束注册连接");
waitRegeditChannel.put(socketChannel);
// 当然,可以建立一个选择器池,并发处理接受的连接,具体如何实现自己扩展
rwSelector.wakeup();
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
private class RWThread implements Runnable {
/*
* (non-Javadoc)
*
* @see java.lang.Thread#run()
*/
@Override
public void run() {
while (true) {
try {
while (!waitRegeditChannel.isEmpty()) {
SocketChannel socketChannel = waitRegeditChannel.poll();
socketChannel.register(rwSelector, SelectionKey.OP_READ);// 此处需要改造
logger.debug("注册了一个连接:" + socketChannel.socket());
}
int count = rwSelector.select(1000);
// logger.debug("rw");
if (count > 0) {
Iterator<SelectionKey> keys = rwSelector.selectedKeys().iterator();
while (keys.hasNext()) {
SelectionKey key = keys.next();
keys.remove();
// 此处可以扩展为将数据放到线程池中处理,这样可以提高数据的吞吐量,但是要注意内存的处理
processKey(key);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
private void processKey(SelectionKey key) {
SocketChannel socketChannel = (SocketChannel) key.channel();
ByteBuffer bb = ByteBuffer.allocate(1024);
int count;
try {
// 此处加断点以后可以明显看到,OS底层的TCP会缓存数据,read的时候将会一次性读出来。
count = socketChannel.read(bb);
if (count < 0) {
// 已经读到流的结尾,或连接异常,需要关闭连接
socketChannel.close();
// 注意key.cancel()是在下次select()的时候才会被清理
key.cancel();
return;
}
} catch (IOException e) {
e.printStackTrace();
}
// buffer的使用一定要好好看看API,buffer的熟练使用对NIO编程很重要
bb.flip();
int limit = bb.limit();
byte[] tmpbytes = new byte[limit];
bb.get(tmpbytes);
logger.debug("接受信息为:" + new String(tmpbytes));
if (!isCache(key, tmpbytes)) {
byte[] bytes = (byte[]) key.attachment();
String requestStr = new String(bytes);
logger.debug("请求字符串:" + requestStr);
bb.clear();
if (requestStr.equals("gettime")) {
bb.put(new Date().toString().getBytes());
key.attach(new byte[0]);
} else if (requestStr.endsWith("clear")) {
key.attach(new byte[0]);
try {
bb.put("缓存已清理".getBytes("GB2312"));
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
} else {
try {
bb.put("不能识别的请求".getBytes("GB2312"));
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
}
bb.flip();
try {
socketChannel.write(bb);
} catch (IOException e) {
e.printStackTrace();
}
}
}
private boolean isCache(SelectionKey key, byte[] tmpbytes) {
Object obj = key.attachment();
byte[] bytes;
if (obj != null) {
bytes = (byte[]) obj;
} else {
bytes = new byte[0];
}
int sumLength = bytes.length + tmpbytes.length;
ByteBuffer bb = ByteBuffer.allocate(sumLength);
bb.put(bytes);
bb.put(tmpbytes);
bb.flip();
tmpbytes = bb.array();
if (tmpbytes[sumLength - 1] == 10) {
tmpbytes = new byte[sumLength - 2];
bb.get(tmpbytes);
key.attach(tmpbytes);
return false;
} else {
key.attach(tmpbytes);
return true;
}
}
}
}
|