mirror of
https://github.com/metabarcoding/obitools4.git
synced 2026-03-25 13:30:52 +00:00
Add Jaccard distance and similarity computations for KmerSet and KmerSetGroup
Add Jaccard distance and similarity computations for KmerSet and KmerSetGroup This commit introduces Jaccard distance and similarity methods for KmerSet and KmerSetGroup. For KmerSet: - Added JaccardDistance method to compute the Jaccard distance between two KmerSets - Added JaccardSimilarity method to compute the Jaccard similarity between two KmerSets For KmerSetGroup: - Added JaccardDistanceMatrix method to compute a pairwise Jaccard distance matrix - Added JaccardSimilarityMatrix method to compute a pairwise Jaccard similarity matrix Also includes: - New DistMatrix implementation in pkg/obidist for storing and computing distance/similarity matrices - Updated version handling with bump-version target in Makefile - Added tests for all new methods
This commit is contained in:
272
pkg/obidist/dist_matrix.go
Normal file
272
pkg/obidist/dist_matrix.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package obidist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DistMatrix represents a symmetric matrix stored as a triangular matrix.
|
||||
// The diagonal has a constant value (typically 0 for distances, 1 for similarities).
|
||||
// Only the upper triangle (i < j) is stored to save memory.
|
||||
//
|
||||
// For an n×n matrix, we store n(n-1)/2 values.
|
||||
type DistMatrix struct {
|
||||
n int // Number of elements (matrix dimension)
|
||||
data []float64 // Triangular storage: upper triangle only
|
||||
labels []string // Optional labels for rows/columns
|
||||
diagonalValue float64 // Value on the diagonal
|
||||
}
|
||||
|
||||
// NewDistMatrix creates a new distance matrix of size n×n.
|
||||
// All distances are initialized to 0.0, diagonal is 0.0.
|
||||
func NewDistMatrix(n int) *DistMatrix {
|
||||
if n < 0 {
|
||||
panic("matrix size must be non-negative")
|
||||
}
|
||||
|
||||
// Number of elements in upper triangle: n(n-1)/2
|
||||
size := n * (n - 1) / 2
|
||||
|
||||
return &DistMatrix{
|
||||
n: n,
|
||||
data: make([]float64, size),
|
||||
labels: make([]string, n),
|
||||
diagonalValue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDistMatrixWithLabels creates a new distance matrix with labels.
|
||||
// Diagonal is 0.0 by default.
|
||||
func NewDistMatrixWithLabels(labels []string) *DistMatrix {
|
||||
dm := NewDistMatrix(len(labels))
|
||||
copy(dm.labels, labels)
|
||||
return dm
|
||||
}
|
||||
|
||||
// NewSimilarityMatrix creates a new similarity matrix of size n×n.
|
||||
// All off-diagonal values are initialized to 0.0, diagonal is 1.0.
|
||||
func NewSimilarityMatrix(n int) *DistMatrix {
|
||||
if n < 0 {
|
||||
panic("matrix size must be non-negative")
|
||||
}
|
||||
|
||||
// Number of elements in upper triangle: n(n-1)/2
|
||||
size := n * (n - 1) / 2
|
||||
|
||||
return &DistMatrix{
|
||||
n: n,
|
||||
data: make([]float64, size),
|
||||
labels: make([]string, n),
|
||||
diagonalValue: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSimilarityMatrixWithLabels creates a new similarity matrix with labels.
|
||||
// Diagonal is 1.0.
|
||||
func NewSimilarityMatrixWithLabels(labels []string) *DistMatrix {
|
||||
dm := NewSimilarityMatrix(len(labels))
|
||||
copy(dm.labels, labels)
|
||||
return dm
|
||||
}
|
||||
|
||||
// Size returns the dimension of the matrix (n for an n×n matrix).
|
||||
func (dm *DistMatrix) Size() int {
|
||||
return dm.n
|
||||
}
|
||||
|
||||
// indexFor computes the index in the data array for element (i, j).
|
||||
// Assumes i < j (upper triangle).
|
||||
//
|
||||
// The upper triangle is stored row by row:
|
||||
// (0,1), (0,2), ..., (0,n-1), (1,2), (1,3), ..., (1,n-1), (2,3), ...
|
||||
//
|
||||
// For element (i, j) where i < j:
|
||||
// index = i*(n-1) + j - 1 - i*(i+1)/2
|
||||
//
|
||||
// This can be simplified to:
|
||||
// index = i*n - i*(i+1)/2 + j - i - 1
|
||||
// = i*(n - (i+1)/2 - 1) + j - 1
|
||||
// = i*(n - 1 - i/2 - 1/2) + j - 1
|
||||
//
|
||||
// But the clearest formula is:
|
||||
// index = i*n - i*(i+3)/2 + j - 1
|
||||
func (dm *DistMatrix) indexFor(i, j int) int {
|
||||
if i >= j {
|
||||
panic(fmt.Sprintf("indexFor expects i < j, got i=%d, j=%d", i, j))
|
||||
}
|
||||
// Formula: number of elements in previous rows + position in current row
|
||||
// Previous rows (0 to i-1): sum from k=0 to i-1 of (n-1-k) = i*n - i*(i+1)/2
|
||||
// Current row position: j - i - 1
|
||||
return i*dm.n - i*(i+1)/2 + j - i - 1
|
||||
}
|
||||
|
||||
// Get returns the value at position (i, j).
|
||||
// The matrix is symmetric, so Get(i, j) == Get(j, i).
|
||||
// The diagonal returns the diagonalValue (0.0 for distances, 1.0 for similarities).
|
||||
func (dm *DistMatrix) Get(i, j int) float64 {
|
||||
if i < 0 || i >= dm.n || j < 0 || j >= dm.n {
|
||||
panic(fmt.Sprintf("indices out of bounds: i=%d, j=%d, n=%d", i, j, dm.n))
|
||||
}
|
||||
|
||||
// Diagonal: return the diagonal value
|
||||
if i == j {
|
||||
return dm.diagonalValue
|
||||
}
|
||||
|
||||
// Ensure i < j for indexing
|
||||
if i > j {
|
||||
i, j = j, i
|
||||
}
|
||||
|
||||
return dm.data[dm.indexFor(i, j)]
|
||||
}
|
||||
|
||||
// Set sets the value at position (i, j).
|
||||
// The matrix is symmetric, so Set(i, j, v) also sets (j, i) to v.
|
||||
// Setting the diagonal (i == j) is ignored (diagonal has a fixed value).
|
||||
func (dm *DistMatrix) Set(i, j int, value float64) {
|
||||
if i < 0 || i >= dm.n || j < 0 || j >= dm.n {
|
||||
panic(fmt.Sprintf("indices out of bounds: i=%d, j=%d, n=%d", i, j, dm.n))
|
||||
}
|
||||
|
||||
// Ignore diagonal assignments (diagonal has a fixed value)
|
||||
if i == j {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure i < j for indexing
|
||||
if i > j {
|
||||
i, j = j, i
|
||||
}
|
||||
|
||||
dm.data[dm.indexFor(i, j)] = value
|
||||
}
|
||||
|
||||
// GetLabel returns the label for element i.
|
||||
func (dm *DistMatrix) GetLabel(i int) string {
|
||||
if i < 0 || i >= dm.n {
|
||||
panic(fmt.Sprintf("index out of bounds: i=%d, n=%d", i, dm.n))
|
||||
}
|
||||
return dm.labels[i]
|
||||
}
|
||||
|
||||
// SetLabel sets the label for element i.
|
||||
func (dm *DistMatrix) SetLabel(i int, label string) {
|
||||
if i < 0 || i >= dm.n {
|
||||
panic(fmt.Sprintf("index out of bounds: i=%d, n=%d", i, dm.n))
|
||||
}
|
||||
dm.labels[i] = label
|
||||
}
|
||||
|
||||
// Labels returns a copy of all labels.
|
||||
func (dm *DistMatrix) Labels() []string {
|
||||
labels := make([]string, dm.n)
|
||||
copy(labels, dm.labels)
|
||||
return labels
|
||||
}
|
||||
|
||||
// GetRow returns the i-th row of the distance matrix.
|
||||
// The returned slice is a copy.
|
||||
func (dm *DistMatrix) GetRow(i int) []float64 {
|
||||
if i < 0 || i >= dm.n {
|
||||
panic(fmt.Sprintf("index out of bounds: i=%d, n=%d", i, dm.n))
|
||||
}
|
||||
|
||||
row := make([]float64, dm.n)
|
||||
for j := 0; j < dm.n; j++ {
|
||||
row[j] = dm.Get(i, j)
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// GetColumn returns the j-th column of the distance matrix.
|
||||
// Since the matrix is symmetric, GetColumn(j) == GetRow(j).
|
||||
// The returned slice is a copy.
|
||||
func (dm *DistMatrix) GetColumn(j int) []float64 {
|
||||
return dm.GetRow(j)
|
||||
}
|
||||
|
||||
// MinDistance returns the minimum non-zero distance in the matrix,
|
||||
// along with the indices (i, j) where it occurs.
|
||||
// Returns (0.0, -1, -1) if the matrix is empty or all distances are 0.
|
||||
func (dm *DistMatrix) MinDistance() (float64, int, int) {
|
||||
if dm.n <= 1 {
|
||||
return 0.0, -1, -1
|
||||
}
|
||||
|
||||
minDist := -1.0
|
||||
minI, minJ := -1, -1
|
||||
|
||||
for i := 0; i < dm.n-1; i++ {
|
||||
for j := i + 1; j < dm.n; j++ {
|
||||
dist := dm.Get(i, j)
|
||||
if minDist < 0 || dist < minDist {
|
||||
minDist = dist
|
||||
minI = i
|
||||
minJ = j
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if minI < 0 {
|
||||
return 0.0, -1, -1
|
||||
}
|
||||
|
||||
return minDist, minI, minJ
|
||||
}
|
||||
|
||||
// MaxDistance returns the maximum distance in the matrix,
|
||||
// along with the indices (i, j) where it occurs.
|
||||
// Returns (0.0, -1, -1) if the matrix is empty.
|
||||
func (dm *DistMatrix) MaxDistance() (float64, int, int) {
|
||||
if dm.n <= 1 {
|
||||
return 0.0, -1, -1
|
||||
}
|
||||
|
||||
maxDist := -1.0
|
||||
maxI, maxJ := -1, -1
|
||||
|
||||
for i := 0; i < dm.n-1; i++ {
|
||||
for j := i + 1; j < dm.n; j++ {
|
||||
dist := dm.Get(i, j)
|
||||
if maxDist < 0 || dist > maxDist {
|
||||
maxDist = dist
|
||||
maxI = i
|
||||
maxJ = j
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if maxI < 0 {
|
||||
return 0.0, -1, -1
|
||||
}
|
||||
|
||||
return maxDist, maxI, maxJ
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the matrix.
|
||||
func (dm *DistMatrix) Copy() *DistMatrix {
|
||||
newDM := &DistMatrix{
|
||||
n: dm.n,
|
||||
data: make([]float64, len(dm.data)),
|
||||
labels: make([]string, dm.n),
|
||||
diagonalValue: dm.diagonalValue,
|
||||
}
|
||||
|
||||
copy(newDM.data, dm.data)
|
||||
copy(newDM.labels, dm.labels)
|
||||
|
||||
return newDM
|
||||
}
|
||||
|
||||
// ToFullMatrix returns a full n×n matrix representation.
|
||||
// This allocates n² values, so use only when needed.
|
||||
func (dm *DistMatrix) ToFullMatrix() [][]float64 {
|
||||
matrix := make([][]float64, dm.n)
|
||||
for i := 0; i < dm.n; i++ {
|
||||
matrix[i] = make([]float64, dm.n)
|
||||
for j := 0; j < dm.n; j++ {
|
||||
matrix[i][j] = dm.Get(i, j)
|
||||
}
|
||||
}
|
||||
return matrix
|
||||
}
|
||||
386
pkg/obidist/dist_matrix_test.go
Normal file
386
pkg/obidist/dist_matrix_test.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package obidist
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewDistMatrix(t *testing.T) {
|
||||
dm := NewDistMatrix(5)
|
||||
|
||||
if dm.Size() != 5 {
|
||||
t.Errorf("Expected size 5, got %d", dm.Size())
|
||||
}
|
||||
|
||||
// Check that all values are initialized to 0
|
||||
for i := 0; i < 5; i++ {
|
||||
for j := 0; j < 5; j++ {
|
||||
if dm.Get(i, j) != 0.0 {
|
||||
t.Errorf("Expected 0.0 at (%d, %d), got %f", i, j, dm.Get(i, j))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixDiagonal(t *testing.T) {
|
||||
dm := NewDistMatrix(5)
|
||||
|
||||
// Diagonal should always be 0
|
||||
for i := 0; i < 5; i++ {
|
||||
if dm.Get(i, i) != 0.0 {
|
||||
t.Errorf("Expected diagonal element (%d, %d) to be 0.0, got %f", i, i, dm.Get(i, i))
|
||||
}
|
||||
}
|
||||
|
||||
// Try to set diagonal (should be ignored)
|
||||
dm.Set(2, 2, 5.0)
|
||||
if dm.Get(2, 2) != 0.0 {
|
||||
t.Errorf("Diagonal should remain 0.0 even after Set, got %f", dm.Get(2, 2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixSymmetry(t *testing.T) {
|
||||
dm := NewDistMatrix(4)
|
||||
|
||||
dm.Set(0, 1, 1.5)
|
||||
dm.Set(0, 2, 2.5)
|
||||
dm.Set(1, 3, 3.5)
|
||||
|
||||
// Check symmetry
|
||||
if dm.Get(0, 1) != dm.Get(1, 0) {
|
||||
t.Errorf("Matrix not symmetric: Get(0,1)=%f, Get(1,0)=%f", dm.Get(0, 1), dm.Get(1, 0))
|
||||
}
|
||||
|
||||
if dm.Get(0, 2) != dm.Get(2, 0) {
|
||||
t.Errorf("Matrix not symmetric: Get(0,2)=%f, Get(2,0)=%f", dm.Get(0, 2), dm.Get(2, 0))
|
||||
}
|
||||
|
||||
if dm.Get(1, 3) != dm.Get(3, 1) {
|
||||
t.Errorf("Matrix not symmetric: Get(1,3)=%f, Get(3,1)=%f", dm.Get(1, 3), dm.Get(3, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixSetGet(t *testing.T) {
|
||||
dm := NewDistMatrix(4)
|
||||
|
||||
testCases := []struct {
|
||||
i int
|
||||
j int
|
||||
value float64
|
||||
}{
|
||||
{0, 1, 1.5},
|
||||
{0, 2, 2.5},
|
||||
{0, 3, 3.5},
|
||||
{1, 2, 4.5},
|
||||
{1, 3, 5.5},
|
||||
{2, 3, 6.5},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
dm.Set(tc.i, tc.j, tc.value)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
got := dm.Get(tc.i, tc.j)
|
||||
if math.Abs(got-tc.value) > 1e-10 {
|
||||
t.Errorf("Get(%d, %d): expected %f, got %f", tc.i, tc.j, tc.value, got)
|
||||
}
|
||||
|
||||
// Check symmetry
|
||||
got = dm.Get(tc.j, tc.i)
|
||||
if math.Abs(got-tc.value) > 1e-10 {
|
||||
t.Errorf("Get(%d, %d) (symmetric): expected %f, got %f", tc.j, tc.i, tc.value, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixLabels(t *testing.T) {
|
||||
labels := []string{"A", "B", "C", "D"}
|
||||
dm := NewDistMatrixWithLabels(labels)
|
||||
|
||||
if dm.Size() != 4 {
|
||||
t.Errorf("Expected size 4, got %d", dm.Size())
|
||||
}
|
||||
|
||||
for i, label := range labels {
|
||||
if dm.GetLabel(i) != label {
|
||||
t.Errorf("Expected label %s at index %d, got %s", label, i, dm.GetLabel(i))
|
||||
}
|
||||
}
|
||||
|
||||
// Modify a label
|
||||
dm.SetLabel(1, "Modified")
|
||||
if dm.GetLabel(1) != "Modified" {
|
||||
t.Errorf("Expected label 'Modified' at index 1, got %s", dm.GetLabel(1))
|
||||
}
|
||||
|
||||
// Check Labels() returns a copy
|
||||
labelsCopy := dm.Labels()
|
||||
labelsCopy[0] = "ChangedCopy"
|
||||
if dm.GetLabel(0) != "A" {
|
||||
t.Errorf("Modifying Labels() return value should not affect original labels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixMinDistance(t *testing.T) {
|
||||
dm := NewDistMatrix(4)
|
||||
|
||||
dm.Set(0, 1, 2.5)
|
||||
dm.Set(0, 2, 1.5) // minimum
|
||||
dm.Set(0, 3, 3.5)
|
||||
dm.Set(1, 2, 4.5)
|
||||
dm.Set(1, 3, 5.5)
|
||||
dm.Set(2, 3, 6.5)
|
||||
|
||||
minDist, minI, minJ := dm.MinDistance()
|
||||
|
||||
if math.Abs(minDist-1.5) > 1e-10 {
|
||||
t.Errorf("Expected min distance 1.5, got %f", minDist)
|
||||
}
|
||||
|
||||
if (minI != 0 || minJ != 2) && (minI != 2 || minJ != 0) {
|
||||
t.Errorf("Expected min at (0, 2) or (2, 0), got (%d, %d)", minI, minJ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixMaxDistance(t *testing.T) {
|
||||
dm := NewDistMatrix(4)
|
||||
|
||||
dm.Set(0, 1, 2.5)
|
||||
dm.Set(0, 2, 1.5)
|
||||
dm.Set(0, 3, 3.5)
|
||||
dm.Set(1, 2, 4.5)
|
||||
dm.Set(1, 3, 5.5)
|
||||
dm.Set(2, 3, 6.5) // maximum
|
||||
|
||||
maxDist, maxI, maxJ := dm.MaxDistance()
|
||||
|
||||
if math.Abs(maxDist-6.5) > 1e-10 {
|
||||
t.Errorf("Expected max distance 6.5, got %f", maxDist)
|
||||
}
|
||||
|
||||
if (maxI != 2 || maxJ != 3) && (maxI != 3 || maxJ != 2) {
|
||||
t.Errorf("Expected max at (2, 3) or (3, 2), got (%d, %d)", maxI, maxJ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixGetRow(t *testing.T) {
|
||||
dm := NewDistMatrix(3)
|
||||
|
||||
dm.Set(0, 1, 1.0)
|
||||
dm.Set(0, 2, 2.0)
|
||||
dm.Set(1, 2, 3.0)
|
||||
|
||||
row0 := dm.GetRow(0)
|
||||
expected0 := []float64{0.0, 1.0, 2.0}
|
||||
|
||||
for i, val := range expected0 {
|
||||
if math.Abs(row0[i]-val) > 1e-10 {
|
||||
t.Errorf("Row 0[%d]: expected %f, got %f", i, val, row0[i])
|
||||
}
|
||||
}
|
||||
|
||||
row1 := dm.GetRow(1)
|
||||
expected1 := []float64{1.0, 0.0, 3.0}
|
||||
|
||||
for i, val := range expected1 {
|
||||
if math.Abs(row1[i]-val) > 1e-10 {
|
||||
t.Errorf("Row 1[%d]: expected %f, got %f", i, val, row1[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixCopy(t *testing.T) {
|
||||
dm := NewDistMatrixWithLabels([]string{"A", "B", "C"})
|
||||
dm.Set(0, 1, 1.5)
|
||||
dm.Set(0, 2, 2.5)
|
||||
dm.Set(1, 2, 3.5)
|
||||
|
||||
dmCopy := dm.Copy()
|
||||
|
||||
// Check values are copied
|
||||
if dmCopy.Get(0, 1) != dm.Get(0, 1) {
|
||||
t.Errorf("Copy has different value at (0, 1)")
|
||||
}
|
||||
|
||||
// Check labels are copied
|
||||
if dmCopy.GetLabel(0) != dm.GetLabel(0) {
|
||||
t.Errorf("Copy has different label at index 0")
|
||||
}
|
||||
|
||||
// Modify copy and ensure original unchanged
|
||||
dmCopy.Set(0, 1, 99.9)
|
||||
if dm.Get(0, 1) == 99.9 {
|
||||
t.Errorf("Modifying copy affected original matrix")
|
||||
}
|
||||
|
||||
dmCopy.SetLabel(0, "Modified")
|
||||
if dm.GetLabel(0) == "Modified" {
|
||||
t.Errorf("Modifying copy label affected original matrix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixToFullMatrix(t *testing.T) {
|
||||
dm := NewDistMatrix(3)
|
||||
dm.Set(0, 1, 1.0)
|
||||
dm.Set(0, 2, 2.0)
|
||||
dm.Set(1, 2, 3.0)
|
||||
|
||||
full := dm.ToFullMatrix()
|
||||
|
||||
expected := [][]float64{
|
||||
{0.0, 1.0, 2.0},
|
||||
{1.0, 0.0, 3.0},
|
||||
{2.0, 3.0, 0.0},
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < 3; j++ {
|
||||
if math.Abs(full[i][j]-expected[i][j]) > 1e-10 {
|
||||
t.Errorf("Full matrix[%d][%d]: expected %f, got %f",
|
||||
i, j, expected[i][j], full[i][j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixBoundsChecking(t *testing.T) {
|
||||
dm := NewDistMatrix(3)
|
||||
|
||||
// Test Get out of bounds
|
||||
testPanic := func(f func()) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Expected panic, but didn't get one")
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
||||
|
||||
testPanic(func() { dm.Get(-1, 0) })
|
||||
testPanic(func() { dm.Get(0, 3) })
|
||||
testPanic(func() { dm.Set(3, 0, 1.0) })
|
||||
testPanic(func() { dm.GetLabel(-1) })
|
||||
testPanic(func() { dm.SetLabel(3, "Invalid") })
|
||||
testPanic(func() { dm.GetRow(3) })
|
||||
}
|
||||
|
||||
func TestDistMatrixEmptyMatrix(t *testing.T) {
|
||||
dm := NewDistMatrix(0)
|
||||
|
||||
if dm.Size() != 0 {
|
||||
t.Errorf("Expected size 0, got %d", dm.Size())
|
||||
}
|
||||
|
||||
minDist, minI, minJ := dm.MinDistance()
|
||||
if minDist != 0.0 || minI != -1 || minJ != -1 {
|
||||
t.Errorf("Empty matrix MinDistance should return (0.0, -1, -1), got (%f, %d, %d)",
|
||||
minDist, minI, minJ)
|
||||
}
|
||||
|
||||
maxDist, maxI, maxJ := dm.MaxDistance()
|
||||
if maxDist != 0.0 || maxI != -1 || maxJ != -1 {
|
||||
t.Errorf("Empty matrix MaxDistance should return (0.0, -1, -1), got (%f, %d, %d)",
|
||||
maxDist, maxI, maxJ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistMatrixSingleElement(t *testing.T) {
|
||||
dm := NewDistMatrix(1)
|
||||
|
||||
if dm.Size() != 1 {
|
||||
t.Errorf("Expected size 1, got %d", dm.Size())
|
||||
}
|
||||
|
||||
// Only element is diagonal (always 0)
|
||||
if dm.Get(0, 0) != 0.0 {
|
||||
t.Errorf("Expected 0.0 at (0, 0), got %f", dm.Get(0, 0))
|
||||
}
|
||||
|
||||
minDist, minI, minJ := dm.MinDistance()
|
||||
if minDist != 0.0 || minI != -1 || minJ != -1 {
|
||||
t.Errorf("Single element matrix MinDistance should return (0.0, -1, -1), got (%f, %d, %d)",
|
||||
minDist, minI, minJ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSimilarityMatrix(t *testing.T) {
|
||||
sm := NewSimilarityMatrix(4)
|
||||
|
||||
if sm.Size() != 4 {
|
||||
t.Errorf("Expected size 4, got %d", sm.Size())
|
||||
}
|
||||
|
||||
// Check diagonal is 1.0
|
||||
for i := 0; i < 4; i++ {
|
||||
if sm.Get(i, i) != 1.0 {
|
||||
t.Errorf("Expected diagonal element (%d, %d) to be 1.0, got %f", i, i, sm.Get(i, i))
|
||||
}
|
||||
}
|
||||
|
||||
// Check off-diagonal is 0.0
|
||||
if sm.Get(0, 1) != 0.0 {
|
||||
t.Errorf("Expected off-diagonal to be 0.0, got %f", sm.Get(0, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSimilarityMatrixWithLabels(t *testing.T) {
|
||||
labels := []string{"A", "B", "C"}
|
||||
sm := NewSimilarityMatrixWithLabels(labels)
|
||||
|
||||
if sm.Size() != 3 {
|
||||
t.Errorf("Expected size 3, got %d", sm.Size())
|
||||
}
|
||||
|
||||
// Check labels
|
||||
for i, label := range labels {
|
||||
if sm.GetLabel(i) != label {
|
||||
t.Errorf("Expected label %s at index %d, got %s", label, i, sm.GetLabel(i))
|
||||
}
|
||||
}
|
||||
|
||||
// Check diagonal is 1.0
|
||||
for i := 0; i < 3; i++ {
|
||||
if sm.Get(i, i) != 1.0 {
|
||||
t.Errorf("Expected diagonal element (%d, %d) to be 1.0, got %f", i, i, sm.Get(i, i))
|
||||
}
|
||||
}
|
||||
|
||||
// Set some similarities
|
||||
sm.Set(0, 1, 0.8)
|
||||
sm.Set(0, 2, 0.6)
|
||||
sm.Set(1, 2, 0.7)
|
||||
|
||||
// Check values
|
||||
if math.Abs(sm.Get(0, 1)-0.8) > 1e-10 {
|
||||
t.Errorf("Expected 0.8 at (0, 1), got %f", sm.Get(0, 1))
|
||||
}
|
||||
|
||||
if math.Abs(sm.Get(1, 0)-0.8) > 1e-10 {
|
||||
t.Errorf("Expected 0.8 at (1, 0) (symmetry), got %f", sm.Get(1, 0))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimilarityMatrixCopy(t *testing.T) {
|
||||
sm := NewSimilarityMatrix(3)
|
||||
sm.Set(0, 1, 0.9)
|
||||
sm.Set(0, 2, 0.7)
|
||||
|
||||
smCopy := sm.Copy()
|
||||
|
||||
// Check diagonal is preserved
|
||||
if smCopy.Get(0, 0) != 1.0 {
|
||||
t.Errorf("Copied similarity matrix should have diagonal 1.0, got %f", smCopy.Get(0, 0))
|
||||
}
|
||||
|
||||
// Check values are preserved
|
||||
if math.Abs(smCopy.Get(0, 1)-0.9) > 1e-10 {
|
||||
t.Errorf("Copy should preserve values, expected 0.9, got %f", smCopy.Get(0, 1))
|
||||
}
|
||||
|
||||
// Modify copy and ensure original unchanged
|
||||
smCopy.Set(0, 1, 0.5)
|
||||
if math.Abs(sm.Get(0, 1)-0.9) > 1e-10 {
|
||||
t.Errorf("Modifying copy should not affect original, expected 0.9, got %f", sm.Get(0, 1))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user