001package com.astrolabsoftware.FinkBrowser.HBaser.Clusteriser;
002
003import com.fasterxml.jackson.databind.ObjectMapper;
004import org.apache.commons.math3.linear.*;
005
006import java.util.List;
007import java.util.Arrays;
008import java.net.URL;
009import java.io.File;
010import java.io.IOException;
011
012// Log4J
013import org.apache.logging.log4j.Logger;
014import org.apache.logging.log4j.LogManager;
015
016/** <code>ClusterFinder</code> identifies HBase rows with
017  * clusters defined by previous clustering algorithm, read from
018  * <tt>JSON</tt> model files.
019  * @opt attributes
020  * @opt operations
021  * @opt types
022  * @opt visibility
023  * @author <a href="mailto:Julius.Hrivnac@cern.ch">J.Hrivnac</a> */
024public class ClusterFinder {
025  
026  public static void main(String[] args) throws IOException {
027    ClusterFinder finder = new ClusterFinder("/tmp/scaler_params.json",
028                                             "/tmp/pca_params.json",
029                                             "/tmp/cluster_centers.json");
030    double[] newData = {1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
031                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
032                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
033                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
034                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,};  // Example input
035    int cluster = finder.transformAndPredict(newData);
036    log.info("Assigned cluster: " + cluster);
037    }
038
039  public ClusterFinder(String scalerFile,
040                       String pcaFile,
041                       String clustersFile) throws IOException {
042    loadScalerParams(scalerFile);
043    loadPCAParams(pcaFile);
044    loadClusterCenters(clustersFile);
045    }
046
047  public ClusterFinder(URL scalerUrl,
048                       URL pcaUrl,
049                       URL clustersUrl) throws IOException {
050    loadScalerParams(scalerUrl);
051    loadPCAParams(pcaUrl);
052    loadClusterCenters(clustersUrl);
053    }
054  
055  private void loadScalerParams(String filePath) throws IOException {
056    ObjectMapper objectMapper = new ObjectMapper();
057    ScalerParams params = objectMapper.readValue(new File(filePath), ScalerParams.class);
058    _mean = params.mean;
059    _std = params.std;
060    log.debug("Scaler: " + _mean.length);
061    }
062  
063  private void loadPCAParams(String filePath) throws IOException {
064    ObjectMapper objectMapper = new ObjectMapper();
065    PCAParams params = objectMapper.readValue(new File(filePath), PCAParams.class);
066    _pcaComponents = new Array2DRowRealMatrix(params.components);
067    _explainedVariance = params.explained_variance;
068    log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension());
069    }
070    
071  private void loadClusterCenters(String filePath) throws IOException {
072    ObjectMapper objectMapper = new ObjectMapper();
073    _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(new File(filePath), double[][].class));
074    log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension());
075    }    
076    
077  private void loadScalerParams(URL url) throws IOException {
078    ObjectMapper objectMapper = new ObjectMapper();
079    ScalerParams params = objectMapper.readValue(url, ScalerParams.class);
080    _mean = params.mean;
081    _std = params.std;
082    log.debug("Scaler: " + _mean.length);
083    }
084  
085  private void loadPCAParams(URL url) throws IOException {
086    ObjectMapper objectMapper = new ObjectMapper();
087    PCAParams params = objectMapper.readValue(url, PCAParams.class);
088    _pcaComponents = new Array2DRowRealMatrix(params.components);
089    _explainedVariance = params.explained_variance;
090    log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension());
091    }
092    
093  private void loadClusterCenters(URL url) throws IOException {
094    ObjectMapper objectMapper = new ObjectMapper();
095    _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(url, double[][].class));
096    log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension());
097    }    
098
099  private double[] standardize(double[] input) {
100    double[] standardized = new double[input.length];
101    for (int i = 0; i < input.length; i++) {
102      if (_std[i] == 0) {
103        standardized[i] = 0;
104        }
105      else {
106        standardized[i] = (input[i] - _mean[i]) / _std[i];
107        }
108      }
109    log.debug("Standardized: " + standardized.length);
110    return standardized;
111    }
112  
113  private double[] applyPCA(double[] standardizedInput) {
114    RealVector inputVector = new ArrayRealVector(standardizedInput);
115    RealVector transformed = _pcaComponents.transpose().operate(inputVector);
116    log.debug("PCA Transformed: " + transformed.getDimension());
117    return transformed.toArray();
118    }
119  
120  /** Find the closest cluster from the transformed data.
121    * @param  transformedData The transformed input data.
122    * @return                 The (number of) the closest cluster.
123    *                         <tt>-1</tt> if it cannot be found with sufficient resolution. */
124  private int findClosestCluster(double[] transformedData) {
125    RealVector transformedVector = new ArrayRealVector(transformedData);
126    double minDistance  = Double.MAX_VALUE;
127    double minDistance2 = Double.MAX_VALUE;
128    int closestCluster = -1;
129    RealVector clusterCenter;
130    double distance;
131    for (int i = 0; i < _clusterCenters.getRowDimension(); i++) {
132      clusterCenter = _clusterCenters.getRowVector(i);
133      distance = transformedVector.getDistance(clusterCenter);
134      if (distance < minDistance2) {
135        if (distance < minDistance) {
136          minDistance2  = minDistance;
137          minDistance   = distance;
138          closestCluster = i;
139          }
140        else {
141          minDistance2 = distance;
142          }
143        }
144      }
145    if (minDistance < _separation * minDistance2) {
146      return closestCluster;
147      }
148    return -1;
149    }
150  
151  /** Transform provided data array and find the closest cluster.
152    * @param  inputData The original input data.
153    * @return           The (number of) the closest cluster.
154    *                   <tt>-1</tt> if it cannot be found with sufficient resolution. */
155  public int transformAndPredict(double[] inputData) {
156    double[] standardized = standardize(inputData);
157    double[] pcaTransformed = applyPCA(standardized);
158    return findClosestCluster(pcaTransformed);
159    }
160    
161  /** Set the minimal separation quotient.
162    * @param separation The minimal separation quotient.
163    *                   The ration between distance to closest and second closest
164    *                   cluster should be smaller than <tt>separation</tt>,
165    *                   otherwise cluster is not considered reliable.
166    *                   <tt>1</tt> gives no restriction. The default is <tt>0.5</tt>. */
167  private static void setSeparation(double separation) {
168    _separation = separation;
169    }
170  
171  private static double _separation = 0.5;  
172    
173  private double[] _mean;
174  
175  private double[] _std;
176  
177  private RealMatrix _pcaComponents;
178  
179  private double[] _explainedVariance;
180  
181  private RealMatrix _clusterCenters;
182    
183  /** Logging . */
184  private static Logger log = LogManager.getLogger(ClusterFinder.class);
185  
186  }