本文将用启发式的方式去做java hashmap的源码分析。
刚入门ACM的时候,经常会处理字符串相关问题。下面有一道很常规的题:
计算一个小写英文字符串中每个字符出现的频率。
如果不了解ASCII,用Java直接实现:
public class TestMap {
public static void main(String[] args) {
String input = "abcdefade";
Map<Character, Integer> charMap = new HashMap<Character, Integer>();
for (int i = 0; i < input.length(); i++) {
char c = input.charAt(i);
Integer count = charMap.get(c);
count = count==null?1:count+1;
charMap.put(c, count);
}
for (Entry<Character, Integer> entry : charMap.entrySet()) {
System.out.println(entry.getKey() + ":" + entry.getValue());
}
}
}
但是熟悉ACM,一般都会写成如下:
public class TestMap2 {
public static void main(String[] args) {
String input = "abcdefade";
int[] counts = new int[26];
for (int i = 0; i < input.length(); i++) {
char c = input.charAt(i);
counts[c-97]++;
}
for (int i = 0; i < counts.length; i++) {
if(counts[i] != 0) {
System.out.println((char)(i+97) + ":" + counts[i]);
}
}
}
}
稍微变化一下,我们就得到了假的HashMap。
public class TestMap2 {
public static void main(String[] args) {
String input = "abcdefade";
PseudoMap charMap = new PseudoMap();
for (int i = 0; i < input.length(); i++) {
char c = input.charAt(i);
Integer count = charMap.get(c);
count = count==0?1:count+1;
charMap.put(c, count);
}
charMap.print();
}
private static class PseudoMap {
private int[] counts = new int[26];
public void put(char c, int count) {
counts[c-97] = count;
}
public int get(char c) {
return counts[c-97];
}
public void print() {
for (int i = 0; i < counts.length; i++) {
if(counts[i] != 0) {
System.out.println((char)(i+97) + ":" + counts[i]);
}
}
}
}
}
对比一下我们不难发现,map的本质就是以key为下标数组,是运用了计算机中空间换时间这种常用思想。不过上述的假map中有一个显而易见的问题,就是key值必须是char类型,如何扩展到所有类型是一个必须要解决的问题。这时候就要轮到hash算法登场了。Java中的hash算法一般都是根据对象的情况来算出int型的hashcode,例如Integer就是本身,String是s[0]*31^(n-1) + s[1]*31^(n-2) + ... + s[n-1]等,这时候我们通过hashcode作为下标就能解决只能是char类型的限制。代码如下:
public class TestMap3 {
public static void main(String[] args) {
String input = "abcdefade";
PseudoMap<Character, Integer> charMap = new PseudoMap<Character, Integer>();
for (int i = 0; i < input.length(); i++) {
char c = input.charAt(i);
Integer count = charMap.get(c);
count = count==null?1:count+1;
charMap.put(c, count);
}
charMap.print();
}
private static class PseudoMap<K, V> {
private Node<K, V>[] table = (Node<K,V>[])new Node[1000];
public void put(K key, V value) {
table[key.hashCode()] = new Node(key, value);
}
public V get(K key) {
Node<K, V> node = table[key.hashCode()];
return node==null?null:node.getValue();
}
public void print() {
for (int i = 0; i < table.length; i++) {
if(table[i] != null) {
System.out.println(table[i]);
}
}
}
private static class Node<K, V> {
K key;
V value;
public Node(K key, V value) {
this.key = key;
this.value = value;
}
public V getValue() {
return value;
}
@Override
public String toString() {
return key + ":" + value;
}
}
}
}
但是这时候又会带了新的问题:hashcode是int,int值很大,不可能直接就创建一个int上限的数组。如何让一个int变成一定范围内的数值,我们不妨把问题变成如何把问题变成如何让一个int变成10以内的正整数(0-9),显而易见用取余mod。变换一下代码如下:
private static class PseudoMap<K, V> {
private int len = 10;
public void put(K key, V value) {
int i = key.hashCode() % len;
table[i] = new Node(key, value);
}
public V get(K key) {
int i = key.hashCode() % len;
Node<K, V> node = table[i];
return node==null?null:node.getValue();
}
}
10进制一般用于数学计算,2进制一般用于计算机,所以我们不妨取len为2^n。如果len = 8,8的二进制是1000,任何一个int数m能表示成二进制bb...bbaaa(b,a都是0,1),即m=8*(bb...bb)+aaa,对8取余就是aaa,所以如果len是2^n,余数就是m二进制取后n-1,所以modM=m%len=m&(len-1),变换一下代码如下:
private static class PseudoMap<K, V> {
private int len = 1 << 4;
private Node<K, V>[] table = (Node<K,V>[])new Node[len];
public void put(K key, V value) {
table[getIndex(key.hashCode())] = new Node(key, value);
}
public V get(K key) {
Node<K, V> node = table[getIndex(key.hashCode())];
return node==null?null:node.getValue();
}
private int getIndex(int hashCode) {
return hashCode & (len-1);
}
}
借助取余的思想能解决数组上限的问题,但是会加剧冲突hash冲突,例如1和17,对于16取余都是1,解决这个问题可以在每个table的位置上放多个元素。这边我们用链表来实现:
private static class PseudoMap<K, V> {
private int len = 1 << 4;
private Node<K, V>[] table = (Node<K,V>[])new Node[len];
public void put(K key, V value) {
int index = getIndex(key.hashCode());
Node<K, V> node = table[index];
if(node == null) {
table[index] = new Node(key, value, null);
} else {
Node<K, V> tempNode = node;
while(true) {
if(tempNode.getKey().equals(key)) {
tempNode.setValue(value);
break;
}
if(tempNode.getNext() == null) {
table[index] = new Node(key, value, node);
break;
} else {
tempNode = tempNode.getNext();
}
}
}
}
public V get(K key) {
V v = null;
Node<K, V> node = table[getIndex(key.hashCode())];
if(node != null) {
while(true) {
if(node.getKey().equals(key)) {
v = node.getValue();
break;
}
if(node.getNext() == null) {
break;
} else {
node = node.getNext();
}
}
}
return v;
}
private int getIndex(int hashCode) {
return hashCode & (len-1);
}
public void print() {
for (int i = 0; i < table.length; i++) {
if(table[i] != null) {
System.out.println(table[i]);
}
}
}
private static class Node<K, V> {
K key;
V value;
Node<K, V> next;
public Node(K key, V value, Node<K, V> next) {
this.key = key;
this.value = value;
this.next = next;
}
public K getKey() {
return key;
}
public V getValue() {
return value;
}
public void setValue(V value) {
this.value = value;
}
public Node<K, V> getNext() {
return next;
}
@Override
public String toString() {
return key + ":" + value;
}
}
}
如果len是固定大小,随着数据量的增涨,必然会导致链表过长。链表的查询效率是O(n),多链表或者链表过长,都会影响查询效率。解决这个问题就是尽量让node均匀分布在table中,所以要按需扩容。简单实现我们可以当数据个数==len选择扩容。代码如下:
private static class PseudoMap<K, V> {
private int len = 1 << 4;
private Node<K, V>[] table = (Node<K,V>[])new Node[len];
private int size = 0;
public void put(K key, V value) {
int hash = key.hashCode();
int index = getIndex(hash);
Node<K, V> node = table[index];
if(node == null) {
table[index] = new Node<K,V>(hash,key, value, null);
} else {
Node<K, V> tempNode = node;
while(true) {
if(tempNode.getKey().equals(key)) {
tempNode.setValue(value);
break;
}
if(tempNode.getNext() == null) {
table[index] = new Node<K,V>(hash, key, value, node);
break;
} else {
tempNode = tempNode.getNext();
}
}
}
size++;
if(size >= len) {
resize();
}
}
private void resize() {
int oldLen = len;
len = len << 1;
Node<K, V>[] newTable = (Node<K,V>[])new Node[len];
//copy table => newTable
for (int i = 0; i < oldLen; i++) {
Node<K,V> node = table[i];
if(node == null) {
continue;
}
// if i=1
// when len=16, key is 1 17 33 65
// when len=32, 1 33 -> index 1, 17 65 -> index 17(1+len)
while(true) {
int index = (node.hash&oldLen)==0?i:i+oldLen;
newTable[index] = new Node<K, V>(node.getHash(), node.getKey(), node.getValue(), newTable[index]);
if(node.getNext() != null) {
node = node.getNext();
} else {
break;
}
}
}
table = newTable;
}
public V get(K key) {
V v = null;
Node<K, V> node = table[getIndex(key.hashCode())];
if(node != null) {
while(true) {
if(node.getKey().equals(key)) {
v = node.getValue();
break;
}
if(node.getNext() == null) {
break;
} else {
node = node.getNext();
}
}
}
return v;
}
private int getIndex(int hashCode) {
return hashCode & (len-1);
}
public void print() {
for (int i = 0; i < table.length; i++) {
if(table[i] != null) {
System.out.println(table[i]);
}
}
}
private static class Node<K, V> {
int hash;
K key;
V value;
Node<K, V> next;
public Node(int hash, K key, V value, Node<K, V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}
public int getHash() {
return hash;
}
public K getKey() {
return key;
}
public V getValue() {
return value;
}
public void setValue(V value) {
this.value = value;
}
public Node<K, V> getNext() {
return next;
}
@Override
public String toString() {
return key + ":" + value;
}
}
}
到此我们的假hashMap就基本完成了,但是其实它还有很多细节没有完善,如果有感兴趣的,可以去看一下HashMap的源码,一定会有更多的理解。 代码只是本人一点拙见,如有任何问题,望斧正。