1 minute read

3092. Most Frequent IDs

Solution

class Solution {
    public long[] mostFrequentIDs(int[] nums, int[] freq) {
        MyPQ pq = new MyPQ(nums.length);
        int cur = 0;
        long[] res = new long[nums.length];
        
        while (cur < nums.length) {
            if (!pq.contains(nums[cur])) {
                pq.add(new Node(nums[cur], freq[cur]));
            } else {
                // pq contains this value
                pq.update(nums[cur], freq[cur]);
            }
            
            res[cur] = pq.getCurMax();
            cur++;
        }
        
        return res;
    }
    
    
    class MyPQ {
        Node[] arr;
        Map<Node, Integer> position;
        Map<Integer, Node> nodes;
        int length = 0;
        
        public MyPQ(int n) {
            arr = new Node[n];
            position = new HashMap<>();
            nodes = new HashMap<>();
        }
        
        public void add(Node node) {
            nodes.put(node.val, node);
            position.put(node, length);
            arr[length++] = node;
            heapInsert(length - 1);
        }
        
        public void update(int value, int countChange) {
            Node node = nodes.get(value);
            int index = position.get(node);
            node.count += countChange;
            heapInsert(index);
            heapify(index);
        }
        
        public boolean contains(int v) {
            return nodes.containsKey(v);
        }
        
        public long getCurMax() {
            return arr[0].count;
        }
        
        private void heapInsert(int index) {
            int parent = (index - 1) / 2;
            
            while (parent != index) {
                if(arr[parent].count < arr[index].count) {
                    swap(parent, index);
                    index = parent;
                    parent = (index - 1) / 2;
                } else {
                    break;
                }
            }
        }
        
        private void heapify(int index) {
            int leftChild = index * 2 + 1;
            int maxChild = 0;
            while (leftChild < length) {
                maxChild = leftChild;
                if (leftChild + 1 < length) {
                    maxChild = arr[leftChild].count > arr[leftChild + 1].count ? leftChild : leftChild + 1;
                }
                
                if (arr[maxChild].count > arr[index].count) {
                    swap(maxChild, index);
                    index = maxChild;
                    leftChild = index * 2 + 1;
                } else {
                    break;
                }
            }
        }
        
        private void swap(int i, int j) {
            Node n1 = arr[i];
            Node n2 = arr[j];
            arr[i] = n2;
            arr[j] = n1;
            
            position.put(n1, j);
            position.put(n2, i);
        }
    }
    
    class Node {
        int val;
        long count;
        
        public Node(int v, int c) {
            val = v;
            count = c;
        }
    }
}