diff --git a/engine/minskyTensorOps.cc b/engine/minskyTensorOps.cc index 413f34903..7d8884c55 100644 --- a/engine/minskyTensorOps.cc +++ b/engine/minskyTensorOps.cc @@ -1160,7 +1160,6 @@ namespace minsky dimension=i-xv.begin(); } - sumy.setArgument(y,args); TensorPtr spreadX; if (x) { @@ -1169,7 +1168,7 @@ namespace minsky } else { - if (rank()>1 && dimension>=y->rank()) return; + if (y->rank()>1 && dimension>=y->rank()) return; // construct x from y's x-vector auto tv=make_shared(); spreadX=tv; @@ -1194,13 +1193,17 @@ namespace minsky } } } - sumx.setArgument(spreadX,args); - auto fxy=[](double x, double y){return isfinite(x) && isfinite(y)? x*y: 0;}; - sumyy.setArgument(make_shared(fxy,y,y),args); - sumxx.setArgument(make_shared(fxy,spreadX,spreadX),args); + auto mask=[](double x, double y){return isfinite(x) && isfinite(y);}; + auto fx=[mask](double x, double y){return mask(x,y)? x:0;}; + auto fxy=[mask](double x, double y){return mask(x,y)? x*y: 0;}; + auto maskedX=make_shared(fx,spreadX,y); + auto maskedY=make_shared(fx,y,spreadX); + sumx.setArgument(maskedX,args); + sumy.setArgument(maskedY,args); + sumyy.setArgument(make_shared(fxy,maskedY,y),args); + sumxx.setArgument(make_shared(fxy,maskedX,spreadX),args); sumxy.setArgument(make_shared(fxy,y,spreadX),args); - count.setArgument - (make_shared([](double x,double y) {return isfinite(x)*isfinite(y);},y,spreadX),args); + count.setArgument(make_shared(mask,y,spreadX),args); assert(sumx.hypercube()==sumy.hypercube()); assert(sumx.index()==sumy.index()); diff --git a/test/testTensorOps.cc b/test/testTensorOps.cc index b9cda33d6..4f6c45dd2 100644 --- a/test/testTensorOps.cc +++ b/test/testTensorOps.cc @@ -1844,6 +1844,28 @@ TEST_F(CorrelationSuite,xvectorValueLinearRegression) for (size_t _i=0; _i{0,1,2,3,4,5}); + TensorVal y(hc); y=vector{1, std::numeric_limits::quiet_NaN(), 3, 4, + std::numeric_limits::infinity(), 6}; + fromVal=y; + + // finite pairs: (x=0,y=1),(x=2,y=3),(x=3,y=4),(x=5,y=6) => line y=x+1 + vector result={1,2,3,4,5,6}; + + OperationPtr op(OperationType::linearRegression); + g->addItem(op); + Wire w1(from->ports(0),op->ports(1)), w3(op->ports(0),to->ports(1)); + Eval(*to, op)(); + + auto& toVal=*to->vValue(); + ASSERT_EQ(result.size(), toVal.size()); + for (size_t _i=0; _i