clique / src /LRMCseedsTUD_streamsafe.java
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.function.Consumer;
/**
* LRMCseedsTUD_streamsafe.java
*
* Streaming LRMC seeder for graph-classification datasets (e.g., TUDataset:
* ENZYMES, PROTEINS, COLLAB, D&D). Accepts either a single canonical
* 0-indexed undirected edgelist file (u v with u < v), or a directory of such
* files, and produces seeds JSON using the streaming reconstruction entry point
* in clique2_ablations_parallel2.
*
* Usage:
* Single file:
* java -Xmx4g LRMCseedsTUD_streamsafe input_graph.txt output_seeds.json [DIAM|INV_SQRT_LAMBDA2] [epsilon]
*
* Directory of files:
* java -Xmx4g LRMCseedsTUD_streamsafe input_dir output_dir [DIAM|INV_SQRT_LAMBDA2] [epsilon]
*
* Input edgelists are expected 0-indexed, undirected, no self-loops, one line
* per edge with u < v. Output JSON mirrors LRMCseedsReddit_streamsafe format.
*/
public class LRMCseedsTUD_streamsafe {
public static void main(String[] args) throws Exception {
if (args.length < 2) {
System.err.println("Usage: java LRMCseedsTUD_streamsafe <input_file_or_dir> <output_file_or_dir> [alpha_kind] [epsilon]");
return;
}
final Path inPath = Paths.get(args[0]);
final Path outPath = Paths.get(args[1]);
final AlphaKind alphaKind = (args.length >= 3 ? parseAlpha(args[2]) : AlphaKind.DIAM);
final double eps = (args.length >= 4 ? Double.parseDouble(args[3]) : 1e-6);
if (Files.isDirectory(inPath)) {
Files.createDirectories(outPath);
try (java.util.stream.Stream<Path> stream = Files.list(inPath)) {
stream.filter(Files::isRegularFile)
.filter(p -> {
String s = p.getFileName().toString().toLowerCase(Locale.ROOT);
return s.endsWith(".txt") || s.endsWith(".csv");
})
.forEach(p -> {
try {
GraphData Gi = loadEdgeList0Based(p);
String base = p.getFileName().toString().replaceFirst("\\.(txt|csv)$", "");
Path outFile = outPath.resolve(base + ".json");
runOnce(Gi, outFile, alphaKind, eps);
} catch (Exception e) {
throw new RuntimeException("Failed on " + p, e);
}
});
}
} else {
GraphData G = loadEdgeList0Based(inPath);
System.out.printf(Locale.US, "# Loaded edge list: n=%d, m=%d%n", G.n, G.m);
runOnce(G, outPath, alphaKind, eps);
}
}
static void runOnce(GraphData G, Path outSeeds, AlphaKind alphaKind, double eps) throws Exception {
PeakTracker tracker = new PeakTracker(G, eps, alphaKind);
// Streaming entry point (required by current pipeline)
clique2_ablations_parallel2.runLaplacianRMCStreaming(G.adj1Based, tracker);
tracker.writeJson(outSeeds);
System.out.println("# Done. wrote " + outSeeds.toAbsolutePath());
}
// Streaming peak tracker (same output schema as Reddit seeder)
static final class PeakTracker implements Consumer<clique2_ablations_parallel2.SnapshotDTO> {
final GraphData G;
final double epsilon;
final AlphaKind alphaKind;
final boolean[] inC;
final Map<Integer, Integer> bestIdxByComp = new LinkedHashMap<>();
final Map<Integer, Double> bestScoreByComp = new HashMap<>();
final List<Rec> arrivals = new ArrayList<>();
int idx = 0;
static final class Rec {
final int compId;
final int sid;
final double score;
final int[] nodes;
Rec(int compId, int sid, double score, int[] nodes) {
this.compId = compId; this.sid = sid; this.score = score; this.nodes = nodes;
}
}
PeakTracker(GraphData G, double epsilon, AlphaKind alphaKind) {
this.G = G; this.epsilon = epsilon; this.alphaKind = alphaKind;
this.inC = new boolean[G.n];
}
@Override
public void accept(clique2_ablations_parallel2.SnapshotDTO s) {
final int[] nodes = s.nodes;
final int k = nodes.length;
if (k == 0) return;
for (int u : nodes) inC[u] = true;
final double Q = s.Q;
final double sc = k / (Q + epsilon);
final int compId = s.componentId;
final int sid = idx++;
if (!bestIdxByComp.containsKey(compId) || sc > bestScoreByComp.get(compId)) {
bestIdxByComp.put(compId, sid);
bestScoreByComp.put(compId, sc);
}
arrivals.add(new Rec(compId, sid, sc, Arrays.copyOf(nodes, nodes.length)));
for (int u : nodes) inC[u] = false;
}
void writeJson(Path outJson) throws IOException {
final int n = G.n;
boolean[] covered = new boolean[n];
try (BufferedWriter w = Files.newBufferedWriter(outJson, StandardCharsets.UTF_8)) {
w.write("{\n");
w.write("\"meta\":{");
w.write("\"epsilon\":" + epsilon);
w.write(",\"alpha_kind\":\"" + (alphaKind == AlphaKind.DIAM ? "DIAM" : "INV_SQRT_LAMBDA2") + "\"");
w.write(",\"n\":" + G.n);
w.write(",\"m\":" + G.m);
w.write(",\"mode\":\"peaks_per_component+singletons(stream)\"");
w.write("},\n");
w.write("\"clusters\":[\n");
boolean first = true;
int nextClusterId = 0;
for (Rec r : arrivals) {
Integer best = bestIdxByComp.get(r.compId);
if (best != null && best == r.sid) {
if (!first) w.write(",\n");
first = false;
w.write(" {\"cluster_id\":" + (nextClusterId++));
w.write(",\"component_id\":" + r.compId);
w.write(",\"snapshot_id\":" + r.sid);
w.write(",\"score\":" + r.score);
w.write(",\"k_seed\":" + r.nodes.length);
w.write(",\"members\":" + intArrayToJson(r.nodes));
w.write(",\"seed_nodes\":" + intArrayToJson(r.nodes));
w.write("}");
for (int u : r.nodes) covered[u] = true;
}
}
for (int u = 0; u < n; u++) {
if (!covered[u]) {
if (!first) w.write(",\n");
first = false;
int[] singleton = new int[]{u};
w.write(" {\"cluster_id\":" + (nextClusterId++));
w.write(",\"component_id\":-1");
w.write(",\"snapshot_id\":-1");
w.write(",\"score\":0.0");
w.write(",\"k_seed\":1");
w.write(",\"members\":" + intArrayToJson(singleton));
w.write(",\"seed_nodes\":" + intArrayToJson(singleton));
w.write(",\"is_singleton\":true");
w.write("}");
}
}
w.write("\n]}");
}
}
}
// Loader for canonical 0-based undirected edgelists (u < v). Builds 1-based adjacency.
static GraphData loadEdgeList0Based(Path edgesFile) throws IOException {
int[] deg = new int[1 << 12];
int maxNode = -1;
long mUndir = 0;
try (BufferedReader br = Files.newBufferedReader(edgesFile, StandardCharsets.UTF_8)) {
String s;
while ((s = br.readLine()) != null) {
s = s.trim();
if (s.isEmpty() || s.startsWith("#")) continue;
String[] tok = s.split("\\s+|,");
if (tok.length < 2) continue;
int u = Integer.parseInt(tok[0]);
int v = Integer.parseInt(tok[1]);
if (u == v) continue;
int needed = Math.max(u, v) + 1;
if (needed > deg.length) {
int newLen = deg.length;
while (newLen < needed) newLen <<= 1;
deg = Arrays.copyOf(deg, newLen);
}
deg[u]++; deg[v]++;
if (u < v) mUndir++;
if (u > maxNode) maxNode = u;
if (v > maxNode) maxNode = v;
}
}
final int n = maxNode + 1;
@SuppressWarnings("unchecked")
List<Integer>[] adj1 = (List<Integer>[]) new List<?>[n + 1];
for (int i = 1; i <= n; i++) adj1[i] = new ArrayList<>(deg[i - 1]);
try (BufferedReader br = Files.newBufferedReader(edgesFile, StandardCharsets.UTF_8)) {
String s;
while ((s = br.readLine()) != null) {
s = s.trim();
if (s.isEmpty() || s.startsWith("#")) continue;
String[] tok = s.split("\\s+|,");
if (tok.length < 2) continue;
int u = Integer.parseInt(tok[0]);
int v = Integer.parseInt(tok[1]);
if (u == v) continue;
adj1[u + 1].add(v + 1);
adj1[v + 1].add(u + 1);
}
}
GraphData G = new GraphData();
G.n = n; G.m = mUndir; G.adj1Based = adj1;
G.labels = new int[n]; Arrays.fill(G.labels, -1);
G.labelNames = new String[0];
return G;
}
static AlphaKind parseAlpha(String s) {
String t = s.trim().toUpperCase(Locale.ROOT);
if (t.startsWith("DIAM")) return AlphaKind.DIAM;
if (t.contains("LAMBDA")) return AlphaKind.INV_SQRT_LAMBDA2;
return AlphaKind.DIAM;
}
static String intArrayToJson(int[] arr) {
StringBuilder sb = new StringBuilder();
sb.append('[');
for (int i = 0; i < arr.length; i++) {
if (i > 0) sb.append(',');
sb.append(arr[i]);
}
sb.append(']');
return sb.toString();
}
enum AlphaKind {DIAM, INV_SQRT_LAMBDA2}
static final class GraphData {
int n;
long m;
List<Integer>[] adj1Based;
int[] labels;
String[] labelNames;
}
}