Skip to content

Commit

Permalink
in memory node refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak Aggarwal <[email protected]>
  • Loading branch information
sarthakaggarwal97 committed Sep 1, 2024
1 parent 398a07b commit a6daf05
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -703,6 +702,20 @@ private InMemoryTreeNode getNewNode() {
return new InMemoryTreeNode();
}

/**
* Returns a new star-tree node
* @param dimensionId dimension id of the star-tree node
* @param startDocId start doc id of the star-tree node
* @param endDocId end doc id of the star-tree node
* @param nodeType node type of the star-tree node
* @param dimensionValue dimension value of the star-tree node
* @return
*/
private InMemoryTreeNode getNewNode(int dimensionId, int startDocId, int endDocId, byte nodeType, long dimensionValue) {
numStarTreeNodes++;
return new InMemoryTreeNode(dimensionId, startDocId, endDocId, nodeType, dimensionValue);
}

/**
* Implements the algorithm to construct a star-tree
*
Expand All @@ -713,30 +726,30 @@ private InMemoryTreeNode getNewNode() {
*/
private void constructStarTree(InMemoryTreeNode node, int startDocId, int endDocId) throws IOException {

int childDimensionId = node.dimensionId + 1;
int childDimensionId = node.getDimensionId() + 1;
if (childDimensionId == numDimensions) {
return;
}

// Construct all non-star children nodes
node.childDimensionId = childDimensionId;
Map<Long, InMemoryTreeNode> children = constructNonStarNodes(startDocId, endDocId, childDimensionId);
node.children = children;
node.setChildDimensionId(childDimensionId);
constructNonStarNodes(node, startDocId, endDocId, childDimensionId);

// Construct star-node if required
if (!skipStarNodeCreationForDimensions.contains(childDimensionId) && children.size() > 1) {
node.childStarNode = constructStarNode(startDocId, endDocId, childDimensionId);
if (!skipStarNodeCreationForDimensions.contains(childDimensionId) && node.getChildren().size() > 1) {
node.addChildNode(constructStarNode(startDocId, endDocId, childDimensionId), (long) ALL);
}

// Further split star node if needed
if (node.childStarNode != null && (node.childStarNode.endDocId - node.childStarNode.startDocId > maxLeafDocuments)) {
constructStarTree(node.childStarNode, node.childStarNode.startDocId, node.childStarNode.endDocId);
if (node.getChildStarNode() != null
&& (node.getChildStarNode().getEndDocId() - node.getChildStarNode().getStartDocId() > maxLeafDocuments)) {
constructStarTree(node.getChildStarNode(), node.getChildStarNode().getStartDocId(), node.getChildStarNode().getEndDocId());
}

// Further split on child nodes if required
for (InMemoryTreeNode child : children.values()) {
if (child.endDocId - child.startDocId > maxLeafDocuments) {
constructStarTree(child, child.startDocId, child.endDocId);
for (InMemoryTreeNode child : node.getChildren().values()) {
if (child.getEndDocId() - child.getStartDocId() > maxLeafDocuments) {
constructStarTree(child, child.getStartDocId(), child.getEndDocId());
}
}

Expand All @@ -745,42 +758,41 @@ private void constructStarTree(InMemoryTreeNode node, int startDocId, int endDoc
/**
* Constructs non star tree nodes
*
* @param node parent node
* @param startDocId start document id (inclusive)
* @param endDocId end document id (exclusive)
* @param dimensionId id of the dimension in the star tree
* @return root node with non-star nodes constructed
*
* @throws IOException throws an exception if we are unable to construct non-star nodes
*/
private Map<Long, InMemoryTreeNode> constructNonStarNodes(int startDocId, int endDocId, int dimensionId) throws IOException {
Map<Long, InMemoryTreeNode> nodes = new HashMap<>();
private void constructNonStarNodes(InMemoryTreeNode node, int startDocId, int endDocId, int dimensionId) throws IOException {
int nodeStartDocId = startDocId;
Long nodeDimensionValue = getDimensionValue(startDocId, dimensionId);
for (int i = startDocId + 1; i < endDocId; i++) {
Long dimensionValue = getDimensionValue(i, dimensionId);
if (Objects.equals(dimensionValue, nodeDimensionValue) == false) {
InMemoryTreeNode child = getNewNode();
child.dimensionId = dimensionId;
if (nodeDimensionValue == null) {
child.dimensionValue = ALL;
child.nodeType = StarTreeNodeType.NULL.getValue();
} else {
child.dimensionValue = nodeDimensionValue;
}
child.startDocId = nodeStartDocId;
child.endDocId = i;
nodes.put(nodeDimensionValue, child);
addChildNode(node, i, dimensionId, nodeStartDocId, nodeDimensionValue);

nodeStartDocId = i;
nodeDimensionValue = dimensionValue;
}
}
InMemoryTreeNode lastNode = getNewNode();
lastNode.dimensionId = dimensionId;
lastNode.dimensionValue = nodeDimensionValue != null ? nodeDimensionValue : ALL;
lastNode.startDocId = nodeStartDocId;
lastNode.endDocId = endDocId;
nodes.put(nodeDimensionValue, lastNode);
return nodes;
addChildNode(node, endDocId, dimensionId, nodeStartDocId, nodeDimensionValue);
}

private void addChildNode(InMemoryTreeNode node, int endDocId, int dimensionId, int nodeStartDocId, Long nodeDimensionValue) {
long childNodeDimensionValue;
byte childNodeType;
if (nodeDimensionValue == null) {
childNodeDimensionValue = ALL;
childNodeType = StarTreeNodeType.NULL.getValue();
} else {
childNodeDimensionValue = nodeDimensionValue;
childNodeType = StarTreeNodeType.DEFAULT.getValue();
}

InMemoryTreeNode lastNode = getNewNode(dimensionId, nodeStartDocId, endDocId, childNodeType, childNodeDimensionValue);
node.addChildNode(lastNode, nodeDimensionValue);
}

/**
Expand All @@ -793,15 +805,10 @@ private Map<Long, InMemoryTreeNode> constructNonStarNodes(int startDocId, int en
* @throws IOException throws an exception if we are unable to construct non-star nodes
*/
private InMemoryTreeNode constructStarNode(int startDocId, int endDocId, int dimensionId) throws IOException {
InMemoryTreeNode starNode = getNewNode();
starNode.dimensionId = dimensionId;
starNode.dimensionValue = ALL;
starNode.nodeType = StarTreeNodeType.STAR.getValue();
starNode.startDocId = numStarTreeDocs;
int starNodeStartDocId = numStarTreeDocs;
Iterator<StarTreeDocument> starTreeDocumentIterator = generateStarTreeDocumentsForStarNode(startDocId, endDocId, dimensionId);
appendDocumentsToStarTree(starTreeDocumentIterator);
starNode.endDocId = numStarTreeDocs;
return starNode;
return getNewNode(dimensionId, starNodeStartDocId, numStarTreeDocs, StarTreeNodeType.STAR.getValue(), ALL);
}

/**
Expand All @@ -815,54 +822,54 @@ private StarTreeDocument createAggregatedDocs(InMemoryTreeNode node) throws IOEx
StarTreeDocument aggregatedStarTreeDocument = null;

// For leaf node
if (node.children == null && node.childStarNode == null) {
if (node.getChildren().isEmpty() && node.getChildStarNode() == null) {

if (node.startDocId == node.endDocId - 1) {
if (node.getStartDocId() == node.getEndDocId() - 1) {
// If it has only one document, use it as the aggregated document
aggregatedStarTreeDocument = getStarTreeDocument(node.startDocId);
node.aggregatedDocId = node.startDocId;
aggregatedStarTreeDocument = getStarTreeDocument(node.getStartDocId());
node.setAggregatedDocId(node.getStartDocId());
} else {
// If it has multiple documents, aggregate all of them
for (int i = node.startDocId; i < node.endDocId; i++) {
for (int i = node.getStartDocId(); i < node.getEndDocId(); i++) {
aggregatedStarTreeDocument = reduceStarTreeDocuments(aggregatedStarTreeDocument, getStarTreeDocument(i));
}
if (null == aggregatedStarTreeDocument) {
throw new IllegalStateException("aggregated star-tree document is null after reducing the documents");
}
for (int i = node.dimensionId + 1; i < numDimensions; i++) {
for (int i = node.getDimensionId() + 1; i < numDimensions; i++) {
aggregatedStarTreeDocument.dimensions[i] = STAR_IN_DOC_VALUES_INDEX;
}
node.aggregatedDocId = numStarTreeDocs;
node.setAggregatedDocId(numStarTreeDocs);
appendToStarTree(aggregatedStarTreeDocument);
}
} else {
// For non-leaf node
if (node.childStarNode != null) {
if (node.getChildStarNode() != null) {
// If it has star child, use the star child aggregated document directly
aggregatedStarTreeDocument = createAggregatedDocs(node.childStarNode);
node.aggregatedDocId = node.childStarNode.aggregatedDocId;
aggregatedStarTreeDocument = createAggregatedDocs(node.getChildStarNode());
node.setAggregatedDocId(node.getChildStarNode().getAggregatedDocId());

for (InMemoryTreeNode child : node.children.values()) {
for (InMemoryTreeNode child : node.getChildren().values()) {
createAggregatedDocs(child);
}
} else {
// If no star child exists, aggregate all aggregated documents from non-star children
if (node.children.values().size() == 1) {
for (InMemoryTreeNode child : node.children.values()) {
if (node.getChildren().values().size() == 1) {
for (InMemoryTreeNode child : node.getChildren().values()) {
aggregatedStarTreeDocument = reduceStarTreeDocuments(aggregatedStarTreeDocument, createAggregatedDocs(child));
node.aggregatedDocId = child.aggregatedDocId;
node.setAggregatedDocId(child.getAggregatedDocId());
}
} else {
for (InMemoryTreeNode child : node.children.values()) {
for (InMemoryTreeNode child : node.getChildren().values()) {
aggregatedStarTreeDocument = reduceStarTreeDocuments(aggregatedStarTreeDocument, createAggregatedDocs(child));
}
if (null == aggregatedStarTreeDocument) {
throw new IllegalStateException("aggregated star-tree document is null after reducing the documents");
}
for (int i = node.dimensionId + 1; i < numDimensions; i++) {
for (int i = node.getDimensionId() + 1; i < numDimensions; i++) {
aggregatedStarTreeDocument.dimensions[i] = STAR_IN_DOC_VALUES_INDEX;
}
node.aggregatedDocId = numStarTreeDocs;
node.setAggregatedDocId(numStarTreeDocs);
appendToStarTree(aggregatedStarTreeDocument);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
import org.opensearch.index.compositeindex.datacube.startree.node.InMemoryTreeNode;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

import static org.opensearch.index.compositeindex.datacube.startree.fileformats.node.FixedLengthStarTreeNode.SERIALIZABLE_DATA_SIZE_IN_BYTES;
Expand Down Expand Up @@ -73,21 +70,14 @@ private static void writeStarTreeNodes(IndexOutput output, InMemoryTreeNode root
int totalNumberOfChildren = 0;
int firstChildId = currentNodeId + queue.size() + 1;

if (node.childStarNode != null) {
if (node.getChildStarNode() != null) {
totalNumberOfChildren++;
queue.add(node.childStarNode);
queue.add(node.getChildStarNode());
}

if (node.children != null) {
// Sort all children nodes based on dimension value
// TODO: Verify if linked hashmap can help avoid the children sort
List<InMemoryTreeNode> sortedChildren = new ArrayList<>(node.children.values());
sortedChildren.sort(
Comparator.comparingInt(InMemoryTreeNode::getNodeType).thenComparingLong(InMemoryTreeNode::getDimensionValue)
);

totalNumberOfChildren = totalNumberOfChildren + sortedChildren.size();
queue.addAll(sortedChildren);
if (node.getChildren() != null) {
totalNumberOfChildren = totalNumberOfChildren + node.getChildren().values().size();
queue.addAll(node.getChildren().values());
}

int lastChildId = firstChildId + totalNumberOfChildren - 1;
Expand All @@ -109,12 +99,12 @@ private static void writeStarTreeNodes(IndexOutput output, InMemoryTreeNode root
* @throws IOException if an I/O error occurs while writing the node
*/
private static void writeStarTreeNode(IndexOutput output, InMemoryTreeNode node, int firstChildId, int lastChildId) throws IOException {
output.writeInt(node.dimensionId);
output.writeLong(node.dimensionValue);
output.writeInt(node.startDocId);
output.writeInt(node.endDocId);
output.writeInt(node.aggregatedDocId);
output.writeByte(node.nodeType);
output.writeInt(node.getDimensionId());
output.writeLong(node.getDimensionValue());
output.writeInt(node.getStartDocId());
output.writeInt(node.getEndDocId());
output.writeInt(node.getAggregatedDocId());
output.writeByte(node.getNodeType());
output.writeInt(firstChildId);
output.writeInt(lastChildId);
}
Expand Down
Loading

0 comments on commit a6daf05

Please sign in to comment.