Skip to content

Commit 7016703

Browse files
committed
Implemented first instance of preview
1 parent 1ae8602 commit 7016703

2 files changed

Lines changed: 88 additions & 40 deletions

File tree

evaluation_function/preview.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,56 @@
1+
import ast
12
from typing import Any
23
from lf_toolkit.preview import Result, Params, Preview
34

4-
def preview_function(response: Any, params: Params) -> Result:
5-
"""
6-
Function used to preview a student response.
7-
---
8-
The handler function passes three arguments to preview_function():
5+
_BLOCKED_MODULES = {
6+
"os", "sys", "subprocess", "socket", "urllib", "http",
7+
"requests", "shutil", "pathlib", "ftplib", "smtplib",
8+
"ctypes", "multiprocessing", "threading", "importlib",
9+
"pickle", "builtins",
10+
}
11+
12+
_BLOCKED_BUILTINS = {"exec", "eval", "compile", "open", "__import__", "input"}
13+
14+
15+
class _SecurityVisitor(ast.NodeVisitor):
16+
def __init__(self):
17+
self.violations: list[str] = []
918

10-
- `response` which are the answers provided by the student.
11-
- `params` which are any extra parameters that may be useful,
12-
e.g., error tolerances.
19+
def visit_Import(self, node):
20+
for alias in node.names:
21+
root = alias.name.split(".")[0]
22+
if root in _BLOCKED_MODULES:
23+
self.violations.append(f"import of '{root}' is not allowed")
24+
self.generic_visit(node)
1325

14-
The output of this function is what is returned as the API response
15-
and therefore must be JSON-encodable. It must also conform to the
16-
response schema.
26+
def visit_ImportFrom(self, node):
27+
if node.module:
28+
root = node.module.split(".")[0]
29+
if root in _BLOCKED_MODULES:
30+
self.violations.append(f"import of '{root}' is not allowed")
31+
self.generic_visit(node)
1732

18-
Any standard python library may be used, as well as any package
19-
available on pip (provided it is added to requirements.txt).
33+
def visit_Call(self, node):
34+
if isinstance(node.func, ast.Name) and node.func.id in _BLOCKED_BUILTINS:
35+
self.violations.append(f"use of '{node.func.id}()' is not allowed")
36+
self.generic_visit(node)
2037

21-
The way you wish to structure you code (all in this function, or
22-
split into many) is entirely up to you.
23-
"""
38+
def visit_Attribute(self, node):
39+
if node.attr.startswith("__") and node.attr.endswith("__"):
40+
self.violations.append(f"access to '{node.attr}' is not allowed")
41+
self.generic_visit(node)
2442

43+
44+
def preview_function(response: Any, params: Params) -> Result:
2545
try:
26-
return Result(preview=Preview(sympy=response))
27-
except Exception as e:
28-
return Result(preview=Preview(feedback=str(e)))
46+
tree = ast.parse(str(response))
47+
except SyntaxError as e:
48+
return Result(preview=Preview(feedback=f"SyntaxError: {e.msg} (line {e.lineno})"))
49+
50+
visitor = _SecurityVisitor()
51+
visitor.visit(tree)
52+
if visitor.violations:
53+
lines = "\n".join(f"- {v}" for v in visitor.violations)
54+
return Result(preview=Preview(feedback=f"Unsafe code detected:\n{lines}"))
55+
56+
return Result(preview=Preview(feedback="Valid Python syntax."))

evaluation_function/preview_test.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,48 @@
22

33
from .preview import Params, preview_function
44

5+
56
class TestPreviewFunction(unittest.TestCase):
6-
"""
7-
TestCase Class used to test the algorithm.
8-
---
9-
Tests are used here to check that the algorithm written
10-
is working as it should.
11-
12-
It's best practice to write these tests first to get a
13-
kind of 'specification' for how your algorithm should
14-
work, and you should run these tests before committing
15-
your code to AWS.
16-
17-
Read the docs on how to use unittest here:
18-
https://docs.python.org/3/library/unittest.html
19-
20-
Use preview_function() to check your algorithm works
21-
as it should.
22-
"""
23-
24-
def test_preview(self):
25-
response, params = "A", Params()
7+
8+
def test_valid_python(self):
9+
response, params = "x = 1 + 2", Params()
10+
result = preview_function(response, params)
11+
12+
self.assertIn("preview", result)
13+
self.assertNotIn("SyntaxError", result["preview"].get("feedback", ""))
14+
self.assertNotIn("Unsafe", result["preview"].get("feedback", ""))
15+
16+
def test_invalid_python(self):
17+
response, params = "def foo(:", Params()
18+
result = preview_function(response, params)
19+
20+
self.assertIn("preview", result)
21+
self.assertIn("SyntaxError", result["preview"].get("feedback", ""))
22+
23+
def test_dangerous_import(self):
24+
response, params = "import os", Params()
25+
result = preview_function(response, params)
26+
27+
self.assertIn("preview", result)
28+
self.assertIn("Unsafe", result["preview"].get("feedback", ""))
29+
30+
def test_dangerous_from_import(self):
31+
response, params = "from subprocess import call", Params()
32+
result = preview_function(response, params)
33+
34+
self.assertIn("preview", result)
35+
self.assertIn("Unsafe", result["preview"].get("feedback", ""))
36+
37+
def test_dangerous_builtin_call(self):
38+
response, params = "exec('x=1')", Params()
39+
result = preview_function(response, params)
40+
41+
self.assertIn("preview", result)
42+
self.assertIn("Unsafe", result["preview"].get("feedback", ""))
43+
44+
def test_dunder_access(self):
45+
response, params = "x.__class__.__bases__", Params()
2646
result = preview_function(response, params)
2747

2848
self.assertIn("preview", result)
29-
self.assertIsNotNone(result["preview"])
49+
self.assertIn("Unsafe", result["preview"].get("feedback", ""))

0 commit comments

Comments
 (0)