1313from moto import mock_s3
1414from sqlalchemy .orm import aliased , Session
1515
16+ from process_tracker .models .contact import Contact
1617from process_tracker .models .extract import (
1718 Extract ,
1819 ExtractDatasetType ,
3536from process_tracker .models .source import (
3637 DatasetType ,
3738 Source ,
39+ SourceContact ,
3840 SourceDatasetType ,
3941 SourceObject ,
4042 SourceObjectDatasetType ,
4143)
4244
43- from process_tracker .utilities .data_store import DataStore
45+ from process_tracker .utilities .data_store import DataStore , ClusterProcess
4446from process_tracker .extract_tracker import ExtractTracker
4547from process_tracker .process_tracker import ProcessTracker
4648from process_tracker .utilities import utilities
4749
4850test_bucket = "test_bucket"
4951
5052
53+
5154# @mock_s3
5255class TestProcessTracker (unittest .TestCase ):
5356 @classmethod
@@ -64,6 +67,7 @@ def setUpClass(cls):
6467
6568 @classmethod
6669 def tearDownClass (cls ):
70+ cls .session .query (ClusterProcess ).delete ()
6771 cls .session .query (Location ).delete ()
6872 cls .session .query (DatasetType ).delete ()
6973 cls .session .query (ProcessSourceObject ).delete ()
@@ -656,6 +660,39 @@ def test_find_extracts_by_process_not_descending(self):
656660
657661 self .assertNotEqual (expected_result , given_result )
658662
663+ def test_find_process_contacts (self ):
664+ """
665+ Testing that when passed a process_id, a list of source contacts will be returned.
666+ :return:
667+ """
668+ contact = DataStore ().get_or_create_item (
669+ model = Contact ,
670+ contact_name = "Test Contact" ,
671+ contact_email = "testcontact@test.com" ,
672+ )
673+
674+ source = DataStore ().get_or_create_item (model = Source , source_name = "Unittests" )
675+
676+ DataStore ().get_or_create_item (
677+ model = SourceContact ,
678+ source_id = source .source_id ,
679+ contact_id = contact .contact_id ,
680+ )
681+
682+ given_result = self .process_tracker .find_process_contacts (
683+ process = self .process_id
684+ )
685+
686+ expected_result = [
687+ {
688+ "contact_name" : "Test Contact" ,
689+ "contact_email" : "testcontact@test.com" ,
690+ "contact_type" : "source" ,
691+ }
692+ ]
693+
694+ self .assertEqual (expected_result , given_result )
695+
659696 def test_initializing_process_tracking (self ):
660697 """
661698 Testing that when ProcessTracking is initialized, the necessary objects are created.
0 commit comments