001package com.astrolabsoftware.FinkBrowser.HBaser.Clusteriser; 002 003import com.fasterxml.jackson.databind.ObjectMapper; 004import org.apache.commons.math3.linear.*; 005 006import java.net.URL; 007import java.io.File; 008import java.io.IOException; 009 010// Log4J 011import org.apache.logging.log4j.Logger; 012import org.apache.logging.log4j.LogManager; 013 014/** <code>ClusterFinder</code> identifies HBase rows with 015 * clusters defined by previous clustering algorithm, read from 016 * <tt>JSON</tt> model files. 017 * @opt attributes 018 * @opt operations 019 * @opt types 020 * @opt visibility 021 * @author <a href="mailto:Julius.Hrivnac@cern.ch">J.Hrivnac</a> */ 022public class ClusterFinder { 023 024 public static void main(String[] args) throws IOException { 025 ClusterFinder finder = new ClusterFinder("/tmp/scaler_params.json", 026 "/tmp/pca_params.json", 027 "/tmp/cluster_centers.json"); 028 double[] newData = {1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4, 029 1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4, 030 1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4, 031 1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4, 032 1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,}; // Example input 033 int cluster = finder.transformAndPredict(newData); 034 log.info("Assigned cluster: " + cluster); 035 } 036 037 public ClusterFinder(String scalerFile, 038 String pcaFile, 039 String clustersFile) throws IOException { 040 loadScalerParams(scalerFile); 041 loadPCAParams(pcaFile); 042 loadClusterCenters(clustersFile); 043 } 044 045 public ClusterFinder(URL scalerUrl, 046 URL pcaUrl, 047 URL clustersUrl) throws IOException { 048 loadScalerParams(scalerUrl); 049 loadPCAParams(pcaUrl); 050 loadClusterCenters(clustersUrl); 051 } 052 053 private void loadScalerParams(String filePath) throws IOException { 054 ObjectMapper objectMapper = new ObjectMapper(); 055 ScalerParams params = objectMapper.readValue(new File(filePath), ScalerParams.class); 056 _mean = params.mean; 057 _std = params.std; 058 log.debug("Scaler: " + _mean.length); 059 } 060 061 private void loadPCAParams(String filePath) throws IOException { 062 ObjectMapper objectMapper = new ObjectMapper(); 063 PCAParams params = objectMapper.readValue(new File(filePath), PCAParams.class); 064 _pcaComponents = new Array2DRowRealMatrix(params.components); 065 _explainedVariance = params.explained_variance; 066 log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension()); 067 } 068 069 private void loadClusterCenters(String filePath) throws IOException { 070 ObjectMapper objectMapper = new ObjectMapper(); 071 _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(new File(filePath), double[][].class)); 072 log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension()); 073 } 074 075 private void loadScalerParams(URL url) throws IOException { 076 ObjectMapper objectMapper = new ObjectMapper(); 077 ScalerParams params = objectMapper.readValue(url, ScalerParams.class); 078 _mean = params.mean; 079 _std = params.std; 080 log.debug("Scaler: " + _mean.length); 081 } 082 083 private void loadPCAParams(URL url) throws IOException { 084 ObjectMapper objectMapper = new ObjectMapper(); 085 PCAParams params = objectMapper.readValue(url, PCAParams.class); 086 _pcaComponents = new Array2DRowRealMatrix(params.components); 087 _explainedVariance = params.explained_variance; 088 log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension()); 089 } 090 091 private void loadClusterCenters(URL url) throws IOException { 092 ObjectMapper objectMapper = new ObjectMapper(); 093 _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(url, double[][].class)); 094 log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension()); 095 } 096 097 private double[] standardize(double[] input) { 098 double[] standardized = new double[input.length]; 099 for (int i = 0; i < input.length; i++) { 100 if (_std[i] == 0) { 101 standardized[i] = 0; 102 } 103 else { 104 standardized[i] = (input[i] - _mean[i]) / _std[i]; 105 } 106 } 107 log.debug("Standardized: " + standardized.length); 108 return standardized; 109 } 110 111 private double[] applyPCA(double[] standardizedInput) { 112 RealVector inputVector = new ArrayRealVector(standardizedInput); 113 RealVector transformed = _pcaComponents.transpose().operate(inputVector); 114 log.debug("PCA Transformed: " + transformed.getDimension()); 115 return transformed.toArray(); 116 } 117 118 /** Find the closest cluster from the transformed data. 119 * @param transformedData The transformed input data. 120 * @return The (number of) the closest cluster. 121 * <tt>-1</tt> if it cannot be found with sufficient resolution. */ 122 private int findClosestCluster(double[] transformedData) { 123 RealVector transformedVector = new ArrayRealVector(transformedData); 124 double minDistance = Double.MAX_VALUE; 125 double minDistance2 = Double.MAX_VALUE; 126 int closestCluster = -1; 127 RealVector clusterCenter; 128 double distance; 129 for (int i = 0; i < _clusterCenters.getRowDimension(); i++) { 130 clusterCenter = _clusterCenters.getRowVector(i); 131 distance = transformedVector.getDistance(clusterCenter); 132 if (distance < minDistance2) { 133 if (distance < minDistance) { 134 minDistance2 = minDistance; 135 minDistance = distance; 136 closestCluster = i; 137 } 138 else { 139 minDistance2 = distance; 140 } 141 } 142 } 143 if (minDistance < _separation * minDistance2) { 144 return closestCluster; 145 } 146 return -1; 147 } 148 149 /** Transform provided data array and find the closest cluster. 150 * @param inputData The original input data. 151 * @return The (number of) the closest cluster. 152 * <tt>-1</tt> if it cannot be found with sufficient resolution. */ 153 public int transformAndPredict(double[] inputData) { 154 double[] standardized = standardize(inputData); 155 double[] pcaTransformed = applyPCA(standardized); 156 return findClosestCluster(pcaTransformed); 157 } 158 159 /** Set the minimal separation quotient. 160 * @param separation The minimal separation quotient. 161 * The ration between distance to closest and second closest 162 * cluster should be smaller than <tt>separation</tt>, 163 * otherwise cluster is not considered reliable. 164 * <tt>1</tt> gives no restriction. The default is <tt>0.5</tt>. */ 165 private static void setSeparation(double separation) { 166 _separation = separation; 167 } 168 169 private static double _separation = 0.5; 170 171 private double[] _mean; 172 173 private double[] _std; 174 175 private RealMatrix _pcaComponents; 176 177 private double[] _explainedVariance; 178 179 private RealMatrix _clusterCenters; 180 181 /** Logging . */ 182 private static Logger log = LogManager.getLogger(ClusterFinder.class); 183 184 }