diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip new file mode 100644 index 000000000..8624177c7 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/LICENSE.txt b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/LICENSE.txt new file mode 100644 index 000000000..57c92e9cf --- /dev/null +++ b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/LICENSE.txt @@ -0,0 +1,540 @@ +## ODC Open Database License (ODbL) + +### Preamble + +The Open Database License (ODbL) is a license agreement intended to +allow users to freely share, modify, and use this Database while +maintaining this same freedom for others. Many databases are covered by +copyright, and therefore this document licenses these rights. Some +jurisdictions, mainly in the European Union, have specific rights that +cover databases, and so the ODbL addresses these rights, too. Finally, +the ODbL is also an agreement in contract for users of this Database to +act in certain ways in return for accessing this Database. + +Databases can contain a wide variety of types of content (images, +audiovisual material, and sounds all in the same database, for example), +and so the ODbL only governs the rights over the Database, and not the +contents of the Database individually. Licensors should use the ODbL +together with another license for the contents, if the contents have a +single set of rights that uniformly covers all of the contents. If the +contents have multiple sets of different rights, Licensors should +describe what rights govern what contents together in the individual +record or in some other way that clarifies what rights apply. + +Sometimes the contents of a database, or the database itself, can be +covered by other rights not addressed here (such as private contracts, +trade mark over the name, or privacy rights / data protection rights +over information in the contents), and so you are advised that you may +have to consult other documents or clear other rights before doing +activities not covered by this License. + +------ + +The Licensor (as defined below) + +and + +You (as defined below) + +agree as follows: + +### 1.0 Definitions of Capitalised Words + +"Collective Database" - Means this Database in unmodified form as part +of a collection of independent databases in themselves that together are +assembled into a collective whole. A work that constitutes a Collective +Database will not be considered a Derivative Database. + +"Convey" - As a verb, means Using the Database, a Derivative Database, +or the Database as part of a Collective Database in any way that enables +a Person to make or receive copies of the Database or a Derivative +Database. Conveying does not include interaction with a user through a +computer network, or creating and Using a Produced Work, where no +transfer of a copy of the Database or a Derivative Database occurs. +"Contents" - The contents of this Database, which includes the +information, independent works, or other material collected into the +Database. For example, the contents of the Database could be factual +data or works such as images, audiovisual material, text, or sounds. + +"Database" - A collection of material (the Contents) arranged in a +systematic or methodical way and individually accessible by electronic +or other means offered under the terms of this License. + +"Database Directive" - Means Directive 96/9/EC of the European +Parliament and of the Council of 11 March 1996 on the legal protection +of databases, as amended or succeeded. + +"Database Right" - Means rights resulting from the Chapter III ("sui +generis") rights in the Database Directive (as amended and as transposed +by member states), which includes the Extraction and Re-utilisation of +the whole or a Substantial part of the Contents, as well as any similar +rights available in the relevant jurisdiction under Section 10.4. + +"Derivative Database" - Means a database based upon the Database, and +includes any translation, adaptation, arrangement, modification, or any +other alteration of the Database or of a Substantial part of the +Contents. This includes, but is not limited to, Extracting or +Re-utilising the whole or a Substantial part of the Contents in a new +Database. + +"Extraction" - Means the permanent or temporary transfer of all or a +Substantial part of the Contents to another medium by any means or in +any form. + +"License" - Means this license agreement and is both a license of rights +such as copyright and Database Rights and an agreement in contract. + +"Licensor" -Means the Person that offers the Database under the terms +of this License. + +"Person" -Means a natural or legal person or a body of persons +corporate or incorporate. + +"Produced Work" - a work (such as an image, audiovisual material, text, +or sounds) resulting from using the whole or a Substantial part of the +Contents (via a search or other query) from this Database, a Derivative +Database, or this Database as part of a Collective Database. + +"Publicly" - means to Persons other than You or under Your control by +either more than 50% ownership or by the power to direct their +activities (such as contracting with an independent consultant). + +"Re-utilisation" - means any form of making available to the public all +or a Substantial part of the Contents by the distribution of copies, by +renting, by online or other forms of transmission. + +"Substantial" - Means substantial in terms of quantity or quality or a +combination of both. The repeated and systematic Extraction or +Re-utilisation of insubstantial parts of the Contents may amount to the +Extraction or Re-utilisation of a Substantial part of the Contents. + +"Use" - As a verb, means doing any act that is restricted by copyright +or Database Rights whether in the original medium or any other; and +includes without limitation distributing, copying, publicly performing, +publicly displaying, and preparing derivative works of the Database, as +well as modifying the Database as may be technically necessary to use it +in a different mode or format. + +"You" - Means a Person exercising rights under this License who has not +previously violated the terms of this License with respect to the +Database, or who has received express permission from the Licensor to +exercise rights under this License despite a previous violation. + +Words in the singular include the plural and vice versa. + +### 2.0 What this License covers + +2.1. Legal effect of this document. This License is: + + a. A license of applicable copyright and neighbouring rights; + + b. A license of the Database Right; and + + c. An agreement in contract between You and the Licensor. + +2.2 Legal rights covered. This License covers the legal rights in the +Database, including: + + a. Copyright. Any copyright or neighbouring rights in the Database. + The copyright licensed includes any individual elements of the + Database, but does not cover the copyright over the Contents + independent of this Database. See Section 2.4 for details. Copyright + law varies between jurisdictions, but is likely to cover: the Database + model or schema, which is the structure, arrangement, and organisation + of the Database, and can also include the Database tables and table + indexes; the data entry and output sheets; and the Field names of + Contents stored in the Database; + + b. Database Rights. Database Rights only extend to the Extraction and + Re-utilisation of the whole or a Substantial part of the Contents. + Database Rights can apply even when there is no copyright over the + Database. Database Rights can also apply when the Contents are removed + from the Database and are selected and arranged in a way that would + not infringe any applicable copyright; and + + c. Contract. This is an agreement between You and the Licensor for + access to the Database. In return you agree to certain conditions of + use on this access as outlined in this License. + +2.3 Rights not covered. + + a. This License does not apply to computer programs used in the making + or operation of the Database; + + b. This License does not cover any patents over the Contents or the + Database; and + + c. This License does not cover any trademarks associated with the + Database. + +2.4 Relationship to Contents in the Database. The individual items of +the Contents contained in this Database may be covered by other rights, +including copyright, patent, data protection, privacy, or personality +rights, and this License does not cover any rights (other than Database +Rights or in contract) in individual Contents contained in the Database. +For example, if used on a Database of images (the Contents), this +License would not apply to copyright over individual images, which could +have their own separate licenses, or one single license covering all of +the rights over the images. + +### 3.0 Rights granted + +3.1 Subject to the terms and conditions of this License, the Licensor +grants to You a worldwide, royalty-free, non-exclusive, terminable (but +only under Section 9) license to Use the Database for the duration of +any applicable copyright and Database Rights. These rights explicitly +include commercial use, and do not exclude any field of endeavour. To +the extent possible in the relevant jurisdiction, these rights may be +exercised in all media and formats whether now known or created in the +future. + +The rights granted cover, for example: + + a. Extraction and Re-utilisation of the whole or a Substantial part of + the Contents; + + b. Creation of Derivative Databases; + + c. Creation of Collective Databases; + + d. Creation of temporary or permanent reproductions by any means and + in any form, in whole or in part, including of any Derivative + Databases or as a part of Collective Databases; and + + e. Distribution, communication, display, lending, making available, or + performance to the public by any means and in any form, in whole or in + part, including of any Derivative Database or as a part of Collective + Databases. + +3.2 Compulsory license schemes. For the avoidance of doubt: + + a. Non-waivable compulsory license schemes. In those jurisdictions in + which the right to collect royalties through any statutory or + compulsory licensing scheme cannot be waived, the Licensor reserves + the exclusive right to collect such royalties for any exercise by You + of the rights granted under this License; + + b. Waivable compulsory license schemes. In those jurisdictions in + which the right to collect royalties through any statutory or + compulsory licensing scheme can be waived, the Licensor waives the + exclusive right to collect such royalties for any exercise by You of + the rights granted under this License; and, + + c. Voluntary license schemes. The Licensor waives the right to collect + royalties, whether individually or, in the event that the Licensor is + a member of a collecting society that administers voluntary licensing + schemes, via that society, from any exercise by You of the rights + granted under this License. + +3.3 The right to release the Database under different terms, or to stop +distributing or making available the Database, is reserved. Note that +this Database may be multiple-licensed, and so You may have the choice +of using alternative licenses for this Database. Subject to Section +10.4, all other rights not expressly granted by Licensor are reserved. + +### 4.0 Conditions of Use + +4.1 The rights granted in Section 3 above are expressly made subject to +Your complying with the following conditions of use. These are important +conditions of this License, and if You fail to follow them, You will be +in material breach of its terms. + +4.2 Notices. If You Publicly Convey this Database, any Derivative +Database, or the Database as part of a Collective Database, then You +must: + + a. Do so only under the terms of this License or another license + permitted under Section 4.4; + + b. Include a copy of this License (or, as applicable, a license + permitted under Section 4.4) or its Uniform Resource Identifier (URI) + with the Database or Derivative Database, including both in the + Database or Derivative Database and in any relevant documentation; and + + c. Keep intact any copyright or Database Right notices and notices + that refer to this License. + + d. If it is not possible to put the required notices in a particular + file due to its structure, then You must include the notices in a + location (such as a relevant directory) where users would be likely to + look for it. + +4.3 Notice for using output (Contents). Creating and Using a Produced +Work does not require the notice in Section 4.2. However, if you +Publicly Use a Produced Work, You must include a notice associated with +the Produced Work reasonably calculated to make any Person that uses, +views, accesses, interacts with, or is otherwise exposed to the Produced +Work aware that Content was obtained from the Database, Derivative +Database, or the Database as part of a Collective Database, and that it +is available under this License. + + a. Example notice. The following text will satisfy notice under + Section 4.3: + + Contains information from DATABASE NAME, which is made available + here under the Open Database License (ODbL). + +DATABASE NAME should be replaced with the name of the Database and a +hyperlink to the URI of the Database. "Open Database License" should +contain a hyperlink to the URI of the text of this License. If +hyperlinks are not possible, You should include the plain text of the +required URI's with the above notice. + +4.4 Share alike. + + a. Any Derivative Database that You Publicly Use must be only under + the terms of: + + i. This License; + + ii. A later version of this License similar in spirit to this + License; or + + iii. A compatible license. + + If You license the Derivative Database under one of the licenses + mentioned in (iii), You must comply with the terms of that license. + + b. For the avoidance of doubt, Extraction or Re-utilisation of the + whole or a Substantial part of the Contents into a new database is a + Derivative Database and must comply with Section 4.4. + + c. Derivative Databases and Produced Works. A Derivative Database is + Publicly Used and so must comply with Section 4.4. if a Produced Work + created from the Derivative Database is Publicly Used. + + d. Share Alike and additional Contents. For the avoidance of doubt, + You must not add Contents to Derivative Databases under Section 4.4 a + that are incompatible with the rights granted under this License. + + e. Compatible licenses. Licensors may authorise a proxy to determine + compatible licenses under Section 4.4 a iii. If they do so, the + authorised proxy's public statement of acceptance of a compatible + license grants You permission to use the compatible license. + + +4.5 Limits of Share Alike. The requirements of Section 4.4 do not apply +in the following: + + a. For the avoidance of doubt, You are not required to license + Collective Databases under this License if You incorporate this + Database or a Derivative Database in the collection, but this License + still applies to this Database or a Derivative Database as a part of + the Collective Database; + + b. Using this Database, a Derivative Database, or this Database as + part of a Collective Database to create a Produced Work does not + create a Derivative Database for purposes of Section 4.4; and + + c. Use of a Derivative Database internally within an organisation is + not to the public and therefore does not fall under the requirements + of Section 4.4. + +4.6 Access to Derivative Databases. If You Publicly Use a Derivative +Database or a Produced Work from a Derivative Database, You must also +offer to recipients of the Derivative Database or Produced Work a copy +in a machine readable form of: + + a. The entire Derivative Database; or + + b. A file containing all of the alterations made to the Database or + the method of making the alterations to the Database (such as an + algorithm), including any additional Contents, that make up all the + differences between the Database and the Derivative Database. + +The Derivative Database (under a.) or alteration file (under b.) must be +available at no more than a reasonable production cost for physical +distributions and free of charge if distributed over the internet. + +4.7 Technological measures and additional terms + + a. This License does not allow You to impose (except subject to + Section 4.7 b.) any terms or any technological measures on the + Database, a Derivative Database, or the whole or a Substantial part of + the Contents that alter or restrict the terms of this License, or any + rights granted under it, or have the effect or intent of restricting + the ability of any person to exercise those rights. + + b. Parallel distribution. You may impose terms or technological + measures on the Database, a Derivative Database, or the whole or a + Substantial part of the Contents (a "Restricted Database") in + contravention of Section 4.74 a. only if You also make a copy of the + Database or a Derivative Database available to the recipient of the + Restricted Database: + + i. That is available without additional fee; + + ii. That is available in a medium that does not alter or restrict + the terms of this License, or any rights granted under it, or have + the effect or intent of restricting the ability of any person to + exercise those rights (an "Unrestricted Database"); and + + iii. The Unrestricted Database is at least as accessible to the + recipient as a practical matter as the Restricted Database. + + c. For the avoidance of doubt, You may place this Database or a + Derivative Database in an authenticated environment, behind a + password, or within a similar access control scheme provided that You + do not alter or restrict the terms of this License or any rights + granted under it or have the effect or intent of restricting the + ability of any person to exercise those rights. + +4.8 Licensing of others. You may not sublicense the Database. Each time +You communicate the Database, the whole or Substantial part of the +Contents, or any Derivative Database to anyone else in any way, the +Licensor offers to the recipient a license to the Database on the same +terms and conditions as this License. You are not responsible for +enforcing compliance by third parties with this License, but You may +enforce any rights that You have over a Derivative Database. You are +solely responsible for any modifications of a Derivative Database made +by You or another Person at Your direction. You may not impose any +further restrictions on the exercise of the rights granted or affirmed +under this License. + +### 5.0 Moral rights + +5.1 Moral rights. This section covers moral rights, including any rights +to be identified as the author of the Database or to object to treatment +that would otherwise prejudice the author's honour and reputation, or +any other derogatory treatment: + + a. For jurisdictions allowing waiver of moral rights, Licensor waives + all moral rights that Licensor may have in the Database to the fullest + extent possible by the law of the relevant jurisdiction under Section + 10.4; + + b. If waiver of moral rights under Section 5.1 a in the relevant + jurisdiction is not possible, Licensor agrees not to assert any moral + rights over the Database and waives all claims in moral rights to the + fullest extent possible by the law of the relevant jurisdiction under + Section 10.4; and + + c. For jurisdictions not allowing waiver or an agreement not to assert + moral rights under Section 5.1 a and b, the author may retain their + moral rights over certain aspects of the Database. + +Please note that some jurisdictions do not allow for the waiver of moral +rights, and so moral rights may still subsist over the Database in some +jurisdictions. + +### 6.0 Fair dealing, Database exceptions, and other rights not affected + +6.1 This License does not affect any rights that You or anyone else may +independently have under any applicable law to make any use of this +Database, including without limitation: + + a. Exceptions to the Database Right including: Extraction of Contents + from non-electronic Databases for private purposes, Extraction for + purposes of illustration for teaching or scientific research, and + Extraction or Re-utilisation for public security or an administrative + or judicial procedure. + + b. Fair dealing, fair use, or any other legally recognised limitation + or exception to infringement of copyright or other applicable laws. + +6.2 This License does not affect any rights of lawful users to Extract +and Re-utilise insubstantial parts of the Contents, evaluated +quantitatively or qualitatively, for any purposes whatsoever, including +creating a Derivative Database (subject to other rights over the +Contents, see Section 2.4). The repeated and systematic Extraction or +Re-utilisation of insubstantial parts of the Contents may however amount +to the Extraction or Re-utilisation of a Substantial part of the +Contents. + +### 7.0 Warranties and Disclaimer + +7.1 The Database is licensed by the Licensor "as is" and without any +warranty of any kind, either express, implied, or arising by statute, +custom, course of dealing, or trade usage. Licensor specifically +disclaims any and all implied warranties or conditions of title, +non-infringement, accuracy or completeness, the presence or absence of +errors, fitness for a particular purpose, merchantability, or otherwise. +Some jurisdictions do not allow the exclusion of implied warranties, so +this exclusion may not apply to You. + +### 8.0 Limitation of liability + +8.1 Subject to any liability that may not be excluded or limited by law, +the Licensor is not liable for, and expressly excludes, all liability +for loss or damage however and whenever caused to anyone by any use +under this License, whether by You or by anyone else, and whether caused +by any fault on the part of the Licensor or not. This exclusion of +liability includes, but is not limited to, any special, incidental, +consequential, punitive, or exemplary damages such as loss of revenue, +data, anticipated profits, and lost business. This exclusion applies +even if the Licensor has been advised of the possibility of such +damages. + +8.2 If liability may not be excluded by law, it is limited to actual and +direct financial loss to the extent it is caused by proved negligence on +the part of the Licensor. + +### 9.0 Termination of Your rights under this License + +9.1 Any breach by You of the terms and conditions of this License +automatically terminates this License with immediate effect and without +notice to You. For the avoidance of doubt, Persons who have received the +Database, the whole or a Substantial part of the Contents, Derivative +Databases, or the Database as part of a Collective Database from You +under this License will not have their licenses terminated provided +their use is in full compliance with this License or a license granted +under Section 4.8 of this License. Sections 1, 2, 7, 8, 9 and 10 will +survive any termination of this License. + +9.2 If You are not in breach of the terms of this License, the Licensor +will not terminate Your rights under it. + +9.3 Unless terminated under Section 9.1, this License is granted to You +for the duration of applicable rights in the Database. + +9.4 Reinstatement of rights. If you cease any breach of the terms and +conditions of this License, then your full rights under this License +will be reinstated: + + a. Provisionally and subject to permanent termination until the 60th + day after cessation of breach; + + b. Permanently on the 60th day after cessation of breach unless + otherwise reasonably notified by the Licensor; or + + c. Permanently if reasonably notified by the Licensor of the + violation, this is the first time You have received notice of + violation of this License from the Licensor, and You cure the + violation prior to 30 days after your receipt of the notice. + +Persons subject to permanent termination of rights are not eligible to +be a recipient and receive a license under Section 4.8. + +9.5 Notwithstanding the above, Licensor reserves the right to release +the Database under different license terms or to stop distributing or +making available the Database. Releasing the Database under different +license terms or stopping the distribution of the Database will not +withdraw this License (or any other license that has been, or is +required to be, granted under the terms of this License), and this +License will continue in full force and effect unless terminated as +stated above. + +### 10.0 General + +10.1 If any provision of this License is held to be invalid or +unenforceable, that must not affect the validity or enforceability of +the remainder of the terms and conditions of this License and each +remaining provision of this License shall be valid and enforced to the +fullest extent permitted by law. + +10.2 This License is the entire agreement between the parties with +respect to the rights granted here over the Database. It replaces any +earlier understandings, agreements or representations with respect to +the Database. + +10.3 If You are in breach of the terms of this License, You will not be +entitled to rely on the terms of this License or to complain of any +breach by the Licensor. + +10.4 Choice of law. This License takes effect in and will be governed by +the laws of the relevant jurisdiction in which the License terms are +sought to be enforced. If the standard suite of rights granted under +applicable copyright law and Database Rights in the relevant +jurisdiction includes additional rights not granted under this License, +these additional rights are granted in this License in order to meet the +terms of this License. \ No newline at end of file diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/README.txt b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/README.txt new file mode 100644 index 000000000..69d9aa22d --- /dev/null +++ b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/README.txt @@ -0,0 +1,21 @@ +=============================== +MIMIC-IV Clinical Database Demo +=============================== + +The Medical Information Mart for Intensive Care (MIMIC)-IV database is comprised +of deidentified electronic health records for patients admitted to the Beth Israel +Deaconess Medical Center. Access to MIMIC-IV is limited to credentialed users. +Here, we have provided an openly-available demo of MIMIC-IV containing a subset +of 100 patients. The dataset includes similar content to MIMIC-IV, but excludes +free-text clinical notes. The demo may be useful for running workshops and for +assessing whether the MIMIC-IV is appropriate for a study before making an access +request. + +For details on the data, see the MIMIC-IV project on PhysioNet: +https://doi.org/10.13026/07hj-2a80 + +The contents of this project also contain an additional file: +demo_subject_id.csv + +This is a CSV file containing the subject_id used to filter MIMIC-IV. Only +these subject_id are available in the demo. diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/SHA256SUMS.txt b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/SHA256SUMS.txt new file mode 100644 index 000000000..02a038588 --- /dev/null +++ b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/SHA256SUMS.txt @@ -0,0 +1,34 @@ +bcdfcd4f31c790e30fb645dc36414907172405897adff329ad379eed85d12017 LICENSE.txt +4eaae533164245e59e5eb70b52695c061900cdc067cebe508105706b4d10e393 README.txt +6e9ba177eefc70617ccc1782b9af716b23fbcf36c4c5143e275149d78ee83624 demo_subject_id.csv +910b9f160ffdf1e08ea673585393f347c773ccc87d66875c627584a903ae8493 hosp/admissions.csv.gz +493ea7fe1d22afd5abd9138be8a4e50c776986feded6891143120d2bc724d881 hosp/d_hcpcs.csv.gz +0f715a3c4c5e44400305d4deb86ff648e499a3bfcb946f353e8b030022d3ec06 hosp/d_icd_diagnoses.csv.gz +a921a20fbf3220e2a7fe874d6392d671c1cc769a001ef0a7841b80eb01030bb6 hosp/d_icd_procedures.csv.gz +72ea5f020469fd24543bd98ab3bd0a4f645a5a4b7802d9d52af20d736f9d76db hosp/d_labitems.csv.gz +958565670c0b3903c0c12825366dc2bae8561d2d55f16feac0c949dd630ed3e4 hosp/diagnoses_icd.csv.gz +9ec47b7a49514cfe243236ca2eb5c5b517d2215d6589efb2bd80166883243537 hosp/drgcodes.csv.gz +f14a4b5597abcc2e8d62d291c7839347ef6260d4c548ea534ab62df6b2f49adc hosp/emar.csv.gz +8ca5a09f2fa2e0ad5d77121fd62bf5f58bad056713223d9758b4ecd03fa0a435 hosp/emar_detail.csv.gz +87f8af876aa6a608c6761068386e93e0c91b59ff2ce0d7a650fe8d5ba833247d hosp/hcpcsevents.csv.gz +d51a4cf1ea63245abe7d544b0d50a928e9ee2ec5eb9a5ce85500aed0bd6c6dfc hosp/labevents.csv.gz +fa462443cc5777ebf9d833f71c67f2a89b4a074eb376320da70bed220e918ad4 hosp/microbiologyevents.csv.gz +6d7c69dbb2c79b9d1de7d1e68f3424ad248821d7c074dc6537d7290774af39b9 hosp/omr.csv.gz +59dcc6cd8450ddf03ad32cc7d76e41caa4c9633541715987f590c74835cf6fae hosp/patients.csv.gz +3403865a279464b14d038a66b30ba3d609933f517262ea0402fef020046bd197 hosp/pharmacy.csv.gz +bf850e1519f2c3df8c706dc5e2afcb83e86e1d6e4afd7c01426b6b07d25cfeef hosp/poe.csv.gz +7010cbc4d74e9a7c998be33c3736af301ebe2c6678969e25800d28e90a2b5825 hosp/poe_detail.csv.gz +33c392ba5b9299b08eca0a61911ba106f0aebdba26ed31b856bb9ffd49fe3654 hosp/prescriptions.csv.gz +68f21dcba9ae0c4b7faa3ef38aa950a8ebeb6d167bc74416314dccdcf28674e0 hosp/procedures_icd.csv.gz +7a56a24cfe2fcd5ed1995c696417bf8f5c09088eb546503b07362c504228c63a hosp/provider.csv.gz +28306d68e952fe2cc9ef6e60b0afda93c27d7a06bee8ef39e71ccc4de5ded793 hosp/services.csv.gz +41151539886d12d57159b65ffe4d7df5a7ef8ceb7cd113ea9a56fbfbfd78a87c hosp/transfers.csv.gz +88ee01d7f21aadee273e2d1ce65b0f14fe96fe87ab42e2b40c89aa8103ef6371 icu/caregiver.csv.gz +90f096ad3db847ed1aaec073f8d86053c51fea7b616d3b046410e030f083d560 icu/chartevents.csv.gz +1690467a822345226a041c1149237c418aa60098c6b699fb4cef64d73588a01f icu/d_items.csv.gz +b9f05bb9b1c5aa52e9f25585853583a2e07046f22ed68419c20ef8daa8762da5 icu/datetimeevents.csv.gz +e05e81aa52a3022e522b6832a898101a69b84e64c17bb344d819e458d5bc21b3 icu/icustays.csv.gz +c9b0afdca936a9485422f72d227b3d1942125bb6d969eef57fdf0233458c6b23 icu/ingredientevents.csv.gz +69597dca47e554e16645ca505167e9f39b07c1724e915ba54ccfb40f5fa0078a icu/inputevents.csv.gz +3bcf68adaa11070d6c24e13d4bfaaf718606cd33b0df7f28e0de240c357a869e icu/outputevents.csv.gz +37454f286dc1eee6cef666b4fdf91452eb8bea6bfc3cf98d4ad179ebb4a741fa icu/procedureevents.csv.gz diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/demo_subject_id.csv b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/demo_subject_id.csv new file mode 100644 index 000000000..2d1febc74 --- /dev/null +++ b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/demo_subject_id.csv @@ -0,0 +1,101 @@ +subject_id +10000032 +10001217 +10001725 +10002428 +10002495 +10002930 +10003046 +10003400 +10004235 +10004422 +10004457 +10004720 +10004733 +10005348 +10005817 +10005866 +10005909 +10006053 +10006580 +10007058 +10007795 +10007818 +10007928 +10008287 +10008454 +10009035 +10009049 +10009628 +10010471 +10010867 +10011398 +10012552 +10012853 +10013049 +10014078 +10014354 +10014729 +10015272 +10015860 +10015931 +10016150 +10016742 +10016810 +10017492 +10018081 +10018328 +10018423 +10018501 +10018845 +10019003 +10019172 +10019385 +10019568 +10019777 +10019917 +10020187 +10020306 +10020640 +10020740 +10020786 +10020944 +10021118 +10021312 +10021487 +10021666 +10021938 +10022017 +10022041 +10022281 +10022880 +10023117 +10023239 +10023771 +10024043 +10025463 +10025612 +10026255 +10026354 +10026406 +10027445 +10027602 +10029291 +10029484 +10031404 +10031757 +10032725 +10035185 +10035631 +10036156 +10037861 +10037928 +10037975 +10038081 +10038933 +10038992 +10038999 +10039708 +10039831 +10039997 +10040025 diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz new file mode 100644 index 000000000..069cf44b6 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_hcpcs.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_hcpcs.csv.gz new file mode 100644 index 000000000..de0d94609 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_hcpcs.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_icd_diagnoses.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_icd_diagnoses.csv.gz new file mode 100644 index 000000000..fa715ba44 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_icd_diagnoses.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_icd_procedures.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_icd_procedures.csv.gz new file mode 100644 index 000000000..05ad8174e Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_icd_procedures.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_labitems.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_labitems.csv.gz new file mode 100644 index 000000000..222c5978d Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/d_labitems.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/diagnoses_icd.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/diagnoses_icd.csv.gz new file mode 100644 index 000000000..72d759dd8 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/diagnoses_icd.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/drgcodes.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/drgcodes.csv.gz new file mode 100644 index 000000000..7d65259ac Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/drgcodes.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/emar.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/emar.csv.gz new file mode 100644 index 000000000..628bb3f51 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/emar.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/emar_detail.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/emar_detail.csv.gz new file mode 100644 index 000000000..a5384e86a Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/emar_detail.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/hcpcsevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/hcpcsevents.csv.gz new file mode 100644 index 000000000..b5d588996 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/hcpcsevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/labevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/labevents.csv.gz new file mode 100644 index 000000000..47f9a4b3e Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/labevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/microbiologyevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/microbiologyevents.csv.gz new file mode 100644 index 000000000..7919abe64 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/microbiologyevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/omr.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/omr.csv.gz new file mode 100644 index 000000000..38056a3e7 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/omr.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/patients.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/patients.csv.gz new file mode 100644 index 000000000..ad9e87891 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/patients.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/pharmacy.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/pharmacy.csv.gz new file mode 100644 index 000000000..063d4a5b8 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/pharmacy.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/poe.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/poe.csv.gz new file mode 100644 index 000000000..e11d2fd26 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/poe.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/poe_detail.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/poe_detail.csv.gz new file mode 100644 index 000000000..f0337efd1 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/poe_detail.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/prescriptions.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/prescriptions.csv.gz new file mode 100644 index 000000000..20e5eab5d Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/prescriptions.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/procedures_icd.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/procedures_icd.csv.gz new file mode 100644 index 000000000..b7c6b36af Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/procedures_icd.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/provider.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/provider.csv.gz new file mode 100644 index 000000000..094b199b1 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/provider.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/services.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/services.csv.gz new file mode 100644 index 000000000..1b1ebde15 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/services.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/transfers.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/transfers.csv.gz new file mode 100644 index 000000000..721598b81 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/transfers.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/caregiver.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/caregiver.csv.gz new file mode 100644 index 000000000..717c4265f Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/caregiver.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/chartevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/chartevents.csv.gz new file mode 100644 index 000000000..38970f403 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/chartevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/d_items.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/d_items.csv.gz new file mode 100644 index 000000000..9ee0a9f85 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/d_items.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/datetimeevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/datetimeevents.csv.gz new file mode 100644 index 000000000..b62946ca9 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/datetimeevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/icustays.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/icustays.csv.gz new file mode 100644 index 000000000..9200e7e0f Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/icustays.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/ingredientevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/ingredientevents.csv.gz new file mode 100644 index 000000000..b42bf00f5 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/ingredientevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/inputevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/inputevents.csv.gz new file mode 100644 index 000000000..8b5842e05 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/inputevents.csv.gz differ diff --git a/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/procedureevents.csv.gz b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/procedureevents.csv.gz new file mode 100644 index 000000000..b55dd3681 Binary files /dev/null and b/data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/icu/procedureevents.csv.gz differ diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..2d2251763 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -170,6 +170,7 @@ API Reference :maxdepth: 3 models/pyhealth.models.BaseModel + models/pyhealth.models.GA2M models/pyhealth.models.LogisticRegression models/pyhealth.models.MLP models/pyhealth.models.CNN diff --git a/docs/api/models/pyhealth.models.GA2M.rst b/docs/api/models/pyhealth.models.GA2M.rst new file mode 100644 index 000000000..ba2c92975 --- /dev/null +++ b/docs/api/models/pyhealth.models.GA2M.rst @@ -0,0 +1,7 @@ +pyhealth.models.GA2M +=================================== +GA2M model with binned shape functions and pairwise interactions. +.. autoclass:: pyhealth.models.GA2M + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4_mortality_ga2m.py b/examples/mimic4_mortality_ga2m.py new file mode 100644 index 000000000..558f39abd --- /dev/null +++ b/examples/mimic4_mortality_ga2m.py @@ -0,0 +1,307 @@ +"""GA2M In-Hospital Mortality Prediction: Full Pipeline + Ablation Study. + +Reproduces the mortality prediction pipeline from: + Hegselmann et al., "An Evaluation of the Doctor-Interpretability of + Generalized Additive Models with Interactions", MLHC 2020. + https://proceedings.mlr.press/v126/hegselmann20a.html + +This script runs four ablations matching the project proposal: + 1. Full GA2M (main effects + top-10 interactions) + 2. Main effects only (use_interactions=False) + 3. Reduced feature set (mean features only, no std) + 4. Logistic Regression baseline (sklearn) + +Metrics: AUC-ROC and AUC-PR (paper Section 2.2). + +Use: + python examples/mimic4_mortality_ga2m.py \ + --data_root data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2 + +Note: + Full results require credentialed MIMIC-III access. This script uses + MIMIC-IV demo (100 patients) for development and testing. Results on + the demo will be different from the paper due to the small sample size. +""" + +import argparse +import sys +import os + +import numpy as np +import torch +from sklearn.metrics import roc_auc_score, average_precision_score +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import StandardScaler + +# check if pyhealth is installed +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models.ga2m import GA2M +from pyhealth.datasets.mimic4_icu_mortality import build_mortality_samples, UNKNOWN_SENTINEL + + +# helpers + +def evaluate(model: GA2M, loader, device: str = "cpu") -> dict: + # auc-roc and auc-pr eval + # return dict of metrics + model.eval() + all_probs, all_labels = [], [] + + with torch.no_grad(): + for batch in loader: + out = model(**{k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}) + all_probs.append(out["y_prob"].cpu().numpy()) + all_labels.append(out["y_true"].cpu().numpy()) + + y_prob = np.concatenate(all_probs).squeeze() + y_true = np.concatenate(all_labels).squeeze() + + # if only one class, metrics would not be defined (NaN) + if len(np.unique(y_true)) < 2: + return {"auc_roc": float("nan"), "auc_pr": float("nan")} + + return { + "auc_roc": roc_auc_score(y_true, y_prob), + "auc_pr": average_precision_score(y_true, y_prob), + } + + +def run_logistic_regression(samples, test_size=0.2, seed=42) -> dict: + # sklearn logistic regression baseline for comparison w/ GA2M + # unknown sentinels (-1) are replaced with 0 before scaling + train_samples, test_samples = train_test_split( + samples, + test_size=test_size, + random_state=seed, + stratify=[s["label"] for s in samples], + ) + + # extract feature arrays, replacing sentinel with 0 for sklearn + X_train = np.array([s["features"] for s in train_samples]) + X_test = np.array([s["features"] for s in test_samples]) + X_train[X_train == UNKNOWN_SENTINEL] = 0.0 + X_test[X_test == UNKNOWN_SENTINEL] = 0.0 + + y_train = np.array([s["label"] for s in train_samples]) + y_test = np.array([s["label"] for s in test_samples]) + + # standardise features — logistic regression is sensitive to scale + scaler = StandardScaler() + X_train = scaler.fit_transform(X_train) + X_test = scaler.transform(X_test) + + # use high max_iter to ensure convergence on small demo data + clf = LogisticRegression(max_iter=1000, random_state=seed) + clf.fit(X_train, y_train) + + y_prob = clf.predict_proba(X_test)[:, 1] + + if len(np.unique(y_test)) < 2: + return {"auc_roc": float("nan"), "auc_pr": float("nan")} + + return { + "auc_roc": roc_auc_score(y_test, y_prob), + "auc_pr": average_precision_score(y_test, y_prob), + } + + +def make_dataset_and_loaders(samples, test_size=0.2, batch_size=32, seed=42): + # split samples + # create SampleDataset for train and test + # create DataLoader for train and test + train_samples, test_samples = train_test_split( + samples, + test_size=test_size, + random_state=seed, + stratify=[s["label"] for s in samples], + ) + + train_ds = create_sample_dataset( + samples=train_samples, + input_schema={"features": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="mimic4_train", + in_memory=True, + ) + test_ds = create_sample_dataset( + samples=test_samples, + input_schema={"features": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="mimic4_test", + in_memory=True, + ) + + train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + test_loader = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + return train_ds, test_ds, train_loader, test_loader + + +def run_experiment( + name: str, + samples, + n_bins: int = 32, + top_k_interactions: int = 10, + use_interactions: bool = True, + stage1_epochs: int = 20, + stage2_epochs: int = 20, + lr: float = 1e-2, + seed: int = 42, +): + # one GA2M experiment & return metrics + print(f"\n{'='*60}") + print(f"Experiment: {name}") + print(f" n_bins={n_bins}, top_k={top_k_interactions}, " + f"use_interactions={use_interactions}") + print(f"{'='*60}") + + torch.manual_seed(seed) + + train_ds, test_ds, train_loader, test_loader = make_dataset_and_loaders( + samples, seed=seed + ) + + model = GA2M( + dataset=train_ds, + n_bins=n_bins, + top_k_interactions=top_k_interactions, + use_interactions=use_interactions, + ) + + # 1. train main effects + model.fit_bins(train_loader) + model.fit_main_effects(train_loader, epochs=stage1_epochs, lr=lr) + + # 2. select interactions and train full model + if use_interactions: + model.select_top_interactions() + + optimiser = torch.optim.Adam(model.parameters(), lr=lr) + model.train() + for epoch in range(stage2_epochs): + total_loss = 0.0 + for batch in train_loader: + optimiser.zero_grad() + out = model(**batch) + out["loss"].backward() + optimiser.step() + total_loss += out["loss"].item() + if (epoch + 1) % 5 == 0: + print(f" [Stage 2] Epoch {epoch+1}/{stage2_epochs} " + f"loss={total_loss/len(train_loader):.4f}") + + metrics = evaluate(model, test_loader) + print(f"\n AUC-ROC : {metrics['auc_roc']:.4f}") + print(f" AUC-PR : {metrics['auc_pr']:.4f}") + return metrics, model + + +# main + +def main(): + parser = argparse.ArgumentParser( + description="GA2M mortality prediction ablation on MIMIC-IV" + ) + parser.add_argument( + "--data_root", + type=str, + default="data/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2", + help="Path to MIMIC-IV root directory", + ) + parser.add_argument("--n_bins", type=int, default=32, + help="Number of bins (paper uses 256; use 32 for demo)") + parser.add_argument("--top_k", type=int, default=10, + help="Top-K interactions (paper uses 34)") + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--lr", type=float, default=1e-2) + args = parser.parse_args() + + print(f"Loading MIMIC-IV data from: {args.data_root}") + all_samples = build_mortality_samples(root=args.data_root) + print(f"Loaded {len(all_samples)} ICU stays " + f"({sum(s['label'] for s in all_samples)} deaths)") + + results = {} + + # ablation 1: Full GA2M (main effects + interactions) + metrics, full_model = run_experiment( + name="Full GA2M (main effects + interactions)", + samples=all_samples, + n_bins=args.n_bins, + top_k_interactions=args.top_k, + use_interactions=True, + stage1_epochs=args.epochs, + stage2_epochs=args.epochs, + lr=args.lr, + ) + results["Full GA2M"] = metrics + + # ablation 2: Main effects only (no interactions) + metrics, _ = run_experiment( + name="Main Effects Only (no interactions)", + samples=all_samples, + n_bins=args.n_bins, + top_k_interactions=0, + use_interactions=False, + stage1_epochs=args.epochs, + stage2_epochs=args.epochs, + lr=args.lr, + ) + results["Main Effects Only"] = metrics + + # ablation 3: Reduced feature set (mean features only, std masked) + # zero out std positions (odd indices) to test whether variability matters + reduced_samples = [] + for s in all_samples: + feats = list(s["features"]) + # std features are at odd indices (1, 3, 5, ...) + for i in range(1, len(feats), 2): + feats[i] = UNKNOWN_SENTINEL + reduced_samples.append({**s, "features": feats}) + + metrics, _ = run_experiment( + name="Reduced Features (mean only, std masked)", + samples=reduced_samples, + n_bins=args.n_bins, + top_k_interactions=args.top_k, + use_interactions=True, + stage1_epochs=args.epochs, + stage2_epochs=args.epochs, + lr=args.lr, + ) + results["Mean Features Only"] = metrics + + # ablation 4: Logistic Regression baseline + # simpler interpretable model — tests whether GA2M's added complexity helps + print(f"\n{'='*60}") + print(f"Experiment: Logistic Regression Baseline") + print(f"{'='*60}") + metrics = run_logistic_regression(all_samples) + print(f"\n AUC-ROC : {metrics['auc_roc']:.4f}") + print(f" AUC-PR : {metrics['auc_pr']:.4f}") + results["Logistic Regression"] = metrics + + # summary of results + print(f"\n{'='*60}") + print("ABLATION SUMMARY") + print(f"{'='*60}") + print(f"{'Experiment':<35} {'AUC-ROC':>8} {'AUC-PR':>8}") + print(f"{'-'*55}") + for name, m in results.items(): + print(f"{name:<35} {m['auc_roc']:>8.4f} {m['auc_pr']:>8.4f}") + + print(f"\nNote: Results on demo (100 patients) will differ from the paper") + print(f"(paper uses ~14k training stays from full MIMIC-III).") + + # visualise shape function for heart rate + print(f"\nHeart rate (feature 0) risk function range:") + midpoints, risks = full_model.get_shape_function(0) + finite = np.isfinite(midpoints) + for mp, r in zip(midpoints[finite][:5], risks[finite][:5]): + print(f" bin midpoint={mp:.2f} risk={r:.4f}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/mimic4_icu_mortality.py b/pyhealth/datasets/mimic4_icu_mortality.py new file mode 100644 index 000000000..f0e215c30 --- /dev/null +++ b/pyhealth/datasets/mimic4_icu_mortality.py @@ -0,0 +1,234 @@ +"""MIMIC-IV ICU In-Hospital Mortality Dataset for PyHealth + +Loads the MIMIC-IV dataset and extracts the 34 features (mean and standard +deviation over the first 48 hours of ICU stay for 17 physiological variables) +used in Hegselmann et al. (MLHC 2020) for in-hospital mortality prediction. + +Reference: + Hegselmann et al., "An Evaluation of the Doctor-Interpretability of + Generalized Additive Models with Interactions", MLHC 2020. + https://proceedings.mlr.press/v126/hegselmann20a.html + +Data Access: + Full MIMIC-IV requires credentialed PhysioNet access: + https://physionet.org/content/mimiciv/ + A demo subset (100 patients) is freely available. We are using this. + https://physionet.org/content/mimic-iv-demo/2.2/ + +Usage: + - from pyhealth.datasets import MIMIC4ICUMortalityDataset + - dataset = MIMIC4ICUMortalityDataset(root="path/to/mimic-iv-demo/") + - print(len(dataset)) +""" + +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from pyhealth.datasets import SampleDataset, create_sample_dataset + + +# mapping of variable names to lists of mimic-iv itemids to use +# order of preference +VARIABLE_ITEMIDS: Dict[str, List[int]] = { + "heart_rate": [220045], + "respiratory_rate": [220210], + "temperature": [223762, 223761], # celsius preferred, fahrenheit fallback + "systolic_bp": [220179, 220050], # non-invasive preferred, arterial fallback + "diastolic_bp": [220180, 220051], + "mean_bp": [220181, 220052], + "spo2": [220277], + "gcs_eye": [220739], + "gcs_motor": [223901], + "gcs_verbal": [223900], + "ph": [223830], + "fio2": [223835], + "glucose": [220621], + "potassium": [227442], + "sodium": [220645], + "hematocrit": [220545], + "wbc": [220546], +} + +# unknown/missing sentinel value +UNKNOWN_SENTINEL = -1.0 + +# temp conversion threshold: values above this are assumed Fahrenheit. +_FAHRENHEIT_THRESHOLD = 50.0 + + +def _to_celsius(series: pd.Series) -> pd.Series: + # convert F° to C° for values + mask = series > _FAHRENHEIT_THRESHOLD + series = series.copy() + series[mask] = (series[mask] - 32.0) * 5.0 / 9.0 + return series + + +def build_mortality_samples( + root: str, + max_icu_hours: float = 48.0, + unknown_sentinel: float = UNKNOWN_SENTINEL, +) -> List[Dict]: + """Extract 34-feature mortality prediction samples from MIMIC-IV CSVs. + + Replicates the preprocessing pipeline from Hegselmann et al. (MLHC 2020), + - first 48 hours of each ICU stay + - mean and std dev for each of 17 variables, so 34 features per stay. + - missing vals: "unknown_sentinel" (-1). + - temp vals: to celsius. + + Args: + root: Path to the MIMIC-IV (or demo) root directory, which should + contain ``hosp/`` and ``icu/`` subdirectories. + max_icu_hours: Number of hours from ICU admission to use (default 48). + unknown_sentinel: Value to impute for missing features (default -1.0). + + Returns: + List of sample dicts, each containing: + - patient_id (str) + - visit_id (str) + - features (list of 34 floats: mean+std for 17 variables) + - label (int): 1 = in-hospital death, 0 = survived + """ + hosp_dir = os.path.join(root, "hosp") + icu_dir = os.path.join(root, "icu") + + # table load + admissions = pd.read_csv( + os.path.join(hosp_dir, "admissions.csv.gz"), + parse_dates=["admittime", "dischtime", "deathtime"], + ) + icustays = pd.read_csv( + os.path.join(icu_dir, "icustays.csv.gz"), + parse_dates=["intime", "outtime"], + ) + chartevents = pd.read_csv( + os.path.join(icu_dir, "chartevents.csv.gz"), + parse_dates=["charttime"], + usecols=["subject_id", "stay_id", "charttime", "itemid", "valuenum"], + ) + + # var name look up by itemid + itemid_to_var: Dict[int, str] = {} + for var_name, item_ids in VARIABLE_ITEMIDS.items(): + for iid in item_ids: + # register first occurrence + if iid not in itemid_to_var: + itemid_to_var[iid] = var_name + + # chartevents - only keep rows w/ itemids we want. map to var names + all_itemids = list(itemid_to_var.keys()) + chart = chartevents[chartevents["itemid"].isin(all_itemids)].copy() + chart["variable"] = chart["itemid"].map(itemid_to_var) + + # merge in ICU stay data for admission times + chart = chart.merge( + icustays[["stay_id", "hadm_id", "intime"]], + on="stay_id", + how="inner", + ) + + # within first max icu ahours + chart["hours_from_intime"] = ( + chart["charttime"] - chart["intime"] + ).dt.total_seconds() / 3600.0 + chart = chart[ + (chart["hours_from_intime"] >= 0) + & (chart["hours_from_intime"] <= max_icu_hours) + ] + + # temp conversion + temp_mask = chart["variable"] == "temperature" + chart.loc[temp_mask, "valuenum"] = _to_celsius( + chart.loc[temp_mask, "valuenum"] + ) + + # merge mortaility labels from admissions + mortality = admissions[["hadm_id", "hospital_expire_flag"]].copy() + chart = chart.merge(mortality, on="hadm_id", how="inner") + + # compute per-stay mean and std for each variable + var_names = list(VARIABLE_ITEMIDS.keys()) + samples = [] + + for stay_id, stay_df in chart.groupby("stay_id"): + hadm_id = stay_df["hadm_id"].iloc[0] + subject_id = stay_df["subject_id"].iloc[0] + label = int(stay_df["hospital_expire_flag"].iloc[0]) + + features = [] + for var in var_names: + var_vals = stay_df.loc[ + stay_df["variable"] == var, "valuenum" + ].dropna() + + if len(var_vals) == 0: + mean_val = unknown_sentinel + std_val = unknown_sentinel + elif len(var_vals) == 1: + mean_val = float(var_vals.iloc[0]) + std_val = unknown_sentinel # std undefined for single obs + else: + mean_val = float(var_vals.mean()) + std_val = float(var_vals.std()) + + features.append(mean_val) + features.append(std_val) + + samples.append({ + "patient_id": str(subject_id), + "visit_id": str(stay_id), + "features": features, + "label": label, + }) + + return samples + + +def MIMIC4ICUMortalityDataset( + root: str, + max_icu_hours: float = 48.0, +) -> SampleDataset: + """Extract mortality prediction samples from MIMIC-IV: + + - Extracts 34 features + - (mean + std of 17 physiological variables over the + first 48 ICU hours) + - missing values filled with unknown_sentinel (-1) + - temperature values converted to Celsius + + - following the preprocessing pipeline from Hegselmann et al. (MLHC 2020). + + Args: + root: Path to MIMIC-IV root directory (containing hosp/ and icu/). + max_icu_hours: Hours of ICU data to use per stay (default: 48). + unknown_sentinel: Value used to impute missing features (default: -1.0). + + Returns: + List of sample dicts w/: + - patient_id (str) + - visit_id (str) + - features (list[float]) — length 34 + - label (int) — 1 = in-hospital death, 0 = survived + """ + samples = build_mortality_samples(root=root, max_icu_hours=max_icu_hours) + + # pyhealth needs 2 or more unique labels for bin class tasks + labels = [s["label"] for s in samples] + if len(set(labels)) < 2: + raise ValueError( + f"Dataset must contain both positive and negative mortality " + f"labels. Found only: {set(labels)}. Try using a larger subset." + ) + + return create_sample_dataset( + samples=samples, + input_schema={"features": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="mimic4_icu_mortality", + task_name="in_hospital_mortality", + in_memory=True, + ) \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..9cbef75a1 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .ga2m import GA2M diff --git a/pyhealth/models/ga2m.py b/pyhealth/models/ga2m.py new file mode 100644 index 000000000..e51b4a6dd --- /dev/null +++ b/pyhealth/models/ga2m.py @@ -0,0 +1,319 @@ +"""GA2M (Generalized Additive Model with Interactions) for PyHealth. + +Implements the GA2M architecture described in: + Hegselmann et al., "An Evaluation of the Doctor-Interpretability of + Generalized Additive Models with Interactions", MLHC 2020. + https://proceedings.mlr.press/v126/hegselmann20a.html +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Dict, List, Optional, Tuple + +from pyhealth.models.base_model import BaseModel +from pyhealth.datasets import SampleDataset + + +# Sentinel value for missing values (Following paper implementation) +UNKNOWN_SENTINEL = -1.0 + + +class GA2M(BaseModel): + """ + GA2M model with binned shape functions and pairwise interactions. + """ + + def __init__( + self, + dataset: SampleDataset, + n_bins: int = 256, # Sets model hyperparameters + top_k_interactions: int = 34, # Builds embeddings for main effects + use_interactions: bool = True, # Prepares interaction structures + ) -> None: + super().__init__(dataset) + + assert len(self.feature_keys) == 1, ( + "expects exactly one input feature key" + ) + assert len(self.label_keys) == 1, ( + "expects exactly one binary label key" + ) + + # Store dataset keys for forward pass consistency + self.feature_key = self.feature_keys[0] + self.label_key = self.label_keys[0] + self.n_bins = n_bins + self.top_k_interactions = top_k_interactions + self.use_interactions = use_interactions + + # Infer input dimensionality from dataset + sample = dataset[0] + self.input_dim: int = sample[self.feature_key].shape[0] + + # Bin boundaries - not learned + self.register_buffer( + "bin_edges", + torch.zeros(self.input_dim, n_bins - 1), + ) + self._bins_fitted: bool = False + + # Bias term - global intercept + self.bias = nn.Parameter(torch.zeros(1)) + + # Main effect shape functions - each feature has (n_bins + 1) embeddings (1 for unknown values) + self.main_effects = nn.ModuleList([ + nn.Embedding(n_bins + 1, 1) + for _ in range(self.input_dim) + ]) + + # Initialise all risk scores to zero + for emb in self.main_effects: + nn.init.zeros_(emb.weight) + + # Interaction componenets for Stage 2 + self.interaction_pairs: List[Tuple[int, int]] = [] + self.interactions = nn.ModuleDict() + + # Track current stage for error messaging + self._interactions_selected: bool = False + + # --- Bin step - preprocessing stage before training --- + + def fit_bins(self, data_loader: torch.utils.data.DataLoader) -> None: + """ + Computes quantile-based discretization bins + Includes proper preprocessing pipeline step + Ensures interpretability which is necessary from the paper + """ + all_features: List[torch.Tensor] = [] + + # Collet dataset for global quantile computation + for batch in data_loader: + x = batch[self.feature_key] # (B, D) + all_features.append(x.cpu()) + X = torch.cat(all_features, dim=0) # (N, D) + + edges = torch.zeros(self.input_dim, self.n_bins - 1) + for d in range(self.input_dim): + # Extract feature column + col = x[:, d].contiguous() + + # Remove missing values + valid = col[col != UNKNOWN_SENTINEL] + if valid.numel() == 0: + continue + quantiles = torch.linspace(0.0, 1.0, self.n_bins + 1)[1:-1] + edges[d] = torch.quantile(valid.float(), quantiles) + + self.bin_edges.copy_(edges) + self._bins_fitted = True + + # --- bin assignment --- + + def _assign_bins(self, x: torch.Tensor) -> torch.LongTensor: + """ + Map continuous feature values to bin indices + """ + batch_size = x.size(0) + bin_idx = torch.zeros( + batch_size, self.input_dim, + dtype=torch.long, device=x.device, + ) + + for d in range(self.input_dim): + col = x[:, d].contiguous() + unknown_mask = (col == UNKNOWN_SENTINEL) + + edges = self.bin_edges[d].contiguous() + idx = torch.bucketize(col, edges) # (B,) in [0, n_bins-1] + + # Assign unkown values to special bins + idx[unknown_mask] = self.n_bins + bin_idx[:, d] = idx + + return bin_idx + + # --- Stage 1 - main effect training --- + + def fit_main_effects( + self, + data_loader: torch.utils.data.DataLoader, + epochs: int = 10, + lr: float = 1e-2, + ) -> None: + """ + Stage 1: train only main effect shape functions + Implementing training loop + """ + if not self._bins_fitted: + raise RuntimeError("need to call fit_bins(train_loader) before fit_main_effects()") + + params = list(self.main_effects.parameters()) + [self.bias] + optimiser = torch.optim.Adam(params, lr=lr) + + self.train() + for epoch in range(epochs): + total_loss = 0.0 + for batch in data_loader: + optimiser.zero_grad() + out = self._forward_main_effects_only(batch) + out["loss"].backward() + optimiser.step() + total_loss += out["loss"].item() + print( + f"[GA2M Stage 1] Epoch {epoch + 1}/{epochs} " + f"loss={total_loss / len(data_loader):.4f}" + ) + + # Forward pass used in Stage 1 only + def _forward_main_effects_only( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Forward pass using only main effects + logit = bias + sum(main_effects(feature bins)) + """ + x = batch[self.feature_key].to(self.device) + y_true = batch[self.label_key].float().to(self.device) + + bin_idx = self._assign_bins(x) # (B, D) + batch_size = x.size(0) + + logits = self.bias.expand(batch_size, 1).clone() + + # Add per feature contributions + for d in range(self.input_dim): + logits = logits + self.main_effects[d](bin_idx[:, d]).squeeze(-1).unsqueeze(-1) + + loss = F.binary_cross_entropy_with_logits(logits, y_true) + return {"loss": loss, "logits": logits} + + # --- Stage 1 -> Stage 2 - interaction selection --- + + def select_top_interactions(self) -> List[Tuple[int, int]]: + """ + Select the top-K feature interactions + """ + # per-feature variance of learned risk scores + variances = [] + for d in range(self.input_dim): + weight = self.main_effects[d].weight.data # (n_bins+1, 1) + variances.append(weight.var().item()) + + # Rank all pairs by score + pair_scores: List[Tuple[float, int, int]] = [] + for i in range(self.input_dim): + for j in range(i + 1, self.input_dim): + score = variances[i] * variances[j] + pair_scores.append((score, i, j)) + + pair_scores.sort(key=lambda t: t[0], reverse=True) + top_pairs = [ + (i, j) + for _, i, j in pair_scores[: self.top_k_interactions] + ] + + self.interaction_pairs = top_pairs + + grid_size = (self.n_bins + 1) ** 2 # flattened 2D bin grid + self.interactions = nn.ModuleDict({ + f"{i}_{j}": nn.Embedding(grid_size, 1) + for i, j in top_pairs + }) + for emb in self.interactions.values(): + nn.init.zeros_(emb.weight) + + self._interactions_selected = True + print( + f"[GA2M] Selected {len(top_pairs)} interaction pairs " + f"(top_k={self.top_k_interactions})." + ) + return top_pairs + + # --- full model forward pass --- + + def forward(self, **kwargs: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Compute using main effects and interactions + logit = bias + sum_d f_d(bin(x_d)) + sum_{i,j} f_{ij}(bin(x_i), bin(x_j)) + """ + if not self._bins_fitted: + raise RuntimeError("run fit_bins(train_loader) before running forward()") + if self.use_interactions and not self._interactions_selected: + raise RuntimeError("Call select_top_interactions() after Stage 1 training before running the full model forward pass") + + x = kwargs[self.feature_key] + y_true = kwargs[self.label_key].float() + + bin_idx = self._assign_bins(x) # (B, D) + batch_size = x.size(0) + + # main effects + logits = self.bias.expand(batch_size, 1).clone() + for d in range(self.input_dim): + logits = logits + self.main_effects[d](bin_idx[:, d]).squeeze(-1).unsqueeze(-1) + + # interaction effects + if self.use_interactions and self._interactions_selected: + n_bins_plus1 = self.n_bins + 1 + for i, j in self.interaction_pairs: + key = f"{i}_{j}" + flat_idx = ( + bin_idx[:, i] * n_bins_plus1 + bin_idx[:, j] + ) # (B,) + logits = logits + self.interactions[key](flat_idx).squeeze(-1).unsqueeze(-1) + + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logits": logits, + } + + # --- helpers for interpretability --- + + def get_shape_function(self, feature_idx: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Extract a single feature from teh learned 1D risk shape function + """ + edges = self.bin_edges[feature_idx].detach().cpu().numpy() + lower = np.concatenate([[-np.inf], edges]) + upper = np.concatenate([edges, [np.inf]]) + # find midpoints + midpoints = np.where( + np.isfinite(lower) & np.isfinite(upper), + (lower + upper) / 2, + np.where(np.isfinite(lower), lower + 1.0, upper - 1.0), + ) + risk_scores = ( + self.main_effects[feature_idx] + .weight.data[: self.n_bins] # exclude unknown bin + .squeeze(1) + .detach() + .cpu() + .numpy() + ) + return midpoints, risk_scores + + def get_interaction_shape( + self, feature_i: int, feature_j: int + ) -> Optional[np.ndarray]: + """ + Extract feature pair from the learned 2D interaction risk grid + """ + key = f"{feature_i}_{feature_j}" + if key not in self.interactions: + return None + grid = ( + self.interactions[key] + .weight.data.reshape(self.n_bins + 1, self.n_bins + 1) + .detach() + .cpu() + .numpy() + ) + return grid[: self.n_bins, : self.n_bins] # exclude unknown row/col \ No newline at end of file diff --git a/tests/test_ga2m.py b/tests/test_ga2m.py new file mode 100644 index 000000000..95dec25b8 --- /dev/null +++ b/tests/test_ga2m.py @@ -0,0 +1,481 @@ +"""Tests for the GA2M PyHealth Model: + +- all tests use small synthetic data (2-5 patients with minimal features) +- completes in milliseconds +- no real datasets (MIMIC etc.) are required + +Run with: + python test_ga2m.py +or use pytest: + pytest test_ga2m.py -v +""" + +import tempfile +import shutil +import torch +import numpy as np +import pytest +from pyhealth.models.ga2m import GA2M, UNKNOWN_SENTINEL + +# synthetic dataset helpers + +def _make_dataset( + n_patients: int = 4, + n_features: int = 6, + include_unknowns: bool = False, +): + from pyhealth.datasets import create_sample_dataset + + rng = np.random.default_rng(42) + samples = [] + for i in range(n_patients): + features = rng.uniform(0.0, 5.0, size=n_features).tolist() + if include_unknowns and i % 2 == 0: + # mark first feature unknown, pretend it's missing in 1/2 samples + features[0] = UNKNOWN_SENTINEL + samples.append({ + "patient_id": f"p{i}", + "visit_id": f"v{i}", + "features": features, + "label": int(rng.integers(0, 2)), + }) + + return create_sample_dataset( + samples=samples, + input_schema={"features": "tensor"}, + output_schema={"label": "binary"}, + dataset_name="synthetic_test", + in_memory=True, # no disk i/o, keeps tests fast + ) + + +# helper to make dataloader from dataset +def _make_loader(dataset, batch_size: int = 4): + from pyhealth.datasets import get_dataloader + return get_dataloader(dataset, batch_size=batch_size) + + +# fixtures for reusable test data and models +@pytest.fixture(scope="module") +def base_dataset(): + # shared across the module so we're not rebuilding it for every test + return _make_dataset(n_patients=4, n_features=6) + + +@pytest.fixture(scope="module") +def unknown_dataset(): + return _make_dataset(n_patients=4, n_features=6, include_unknowns=True) + + +@pytest.fixture(scope="module") +def fitted_model(base_dataset): + # GA2M w/ completed Stage 1 and Stage 2 setup + loader = _make_loader(base_dataset) + model = GA2M( + dataset=base_dataset, + n_bins=8, # small bins for speed + top_k_interactions=3, # few interactions for speed + use_interactions=True, + ) + model.fit_bins(loader) + model.fit_main_effects(loader, epochs=2, lr=1e-2) + model.select_top_interactions() + return model + + +@pytest.fixture() +def tmp_dir(): + # function-scoped so each test gets a clean directory + path = tempfile.mkdtemp() + yield path + shutil.rmtree(path) # always cleaned up, even if the test fails + + +# 1. instantiation and basic properties +class TestInstantiation: + + def test_instantiation_defaults(self, base_dataset): + # model init w/ paper-default hyperparams + model = GA2M(dataset=base_dataset, n_bins=256, top_k_interactions=34) + assert model.n_bins == 256 + assert model.top_k_interactions == 34 + assert model.use_interactions is True + + def test_instantiation_no_interactions(self, base_dataset): + # use_interactions=False: creates no interaction modules + model = GA2M(dataset=base_dataset, n_bins=8, use_interactions=False) + assert len(model.interactions) == 0 + + def test_main_effects_count(self, base_dataset): + # one embedding / input feature is created + model = GA2M(dataset=base_dataset, n_bins=8) + assert len(model.main_effects) == model.input_dim + + def test_embedding_size(self, base_dataset): + # every main effect embedding has n_bins + 1 rows (unknown bin) + model = GA2M(dataset=base_dataset, n_bins=8) + for emb in model.main_effects: + assert emb.weight.shape == (9, 1) # 8 bins + 1 unknown + + def test_wrong_feature_keys_raises(self): + # model raises if dataset has multiple feature keys + from pyhealth.datasets import create_sample_dataset + samples = [ + {"patient_id": "p0", "visit_id": "v0", + "feat_a": [1.0, 2.0], "feat_b": [3.0], "label": 0}, + {"patient_id": "p1", "visit_id": "v1", + "feat_a": [2.0, 3.0], "feat_b": [4.0], "label": 1}, + ] + ds = create_sample_dataset( + samples=samples, + input_schema={"feat_a": "tensor", "feat_b": "tensor"}, + output_schema={"label": "binary"}, + in_memory=True, + ) + with pytest.raises(AssertionError): + GA2M(dataset=ds, n_bins=8) + + +# bin fitting and edge properties +class TestBinFitting: + + def test_fit_bins_sets_flag(self, base_dataset): + # _bins_fitted flips False -> True after fit_bins + model = GA2M(dataset=base_dataset, n_bins=8) + loader = _make_loader(base_dataset) + assert not model._bins_fitted + model.fit_bins(loader) + assert model._bins_fitted + + def test_bin_edges_shape(self, base_dataset): + # bin_edges has shape (input_dim, n_bins - 1) + model = GA2M(dataset=base_dataset, n_bins=8) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + assert model.bin_edges.shape == (model.input_dim, 7) + + def test_bin_edges_monotone(self, base_dataset): + # quantile edges: non-decreasing/feature + model = GA2M(dataset=base_dataset, n_bins=8) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + edges = model.bin_edges.numpy() + for d in range(model.input_dim): + diffs = np.diff(edges[d]) + # if edges aren't sorted, bin assignments will be nonsense + assert np.all(diffs >= 0), ( + f"Feature {d} bin edges are not monotone: {edges[d]}" + ) + + def test_forward_before_fit_bins_raises(self, base_dataset): + # calling forward() without fit_bins should: RuntimeError + model = GA2M( + dataset=base_dataset, + n_bins=8, + use_interactions=False, + ) + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + with pytest.raises(RuntimeError, match="fit_bins"): + model(**batch) + + def test_tmp_dir_cleanup(self, tmp_dir): + import os + # dir exists during the test and gets wiped by the fixture afterward + assert os.path.isdir(tmp_dir) + test_file = os.path.join(tmp_dir, "dummy.txt") + with open(test_file, "w") as f: + f.write("temp") + assert os.path.exists(test_file) + # shutil.rmtree runs after yield in the fixture, no manual cleanup needed + + +# bin assignment logic, including unknown sentinel handling +class TestBinAssignment: + + def test_known_values_in_range(self, base_dataset): + # known vals map to bins 0...n_bins-1 + model = GA2M(dataset=base_dataset, n_bins=8) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + + x = torch.tensor([[1.0, 2.0, 3.0, 1.5, 2.5, 0.5]]) + idx = model._assign_bins(x) + assert idx.shape == (1, model.input_dim) + assert (idx < model.n_bins).all(), ( + "Known values should not be routed to unknown bin." + ) + + def test_unknown_sentinel_routes_to_unknown_bin(self, base_dataset): + # UNKNOWN_SENTINEL values mapped to bin index n_bins + model = GA2M(dataset=base_dataset, n_bins=8) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + + x = torch.full((1, model.input_dim), 1.0) + x[0, 0] = UNKNOWN_SENTINEL # mark just the first feature as missing + idx = model._assign_bins(x) + assert idx[0, 0].item() == model.n_bins, ( + "UNKNOWN_SENTINEL should route to the dedicated unknown bin." + ) + # all other features should be unaffected + assert (idx[0, 1:] < model.n_bins).all() + + +# stage 1: main effect training +class TestMainEffectsTraining: + + def test_weights_change_after_training(self, base_dataset): + # effect weights should differ from zeros after Stage 1 + model = GA2M(dataset=base_dataset, n_bins=8, use_interactions=False) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + + # snapshot before training + initial = model.main_effects[0].weight.data.clone() + model.fit_main_effects(loader, epochs=3, lr=1e-1) + after = model.main_effects[0].weight.data + + assert not torch.allclose(initial, after), ( + "Main effect weights should update during Stage 1 training." + ) + + def test_forward_main_effects_only_output_shapes(self, base_dataset): + # internal main effects only fwd returns correct shapes + model = GA2M(dataset=base_dataset, n_bins=8, use_interactions=False) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + + batch = next(iter(loader)) + out = model._forward_main_effects_only(batch) + B = batch["features"].shape[0] + assert out["logits"].shape == (B, 1) + assert out["loss"].shape == () # scalar, not a vector + + +# select top interactions +class TestInteractionSelection: + + def test_select_top_interactions_count(self, fitted_model): + # top_k_interactions pairs are selected + assert len(fitted_model.interaction_pairs) == fitted_model.top_k_interactions + + def test_interaction_module_dict_keys(self, fitted_model): + # ModuleDict keys match pair tuples that are chosen + for i, j in fitted_model.interaction_pairs: + assert f"{i}_{j}" in fitted_model.interactions + + def test_interaction_embedding_size(self, fitted_model): + # every interaction embedding has (n_bins+1)^2 rows for 2D grid of bin pairs + expected = (fitted_model.n_bins + 1) ** 2 + for emb in fitted_model.interactions.values(): + assert emb.weight.shape[0] == expected + + def test_no_self_interactions(self, fitted_model): + # features are not paired with themselves + for i, j in fitted_model.interaction_pairs: + assert i != j + + def test_pairs_are_upper_triangular(self, fitted_model): + # pairs: (i, j) with i < j (no duplicates) + for i, j in fitted_model.interaction_pairs: + assert i < j + + def test_forward_raises_before_select(self, base_dataset): + # full fwd should raise if select_top_interactions is NOT called + model = GA2M(dataset=base_dataset, n_bins=8, use_interactions=True) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + model.fit_main_effects(loader, epochs=1, lr=1e-2) + # skip select_top_interactions on purpose + batch = next(iter(loader)) + with pytest.raises(RuntimeError, match="select_top_interactions"): + model(**batch) + + +# full forward pass - after all stages are done +class TestFullForward: + + def test_output_keys(self, fitted_model, base_dataset): + # forward returns expected dict keys + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + out = fitted_model(**batch) + for key in ("loss", "y_prob", "y_true", "logits"): + assert key in out, f"Missing key '{key}' in forward output." + + def test_output_shapes(self, fitted_model, base_dataset): + # logits and y_prob have shape (B, 1), loss is scalar + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + B = batch["features"].shape[0] + out = fitted_model(**batch) + + assert out["logits"].shape == (B, 1), ( + f"Expected logits shape ({B}, 1), got {out['logits'].shape}" + ) + assert out["y_prob"].shape == (B, 1), ( + f"Expected y_prob shape ({B}, 1), got {out['y_prob'].shape}" + ) + assert out["loss"].shape == (), "Loss should be a scalar tensor." + + def test_y_prob_in_unit_interval(self, fitted_model, base_dataset): + # predicted probabilities must lie in [0, 1] + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + out = fitted_model(**batch) + assert (out["y_prob"] >= 0.0).all() and (out["y_prob"] <= 1.0).all() + + def test_loss_is_positive(self, fitted_model, base_dataset): + # bce loss > 0 for non-perfect predictions + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + out = fitted_model(**batch) + assert out["loss"].item() > 0.0 + + def test_backward_pass_main_effects(self, fitted_model, base_dataset): + # gradients flow through full model back to main effect embeddings + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + out = fitted_model(**batch) + out["loss"].backward() + + # check that at least one main effect got a gradient + has_grad = any( + emb.weight.grad is not None + for emb in fitted_model.main_effects + ) + assert has_grad, "No gradients flowed to main effect embeddings." + + def test_backward_pass_interactions(self, fitted_model, base_dataset): + # interaction embeddings should also receive gradients + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + out = fitted_model(**batch) + out["loss"].backward() + + # at least one interaction embedding should have a gradient + has_grad = any( + emb.weight.grad is not None + for emb in fitted_model.interactions.values() + ) + assert has_grad, "No gradients flowed to interaction embeddings." + + def test_backward_pass_bias(self, fitted_model, base_dataset): + # the global bias (intercept) should also get a gradient + loader = _make_loader(base_dataset) + batch = next(iter(loader)) + out = fitted_model(**batch) + out["loss"].backward() + assert fitted_model.bias.grad is not None, "Bias parameter has no gradient." + + def test_no_interactions_flag(self, base_dataset): + # use_interactions=False runs without select_top_interactions + model = GA2M(dataset=base_dataset, n_bins=8, use_interactions=False) + loader = _make_loader(base_dataset) + model.fit_bins(loader) + model.fit_main_effects(loader, epochs=1) + + batch = next(iter(loader)) + out = model(**batch) + B = batch["features"].shape[0] + assert out["logits"].shape == (B, 1) + assert out["loss"].item() > 0.0 + + +# unknown bin handling +class TestUnknownBin: + + def test_unknown_values_produce_finite_output(self, unknown_dataset): + # samples w/ UNKNOWN_SENTINEL features should not produce NaN/Inf logits + loader = _make_loader(unknown_dataset) + model = GA2M( + dataset=unknown_dataset, + n_bins=8, + top_k_interactions=2, + use_interactions=True, + ) + model.fit_bins(loader) + model.fit_main_effects(loader, epochs=2) + model.select_top_interactions() + + batch = next(iter(loader)) + out = model(**batch) + # nan/inf here would silently break training downstream + assert torch.isfinite(out["logits"]).all(), ( + "Logits contain NaN or Inf for samples with unknown features." + ) + + def test_unknown_bin_weight_is_independent(self, unknown_dataset): + # the unknown bin (index n_bins) is its own learnable parameter, + # so the model can learn a meaningful risk for missing values + model = GA2M(dataset=unknown_dataset, n_bins=8, use_interactions=False) + # n_bins + 1 rows: last row is the unknown bin + assert model.main_effects[0].weight.shape[0] == model.n_bins + 1 + + +# interpretability helper methods (get_shape_function, get_interaction_shape) +class TestInterpretabilityHelpers: + + def test_get_shape_function_lengths(self, fitted_model): + # shape fn returns arrays of length n_bins for midpoints & risks + midpoints, risks = fitted_model.get_shape_function(0) + assert len(midpoints) == fitted_model.n_bins + assert len(risks) == fitted_model.n_bins + + def test_get_shape_function_finite(self, fitted_model): + # shape fn values must be finite for plotting + for d in range(fitted_model.input_dim): + midpoints, risks = fitted_model.get_shape_function(d) + assert np.all(np.isfinite(risks)), ( + f"Feature {d} shape function contains non-finite risk values." + ) + + def test_get_interaction_shape_selected(self, fitted_model): + # get_interaction_shape returns a 2D arr for selected pairs + i, j = fitted_model.interaction_pairs[0] + grid = fitted_model.get_interaction_shape(i, j) + assert grid is not None + # unknown row/col excluded from the visualisation grid + assert grid.shape == (fitted_model.n_bins, fitted_model.n_bins) + + def test_get_interaction_shape_unselected_returns_none(self, fitted_model): + # get_interaction_shape returns None for non-selected pairs + # find a pair definitely not selected — check against the selected set + selected = set(fitted_model.interaction_pairs) + for i in range(fitted_model.input_dim): + for j in range(i + 1, fitted_model.input_dim): + if (i, j) not in selected: + result = fitted_model.get_interaction_shape(i, j) + assert result is None + return # one check to confirm correct behaviour + + +# w/out pytest +# run smoke test of full forward pass on tiny synthetic data, print o/ps + +if __name__ == "__main__": + print("Running GA2M tests...") + + ds = _make_dataset(n_patients=4, n_features=6) + loader = _make_loader(ds) + + model = GA2M(dataset=ds, n_bins=8, top_k_interactions=3) + model.fit_bins(loader) + model.fit_main_effects(loader, epochs=3, lr=1e-1) + model.select_top_interactions() + + batch = next(iter(loader)) + out = model(**batch) + out["loss"].backward() + + print(f" loss : {out['loss'].item():.4f}") + print(f" y_prob : {out['y_prob'].squeeze().tolist()}") + print(f" y_true : {out['y_true'].squeeze().tolist()}") + print(f" logits : {out['logits'].squeeze().tolist()}") + print(f" selected pairs: {model.interaction_pairs}") + + midpoints, risks = model.get_shape_function(0) + print(f" feature 0 risk range: [{risks.min():.4f}, {risks.max():.4f}]") + + print("\nAll checks passed!") \ No newline at end of file