001package com.astrolabsoftware.FinkBrowser.HBaser.Clusteriser;
002
003import com.fasterxml.jackson.databind.ObjectMapper;
004import org.apache.commons.math3.linear.*;
005
006import java.net.URL;
007import java.io.File;
008import java.io.IOException;
009
010// Log4J
011import org.apache.logging.log4j.Logger;
012import org.apache.logging.log4j.LogManager;
013
014/** <code>ClusterFinder</code> identifies HBase rows with
015  * clusters defined by previous clustering algorithm, read from
016  * <tt>JSON</tt> model files.
017  * @opt attributes
018  * @opt operations
019  * @opt types
020  * @opt visibility
021  * @author <a href="mailto:Julius.Hrivnac@cern.ch">J.Hrivnac</a> */
022public class ClusterFinder {
023  
024  public static void main(String[] args) throws IOException {
025    ClusterFinder finder = new ClusterFinder("/tmp/scaler_params.json",
026                                             "/tmp/pca_params.json",
027                                             "/tmp/cluster_centers.json");
028    double[] newData = {1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
029                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
030                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
031                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,
032                        1.2, 3.4, 5.6, 7.8, 2.1, 4.3, 6.5, 8.7, 3.2, 5.4,};  // Example input
033    int cluster = finder.transformAndPredict(newData);
034    log.info("Assigned cluster: " + cluster);
035    }
036
037  public ClusterFinder(String scalerFile,
038                       String pcaFile,
039                       String clustersFile) throws IOException {
040    loadScalerParams(scalerFile);
041    loadPCAParams(pcaFile);
042    loadClusterCenters(clustersFile);
043    }
044
045  public ClusterFinder(URL scalerUrl,
046                       URL pcaUrl,
047                       URL clustersUrl) throws IOException {
048    loadScalerParams(scalerUrl);
049    loadPCAParams(pcaUrl);
050    loadClusterCenters(clustersUrl);
051    }
052  
053  private void loadScalerParams(String filePath) throws IOException {
054    ObjectMapper objectMapper = new ObjectMapper();
055    ScalerParams params = objectMapper.readValue(new File(filePath), ScalerParams.class);
056    _mean = params.mean;
057    _std = params.std;
058    log.debug("Scaler: " + _mean.length);
059    }
060  
061  private void loadPCAParams(String filePath) throws IOException {
062    ObjectMapper objectMapper = new ObjectMapper();
063    PCAParams params = objectMapper.readValue(new File(filePath), PCAParams.class);
064    _pcaComponents = new Array2DRowRealMatrix(params.components);
065    _explainedVariance = params.explained_variance;
066    log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension());
067    }
068    
069  private void loadClusterCenters(String filePath) throws IOException {
070    ObjectMapper objectMapper = new ObjectMapper();
071    _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(new File(filePath), double[][].class));
072    log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension());
073    }    
074    
075  private void loadScalerParams(URL url) throws IOException {
076    ObjectMapper objectMapper = new ObjectMapper();
077    ScalerParams params = objectMapper.readValue(url, ScalerParams.class);
078    _mean = params.mean;
079    _std = params.std;
080    log.debug("Scaler: " + _mean.length);
081    }
082  
083  private void loadPCAParams(URL url) throws IOException {
084    ObjectMapper objectMapper = new ObjectMapper();
085    PCAParams params = objectMapper.readValue(url, PCAParams.class);
086    _pcaComponents = new Array2DRowRealMatrix(params.components);
087    _explainedVariance = params.explained_variance;
088    log.debug("PCA Components: " + _pcaComponents.getColumnDimension() + " * " + _pcaComponents.getRowDimension());
089    }
090    
091  private void loadClusterCenters(URL url) throws IOException {
092    ObjectMapper objectMapper = new ObjectMapper();
093    _clusterCenters = new Array2DRowRealMatrix(objectMapper.readValue(url, double[][].class));
094    log.debug("Cluster Centers: " + _clusterCenters.getColumnDimension() + " * " + _clusterCenters.getRowDimension());
095    }    
096
097  private double[] standardize(double[] input) {
098    double[] standardized = new double[input.length];
099    for (int i = 0; i < input.length; i++) {
100      if (_std[i] == 0) {
101        standardized[i] = 0;
102        }
103      else {
104        standardized[i] = (input[i] - _mean[i]) / _std[i];
105        }
106      }
107    log.debug("Standardized: " + standardized.length);
108    return standardized;
109    }
110  
111  private double[] applyPCA(double[] standardizedInput) {
112    RealVector inputVector = new ArrayRealVector(standardizedInput);
113    RealVector transformed = _pcaComponents.transpose().operate(inputVector);
114    log.debug("PCA Transformed: " + transformed.getDimension());
115    return transformed.toArray();
116    }
117  
118  /** Find the closest cluster from the transformed data.
119    * @param  transformedData The transformed input data.
120    * @return                 The (number of) the closest cluster.
121    *                         <tt>-1</tt> if it cannot be found with sufficient resolution. */
122  private int findClosestCluster(double[] transformedData) {
123    RealVector transformedVector = new ArrayRealVector(transformedData);
124    double minDistance  = Double.MAX_VALUE;
125    double minDistance2 = Double.MAX_VALUE;
126    int closestCluster = -1;
127    RealVector clusterCenter;
128    double distance;
129    for (int i = 0; i < _clusterCenters.getRowDimension(); i++) {
130      clusterCenter = _clusterCenters.getRowVector(i);
131      distance = transformedVector.getDistance(clusterCenter);
132      if (distance < minDistance2) {
133        if (distance < minDistance) {
134          minDistance2  = minDistance;
135          minDistance   = distance;
136          closestCluster = i;
137          }
138        else {
139          minDistance2 = distance;
140          }
141        }
142      }
143    if (minDistance < _separation * minDistance2) {
144      return closestCluster;
145      }
146    return -1;
147    }
148  
149  /** Transform provided data array and find the closest cluster.
150    * @param  inputData The original input data.
151    * @return           The (number of) the closest cluster.
152    *                   <tt>-1</tt> if it cannot be found with sufficient resolution. */
153  public int transformAndPredict(double[] inputData) {
154    double[] standardized = standardize(inputData);
155    double[] pcaTransformed = applyPCA(standardized);
156    return findClosestCluster(pcaTransformed);
157    }
158    
159  /** Set the minimal separation quotient.
160    * @param separation The minimal separation quotient.
161    *                   The ration between distance to closest and second closest
162    *                   cluster should be smaller than <tt>separation</tt>,
163    *                   otherwise cluster is not considered reliable.
164    *                   <tt>1</tt> gives no restriction. The default is <tt>0.5</tt>. */
165  private static void setSeparation(double separation) {
166    _separation = separation;
167    }
168  
169  private static double _separation = 0.5;  
170    
171  private double[] _mean;
172  
173  private double[] _std;
174  
175  private RealMatrix _pcaComponents;
176  
177  private double[] _explainedVariance;
178  
179  private RealMatrix _clusterCenters;
180    
181  /** Logging . */
182  private static Logger log = LogManager.getLogger(ClusterFinder.class);
183  
184  }