Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package com.thealgorithms.dynamicprogramming;

import java.util.Arrays;
import java.util.Comparator;

/**
* Computes the minimum search cost of an optimal binary search tree.
*
* <p>The algorithm sorts the keys, preserves the corresponding search frequencies, and uses
* dynamic programming with Knuth's optimization to compute the minimum weighted search cost.
*
* <p>Example: if keys = [10, 12] and frequencies = [34, 50], the best tree puts 12 at the root
* and 10 as its left child. The total cost is 50 * 1 + 34 * 2 = 118.
*
* <p>Reference:
* https://en.wikipedia.org/wiki/Optimal_binary_search_tree
*/
public final class OptimalBinarySearchTree {
private OptimalBinarySearchTree() {
}

/**
* Computes the minimum weighted search cost for the given keys and search frequencies.
*
* @param keys the BST keys
* @param frequencies the search frequencies associated with the keys
* @return the minimum search cost
* @throws IllegalArgumentException if the input is invalid
*/
public static long findOptimalCost(int[] keys, int[] frequencies) {
validateInput(keys, frequencies);
if (keys.length == 0) {
return 0L;
}

int[][] sortedNodes = sortNodes(keys, frequencies);
int nodeCount = sortedNodes.length;
long[] prefixSums = buildPrefixSums(sortedNodes);
long[][] optimalCost = new long[nodeCount][nodeCount];
int[][] root = new int[nodeCount][nodeCount];

// Small example:
// keys = [10, 12]
// frequencies = [34, 50]
// Choosing 12 as the root gives cost 50 * 1 + 34 * 2 = 118,
// which is better than choosing 10 as the root.

// Base case: a subtree containing one key has cost equal to its frequency,
// because that key becomes the root of the subtree and is searched at depth 1.
for (int index = 0; index < nodeCount; index++) {
optimalCost[index][index] = sortedNodes[index][1];
root[index][index] = index;
}

// Build solutions for longer and longer key ranges.
// optimalCost[start][end] stores the minimum search cost for keys in that range.
for (int length = 2; length <= nodeCount; length++) {
for (int start = 0; start <= nodeCount - length; start++) {
int end = start + length - 1;

// Every key in this range moves one level deeper when we choose a root,
// so the sum of frequencies is added once to the subtree cost.
long frequencySum = prefixSums[end + 1] - prefixSums[start];
optimalCost[start][end] = Long.MAX_VALUE;

// Knuth's optimization:
// the best root for [start, end] lies between the best roots of
// [start, end - 1] and [start + 1, end], so we search only this interval.
int leftBoundary = root[start][end - 1];
int rightBoundary = root[start + 1][end];
for (int currentRoot = leftBoundary; currentRoot <= rightBoundary; currentRoot++) {
long leftCost = currentRoot > start ? optimalCost[start][currentRoot - 1] : 0L;
long rightCost = currentRoot < end ? optimalCost[currentRoot + 1][end] : 0L;
long currentCost = frequencySum + leftCost + rightCost;

if (currentCost < optimalCost[start][end]) {
optimalCost[start][end] = currentCost;
root[start][end] = currentRoot;
}
}
}
}

return optimalCost[0][nodeCount - 1];
}

private static void validateInput(int[] keys, int[] frequencies) {
if (keys == null || frequencies == null) {
throw new IllegalArgumentException("Keys and frequencies cannot be null");
}
if (keys.length != frequencies.length) {
throw new IllegalArgumentException("Keys and frequencies must have the same length");
}

for (int frequency : frequencies) {
if (frequency < 0) {
throw new IllegalArgumentException("Frequencies cannot be negative");
}
}
}

private static int[][] sortNodes(int[] keys, int[] frequencies) {
int[][] sortedNodes = new int[keys.length][2];
for (int index = 0; index < keys.length; index++) {
sortedNodes[index][0] = keys[index];
sortedNodes[index][1] = frequencies[index];
}

// Sort by key so the nodes can be treated as an in-order BST sequence.
Arrays.sort(sortedNodes, Comparator.comparingInt(node -> node[0]));

for (int index = 1; index < sortedNodes.length; index++) {
if (sortedNodes[index - 1][0] == sortedNodes[index][0]) {
throw new IllegalArgumentException("Keys must be distinct");
}
}

return sortedNodes;
}

private static long[] buildPrefixSums(int[][] sortedNodes) {
long[] prefixSums = new long[sortedNodes.length + 1];
for (int index = 0; index < sortedNodes.length; index++) {
// prefixSums[i] holds the total frequency of the first i sorted keys.
// This lets us get the frequency sum of any range in O(1) time.
prefixSums[index + 1] = prefixSums[index] + sortedNodes[index][1];
}
return prefixSums;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.thealgorithms.dynamicprogramming;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.Arrays;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

class OptimalBinarySearchTreeTest {

@ParameterizedTest
@MethodSource("validTestCases")
void testFindOptimalCost(int[] keys, int[] frequencies, long expectedCost) {
assertEquals(expectedCost, OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
}

private static Stream<Arguments> validTestCases() {
return Stream.of(Arguments.of(new int[] {}, new int[] {}, 0L), Arguments.of(new int[] {15}, new int[] {9}, 9L), Arguments.of(new int[] {10, 12}, new int[] {34, 50}, 118L), Arguments.of(new int[] {20, 10, 30}, new int[] {50, 34, 8}, 134L),
Arguments.of(new int[] {12, 10, 20, 42, 25, 37}, new int[] {8, 34, 50, 3, 40, 30}, 324L), Arguments.of(new int[] {1, 2, 3}, new int[] {0, 0, 0}, 0L));
}

@ParameterizedTest
@MethodSource("crossCheckTestCases")
void testFindOptimalCostAgainstBruteForce(int[] keys, int[] frequencies) {
assertEquals(bruteForceOptimalCost(keys, frequencies), OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
}

private static Stream<Arguments> crossCheckTestCases() {
return Stream.of(Arguments.of(new int[] {3, 1, 2}, new int[] {4, 2, 6}), Arguments.of(new int[] {5, 2, 8, 6}, new int[] {3, 7, 1, 4}), Arguments.of(new int[] {9, 4, 11, 2}, new int[] {1, 8, 2, 5}));
}

@ParameterizedTest
@MethodSource("invalidTestCases")
void testFindOptimalCostInvalidInput(int[] keys, int[] frequencies) {
assertThrows(IllegalArgumentException.class, () -> OptimalBinarySearchTree.findOptimalCost(keys, frequencies));
}

private static Stream<Arguments> invalidTestCases() {
return Stream.of(Arguments.of(null, new int[] {}), Arguments.of(new int[] {}, null), Arguments.of(new int[] {1, 2}, new int[] {3}), Arguments.of(new int[] {1, 1}, new int[] {2, 3}), Arguments.of(new int[] {1, 2}, new int[] {3, -1}));
}

private static long bruteForceOptimalCost(int[] keys, int[] frequencies) {
int[][] sortedNodes = new int[keys.length][2];
for (int index = 0; index < keys.length; index++) {
sortedNodes[index][0] = keys[index];
sortedNodes[index][1] = frequencies[index];
}
Arrays.sort(sortedNodes, java.util.Comparator.comparingInt(node -> node[0]));

int[] sortedFrequencies = new int[sortedNodes.length];
for (int index = 0; index < sortedNodes.length; index++) {
sortedFrequencies[index] = sortedNodes[index][1];
}

return bruteForceOptimalCost(sortedFrequencies, 0, sortedFrequencies.length - 1, 1);
}

private static long bruteForceOptimalCost(int[] frequencies, int start, int end, int depth) {
if (start > end) {
return 0L;
}

long minimumCost = Long.MAX_VALUE;
for (int root = start; root <= end; root++) {
long currentCost = (long) depth * frequencies[root] + bruteForceOptimalCost(frequencies, start, root - 1, depth + 1) + bruteForceOptimalCost(frequencies, root + 1, end, depth + 1);
minimumCost = Math.min(minimumCost, currentCost);
}
return minimumCost;
}
}
Loading