Skip to content

Commit a6f4a9c

Browse files
gh-37: Allow TNS range assingment.
1 parent bd17c08 commit a6f4a9c

File tree

2 files changed

+217
-35
lines changed

2 files changed

+217
-35
lines changed

src/interpreter.c

Lines changed: 142 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,61 +1682,168 @@ ExecResult assign_index_chain(Interpreter* interp, Env* env, Expr* idx_expr, Val
16821682

16831683
if (cur->type == VAL_TNS) {
16841684
Tensor* t = cur->as.tns;
1685-
// Lvalue assignment only supports full integer indexing (no ranges/wildcards).
1686-
if (indices->count != t->ndim) {
1687-
out = make_error("Cannot assign through tensor slice", node->line, node->column);
1688-
goto cleanup;
1689-
}
16901685

1691-
size_t* idxs0 = malloc(sizeof(size_t) * indices->count);
1692-
if (!idxs0) {
1686+
// Allow indexing with ranges/wildcards or integers. Indices may be fewer than ndim.
1687+
// Build per-dimension start/end (1-based inclusive) arrays for the full tensor dims.
1688+
int64_t* starts = malloc(sizeof(int64_t) * t->ndim);
1689+
int64_t* ends = malloc(sizeof(int64_t) * t->ndim);
1690+
if (!starts || !ends) {
1691+
free(starts); free(ends);
16931692
out = make_error("Out of memory", stmt_line, stmt_col);
16941693
goto cleanup;
16951694
}
16961695

1696+
// default full spans
1697+
for (size_t i = 0; i < t->ndim; i++) { starts[i] = 1; ends[i] = (int64_t)t->shape[i]; }
1698+
1699+
// fill from provided indices
16971700
for (size_t i = 0; i < indices->count; i++) {
16981701
Expr* it = indices->items[i];
1699-
if (it->type == EXPR_WILDCARD || it->type == EXPR_RANGE) {
1700-
free(idxs0);
1701-
out = make_error("Cannot assign using ranges or wildcards", it->line, it->column);
1702-
goto cleanup;
1702+
if (it->type == EXPR_WILDCARD) {
1703+
starts[i] = 1; ends[i] = (int64_t)t->shape[i];
1704+
continue;
1705+
}
1706+
if (it->type == EXPR_RANGE) {
1707+
// range node holds start and end expressions as children
1708+
Expr* rs = it->as.range.start;
1709+
Expr* re = it->as.range.end;
1710+
Value vs = eval_expr(interp, rs, env);
1711+
if (interp->error) { ExecResult err = make_error(interp->error, interp->error_line, interp->error_col); clear_error(interp); free(starts); free(ends); out = err; goto cleanup; }
1712+
Value ve = eval_expr(interp, re, env);
1713+
if (interp->error) { ExecResult err = make_error(interp->error, interp->error_line, interp->error_col); clear_error(interp); value_free(vs); free(starts); free(ends); out = err; goto cleanup; }
1714+
if (vs.type != VAL_INT || ve.type != VAL_INT) { value_free(vs); value_free(ve); free(starts); free(ends); out = make_error("Range endpoints must evaluate to INT", it->line, it->column); goto cleanup; }
1715+
starts[i] = vs.as.i; ends[i] = ve.as.i;
1716+
value_free(vs); value_free(ve);
1717+
continue;
17031718
}
17041719

1720+
// single index expression
17051721
Value vi = eval_expr(interp, it, env);
1706-
if (interp->error) {
1707-
ExecResult err = make_error(interp->error, interp->error_line, interp->error_col);
1708-
clear_error(interp);
1709-
free(idxs0);
1710-
out = err;
1711-
goto cleanup;
1722+
if (interp->error) { ExecResult err = make_error(interp->error, interp->error_line, interp->error_col); clear_error(interp); free(starts); free(ends); out = err; goto cleanup; }
1723+
if (vi.type != VAL_INT) { value_free(vi); free(starts); free(ends); out = make_error("Index expression must evaluate to INT", it->line, it->column); goto cleanup; }
1724+
starts[i] = vi.as.i; ends[i] = vi.as.i; // fixed single element
1725+
value_free(vi);
1726+
}
1727+
1728+
// Normalize negative indices and clamp; compute lengths
1729+
size_t new_ndim = 0;
1730+
int* orig_to_out = malloc(sizeof(int) * t->ndim);
1731+
if (!orig_to_out) { free(starts); free(ends); out = make_error("Out of memory", stmt_line, stmt_col); goto cleanup; }
1732+
for (size_t i = 0; i < t->ndim; i++) {
1733+
int64_t s = starts[i];
1734+
int64_t e = ends[i];
1735+
int64_t dim = (int64_t)t->shape[i];
1736+
if (s < 0) s = dim + s + 1;
1737+
if (e < 0) e = dim + e + 1;
1738+
if (s < 1) s = 1;
1739+
if (e > dim) e = dim;
1740+
if (s > e) { starts[i] = 1; ends[i] = 0; orig_to_out[i] = -1; continue; }
1741+
starts[i] = s; ends[i] = e;
1742+
size_t len = (size_t)(e - s + 1);
1743+
if (len <= 1) orig_to_out[i] = -1; else orig_to_out[i] = (int)new_ndim++;
1744+
}
1745+
1746+
if (new_ndim == 0) {
1747+
// All dimensions fixed -> single element assignment
1748+
size_t src_offset = 0;
1749+
for (size_t i = 0; i < t->ndim; i++) {
1750+
size_t pos = (starts[i] <= ends[i]) ? (size_t)(starts[i] - 1) : 0;
1751+
src_offset += pos * t->strides[i];
17121752
}
1713-
if (vi.type != VAL_INT) {
1714-
value_free(vi);
1715-
free(idxs0);
1716-
out = make_error("Index expression must evaluate to INT", it->line, it->column);
1753+
1754+
// type compatibility
1755+
if (rhs.type != VAL_TNS && value_type_to_decl(rhs.type) != t->elem_type) {
1756+
free(starts); free(ends); free(orig_to_out);
1757+
out = make_error("Element type mismatch", stmt_line, stmt_col);
17171758
goto cleanup;
17181759
}
17191760

1720-
int64_t v = vi.as.i;
1721-
int64_t dim = (int64_t)t->shape[i];
1722-
if (v < 0) v = dim + v + 1;
1723-
if (v < 1 || v > dim) {
1724-
value_free(vi);
1725-
free(idxs0);
1726-
out = make_error("Index out of range", it->line, it->column);
1761+
mtx_lock(&t->lock);
1762+
value_free(t->data[src_offset]);
1763+
if (rhs.type == VAL_TNS) {
1764+
// RHS is a tensor but single-element selection: copy whole RHS value
1765+
t->data[src_offset] = value_copy(rhs);
1766+
} else {
1767+
t->data[src_offset] = value_copy(rhs);
1768+
}
1769+
mtx_unlock(&t->lock);
1770+
free(starts); free(ends); free(orig_to_out);
1771+
// Set cur to point at this element for further chaining
1772+
cur = &t->data[src_offset];
1773+
continue;
1774+
}
1775+
1776+
// Build output shape and validate RHS
1777+
size_t* out_shape = malloc(sizeof(size_t) * new_ndim);
1778+
if (!out_shape) { free(starts); free(ends); free(orig_to_out); out = make_error("Out of memory", stmt_line, stmt_col); goto cleanup; }
1779+
for (size_t i = 0; i < t->ndim; i++) {
1780+
if (orig_to_out[i] >= 0) {
1781+
out_shape[orig_to_out[i]] = (size_t)(ends[i] - starts[i] + 1);
1782+
}
1783+
}
1784+
1785+
if (rhs.type != VAL_TNS) {
1786+
free(starts); free(ends); free(orig_to_out); free(out_shape);
1787+
out = make_error("Right-hand side must be a TNS matching slice shape", node->line, node->column);
1788+
goto cleanup;
1789+
}
1790+
1791+
Tensor* rt = rhs.as.tns;
1792+
if (rt->ndim != new_ndim) {
1793+
free(starts); free(ends); free(orig_to_out); free(out_shape);
1794+
out = make_error("Right-hand side tensor dimensionality mismatch", node->line, node->column);
1795+
goto cleanup;
1796+
}
1797+
for (size_t d = 0; d < new_ndim; d++) {
1798+
if (rt->shape[d] != out_shape[d]) {
1799+
free(starts); free(ends); free(orig_to_out); free(out_shape);
1800+
out = make_error("Right-hand side tensor shape mismatch", node->line, node->column);
17271801
goto cleanup;
17281802
}
1729-
idxs0[i] = (size_t)(v - 1);
1730-
value_free(vi);
17311803
}
17321804

1733-
Value* elem = value_tns_get_ptr(*cur, idxs0, indices->count);
1734-
free(idxs0);
1735-
if (!elem) {
1736-
out = make_error("Index out of range", node->line, node->column);
1805+
if (rt->elem_type != t->elem_type) {
1806+
free(starts); free(ends); free(orig_to_out); free(out_shape);
1807+
out = make_error("Element type mismatch", stmt_line, stmt_col);
17371808
goto cleanup;
17381809
}
1739-
cur = elem;
1810+
1811+
// Write RHS elements into target tensor region
1812+
// Iterate over output positions and compute corresponding source offset
1813+
for (size_t out_idx = 0; out_idx < rt->length; out_idx++) {
1814+
// compute multi-index for out
1815+
size_t rem = out_idx;
1816+
size_t src_offset = 0;
1817+
for (size_t d = 0; d < new_ndim; d++) {
1818+
size_t pos = rem / rt->strides[d];
1819+
rem = rem % rt->strides[d];
1820+
// find orig dim for this d
1821+
size_t orig = 0;
1822+
for (size_t k = 0; k < t->ndim; k++) {
1823+
if (orig_to_out[k] == (int)d) { orig = k; break; }
1824+
}
1825+
size_t src_pos = pos + (size_t)(starts[orig] - 1);
1826+
src_offset += src_pos * t->strides[orig];
1827+
}
1828+
// add fixed-dimension offsets
1829+
for (size_t k = 0; k < t->ndim; k++) {
1830+
if (orig_to_out[k] == -1) {
1831+
size_t pos = (ends[k] >= starts[k]) ? (size_t)(starts[k] - 1) : 0;
1832+
src_offset += pos * t->strides[k];
1833+
}
1834+
}
1835+
1836+
// assign element
1837+
mtx_lock(&t->lock);
1838+
value_free(t->data[src_offset]);
1839+
t->data[src_offset] = value_copy(rt->data[out_idx]);
1840+
mtx_unlock(&t->lock);
1841+
}
1842+
1843+
free(out_shape);
1844+
free(starts); free(ends); free(orig_to_out);
1845+
// After slice assignment, set cur to base (no further chaining into this node)
1846+
cur = &base_val;
17401847
continue;
17411848
}
17421849

src/parser.c

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,81 @@ static Stmt* parse_statement(Parser* parser) {
498498
char* name = parser->current_token.literal;
499499
advance(parser);
500500
DeclType dtype = parse_type_name(type_tok.literal);
501+
// Support typed declaration with indexed-assignment target, e.g. `TNS: t[1-10] = ...`
502+
if (parser->current_token.type == TOKEN_LBRACKET || parser->current_token.type == TOKEN_LANGLE) {
503+
// construct base identifier expr and parse trailing indexers
504+
Expr* base = expr_ident(name, type_tok.line, type_tok.column);
505+
while (parser->current_token.type == TOKEN_LBRACKET || parser->current_token.type == TOKEN_LANGLE) {
506+
if (parser->current_token.type == TOKEN_LBRACKET) {
507+
int line = parser->current_token.line;
508+
int column = parser->current_token.column;
509+
advance(parser); // consume '['
510+
Expr* idx = expr_index(base, line, column);
511+
if (parser->current_token.type == TOKEN_RBRACKET) {
512+
report_error(parser, "Empty index list");
513+
return NULL;
514+
}
515+
while (parser->current_token.type != TOKEN_RBRACKET && parser->current_token.type != TOKEN_EOF) {
516+
if (match(parser, TOKEN_STAR)) {
517+
Expr* wc = expr_wildcard(parser->previous_token.line, parser->previous_token.column);
518+
expr_list_add(&idx->as.index.indices, wc);
519+
} else {
520+
Expr* start = parse_expression(parser);
521+
if (!start) return NULL;
522+
bool is_range = false;
523+
if (parser->current_token.type == TOKEN_DASH) is_range = true;
524+
else if (parser->current_token.type == TOKEN_NUMBER && parser->current_token.literal && parser->current_token.literal[0] == '-') is_range = true;
525+
if (is_range) {
526+
if (parser->current_token.type == TOKEN_DASH) advance(parser);
527+
Expr* end = parse_expression(parser);
528+
if (!end) return NULL;
529+
Expr* range = expr_range(start, end, start->line, start->column);
530+
expr_list_add(&idx->as.index.indices, range);
531+
} else {
532+
expr_list_add(&idx->as.index.indices, start);
533+
}
534+
}
535+
536+
if (parser->current_token.type == TOKEN_COMMA) { advance(parser); continue; }
537+
break;
538+
}
539+
consume(parser, TOKEN_RBRACKET, "Expected ']' after index list");
540+
base = idx;
541+
continue;
542+
}
543+
544+
// angle-bracket indexing for maps
545+
if (parser->current_token.type == TOKEN_LANGLE) {
546+
int line = parser->current_token.line;
547+
int column = parser->current_token.column;
548+
advance(parser); // consume '<'
549+
Expr* idx = expr_index(base, line, column);
550+
if (parser->current_token.type == TOKEN_RANGLE) {
551+
report_error(parser, "Empty index list");
552+
return NULL;
553+
}
554+
while (parser->current_token.type != TOKEN_RANGLE && parser->current_token.type != TOKEN_EOF) {
555+
Expr* key = parse_expression(parser);
556+
if (!key) return NULL;
557+
expr_list_add(&idx->as.index.indices, key);
558+
if (parser->current_token.type == TOKEN_COMMA) { advance(parser); continue; }
559+
break;
560+
}
561+
consume(parser, TOKEN_RANGLE, "Expected '>' after index list");
562+
base = idx;
563+
continue;
564+
}
565+
}
566+
567+
if (match(parser, TOKEN_EQUALS)) {
568+
Expr* expr = parse_expression(parser);
569+
if (!expr) return NULL;
570+
return stmt_assign(true, dtype, NULL, base, expr, type_tok.line, type_tok.column);
571+
}
572+
report_error(parser, "Expected '=' after typed indexed target");
573+
return NULL;
574+
}
575+
501576
if (match(parser, TOKEN_EQUALS)) {
502577
Expr* expr = parse_expression(parser);
503578
if (!expr) return NULL;

0 commit comments

Comments
 (0)