diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index e21c539d6d8..f5719641df7 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -154,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..9a894dde13b 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -215,6 +215,8 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), + GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary), + //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..c2832aeb8cd 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -639,6 +639,7 @@ public enum OpOp2 { MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y + GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false), //quantization-fused compression UNION_DISTINCT(false); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 2b803a053c1..dc7edf76e50 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -853,7 +853,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + + if( op == OpOp2.GET_CATEGORICAL_MASK) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 28f6949f722..ab0c7993b4e 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("The compress or decompress instruction is not allowed in dml scripts"); break; + case GET_CATEGORICAL_MASK: + checkNumParameters(2); + checkFrameParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + output.setDataType(DataType.MATRIX); + output.setDimensions(1, -1); + output.setBlocksize( id.getBlocksize()); + output.setValueType(ValueType.FP64); + break; case QUANTIZE_COMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { checkNumParameters(2); @@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS); } } + + protected void checkFrameParam(Expression e) { + if(e.getOutput().getDataType() != DataType.FRAME) { + raiseValidateError("Expecting frame parameter for function " + getOpCode(), false, + LanguageErrorCodes.UNSUPPORTED_PARAMETERS); + } + } protected void checkMatrixScalarParam(Expression e) { //always unconditional if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..e14cfd31388 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2821,6 +2821,9 @@ else if ( in.length == 2 ) DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; + case GET_CATEGORICAL_MASK: + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2); + break; default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 39735be62e0..eed2c58f78c 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, - DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, + DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; @@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, String2BuiltinCode.put( "_map", BuiltinCode.MAP); String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA); + String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK); } protected Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 28b8775ebd5..86184f47be6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME) return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str); + else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR) + return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX) return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str); else diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java new file mode 100644 index 00000000000..bbf4774ed7a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysds.common.Builtins; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; +import org.apache.sysds.runtime.transform.TfUtils.TfMethod; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONObject; + +public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { + // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, + String opcode, String istr) { + super(CPType.Binary, op, in1, in2, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + // get input frames + FrameBlock inBlock1 = ec.getFrameInput(input1.getName()); + ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true); + if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) { + processGetCategorical(ec, inBlock1, spec); + } + else { + throw new DMLRuntimeException("Unsupported operation"); + } + + // Release the memory occupied by input frames + ec.releaseFrameInput(input1.getName()); + } + + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + + // MatrixBlock ret = new MatrixBlock(); + int nCol = f.getNumColumns(); + + JSONObject jSpec = new JSONObject(spec.getStringValue()); + + if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) { + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); + } + + String recode = TfMethod.RECODE.toString(); + String dummycode = TfMethod.DUMMYCODE.toString(); + String hash = TfMethod.HASH.toString(); + + int[] lengths = new int[nCol]; + // assume all columns encode to at least one column. + Arrays.fill(lengths, 1); + boolean[] categorical = new boolean[nCol]; + + // feature-hashed columns map to K buckets; a plain hashed column + // produces a single (categorical) bucket-id column, while a hashed + // column that is additionally dummycoded expands to K columns. + boolean[] hashed = new boolean[nCol]; + int K = 0; + if(jSpec.containsKey(hash)) { + K = jSpec.getInt("K"); + JSONArray a = jSpec.getJSONArray(hash); + for(Object aa : a) { + int av = (Integer) aa - 1; + hashed[av] = true; + categorical[av] = true; + } + } + + if(jSpec.containsKey(recode)) { + JSONArray a = jSpec.getJSONArray(recode); + for(Object aa : a) { + int av = (Integer) aa - 1; + categorical[av] = true; + } + } + + if(jSpec.containsKey(dummycode)) { + JSONArray a = jSpec.getJSONArray(dummycode); + for(Object aa : a) { + int av = (Integer) aa - 1; + int ndist; + if(hashed[av]) { + // feature hashing followed by dummycoding yields K columns + ndist = K; + } + else { + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + } + lengths[av] = ndist; + categorical[av] = true; + } + } + + // get total size after mapping + + int sumLengths = 0; + for(int i : lengths) { + sumLengths += i; + } + + MatrixBlock ret = new MatrixBlock(1, sumLengths, false); + ret.allocateDenseBlock(); + int off = 0; + for(int i = 0; i < lengths.length; i++) { + for(int j = 0; j < lengths[i]; j++) { + ret.set(0, off++, categorical[i] ? 1 : 0); + } + } + + ec.setMatrixOutput(output.getName(), ret); + + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 5ebc243dd44..86db894f8e3 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -2941,6 +2941,25 @@ public static void writeTestScalar(String file, double value) { } } + + /** + * Write scalar to file + * + * @param file File to write to + * @param value Value to write + */ + public static void writeTestScalar(String file, String value) { + try { + DataOutputStream out = new DataOutputStream(new FileOutputStream(file)); + try(PrintWriter pw = new PrintWriter(out)) { + pw.println(value); + } + } + catch(IOException e) { + fail("unable to write test scalar (" + file + "): " + e.getMessage()); + } + } + /** * Write scalar to file * diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java new file mode 100644 index 00000000000..b8df642542f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction; +import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Unit tests that drive the get_categorical_mask instruction directly to exercise the defensive code + * paths (distinct-count prefix in the metadata frame, default column metadata, non id-based specs and + * the unsupported opcode guard) that the script-level transform tests cannot reach. + */ +public class GetCategoricalMaskInstructionTest { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskInstructionTest.class.getName()); + + private static final String MASK_OPCODE = "get_categorical_mask"; + + @BeforeClass + public static void init() throws java.io.IOException { + CacheableData.initCaching("get_categorical_mask_instruction_test"); + } + + @Test + public void dummycodeReadsDistinctCountFromMetadataPrefix() { + // a metadata cell prefixed with '¿' encodes the number of distinct values inline + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"¿3"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(3, res.getNumColumns()); + assertArrayEquals(new double[] {1, 1, 1}, res.getDenseBlockValues(), 0.0); + } + + @Test + public void dummycodeDefaultMetadataContributesNoColumns() { + // first column is dummycoded but carries default metadata (no distinct count) -> 0 columns, + // the trailing pass-through column keeps the output non-empty + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING, ValueType.STRING}, + new String[][] {{"x", "y"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(1, res.getNumColumns()); + assertEquals(0.0, res.get(0, 0), 0.0); + } + + @Test + public void nonIdSpecMissingIdsKeyThrows() { + // a spec without the "ids" key must be rejected, not silently mis-interpreted + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("non ID based spec", () -> run(meta, "{\"recode\": [1]}")); + } + + @Test + public void nonIdSpecIdsFalseThrows() { + // "ids": false is equally unsupported + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("non ID based spec", () -> run(meta, "{\"ids\": false, \"recode\": [1]}")); + } + + @Test + public void unsupportedOpcodeThrows() { + // any frame-scalar binary opcode other than get_categorical_mask must be rejected + ExecutionContext ec = ExecutionContextFactory.createContext(); + ec.setAutoCreateVars(true); + ec.setVariable("F", frameObject(new FrameBlock(new ValueType[] {ValueType.STRING}, + new String[][] {{"a"}}))); + assertThrowsMessage("Unsupported operation", () -> maskInstruction("+").processInstruction(ec)); + } + + /** Assert the action throws a DMLRuntimeException whose message chain contains the expected text. */ + private static void assertThrowsMessage(String expected, Runnable action) { + try { + action.run(); + fail("Expected DMLRuntimeException containing \"" + expected + "\" but nothing was thrown"); + } + catch(DMLRuntimeException e) { + StringBuilder chain = new StringBuilder(); + for(Throwable t = e; t != null; t = t.getCause()) + chain.append(t.getMessage()).append(" | "); + assertTrue("Exception chain [" + chain + "] should contain \"" + expected + "\"", + chain.toString().contains(expected)); + } + } + + private static MatrixBlock run(FrameBlock meta, String spec) { + ExecutionContext ec = ExecutionContextFactory.createContext(); + ec.setAutoCreateVars(true); + maskInstruction(MASK_OPCODE).processGetCategorical(ec, meta, new StringObject(spec)); + return ec.getMatrixObject("out").acquireReadAndRelease(); + } + + private static BinaryFrameScalarCPInstruction maskInstruction(String opcode) { + String in1 = InstructionUtils.concatOperandParts("F", DataType.FRAME.name(), ValueType.STRING.name(), "false"); + String in2 = InstructionUtils.concatOperandParts("spec", DataType.SCALAR.name(), ValueType.STRING.name(), "true"); + String out = InstructionUtils.concatOperandParts("out", DataType.MATRIX.name(), ValueType.FP64.name(), "false"); + String str = InstructionUtils.concatOperands("CP", opcode, in1, in2, out); + return (BinaryFrameScalarCPInstruction) BinaryCPInstruction.parseInstruction(str); + } + + private static FrameObject frameObject(FrameBlock fb) { + MatrixCharacteristics mc = new MatrixCharacteristics(fb.getNumRows(), fb.getNumColumns(), -1, -1); + FrameObject fo = new FrameObject("F", new MetaDataFormat(mc, FileFormat.BINARY), fb.getSchema()); + fo.acquireModify(fb); + fo.release(); + return fo; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java new file mode 100644 index 00000000000..30681f373e4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class GetCategoricalMaskTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskTest.class.getName()); + + private final static String TEST_NAME1 = "GetCategoricalMaskTest"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeApplyTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testRecode() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 1, 1.0); + String spec = "{\"ids\": true, \"recode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testRecode2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8, ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 2, new double[] {0, 1}); + + String spec = "{\"ids\": true, \"recode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {0, 1, 1, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {1, 1, 1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash3() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 7, new double[] {1, 1, 1, 0, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + + @Test + public void testHybrid1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 9, new double[] {1, 1, 1, 0, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHybrid2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.BOOLEAN,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 10, new double[] {1, 1, 1, 1,1, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,2,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + private void runTransformTest(FrameBlock fb, String spec, MatrixBlock expected) throws Exception { + try { + + getAndLoadTestConfiguration(TEST_NAME1); + + String inF = input("F-In"); + String inS = input("spec"); + + TestUtils.writeTestFrame(inF, fb, fb.getSchema(), FileFormat.CSV); + TestUtils.writeTestScalar(input("spec"), spec); + + String out = output("ret"); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-args", inF, inS, out, expected.getNumColumns() + ""}; + + runTest(true, false, null, -1); + + MatrixBlock result = TestUtils.readBinary(out); + + TestUtils.compareMatrices(expected, result, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + +} diff --git a/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml new file mode 100644 index 00000000000..5d7bb35a250 --- /dev/null +++ b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($1, data_type="frame", format="csv"); + +jspec = read($2, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +Cm = getCategoricalMask(M, jspec) +expectedColumns = $4 +if(ncol(Cm) != expectedColumns){ + stop("Wrong number of metadata columns in categorical mask") +} +# print mean to verify that Cm is a matrix, not a Frame according to compiler +print(mean(Cm)) + +write(Cm, $3, format="csv"); +